em_fit.hpp

Go to the documentation of this file.
00001 
00023 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
00024 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
00025 
00026 #include <mlpack/core.hpp>
00027 
00028 // Default clustering mechanism.
00029 #include <mlpack/methods/kmeans/kmeans.hpp>
00030 // Default covariance matrix constraint.
00031 #include "positive_definite_constraint.hpp"
00032 
00033 namespace mlpack {
00034 namespace gmm {
00035 
00049 template<typename InitialClusteringType = kmeans::KMeans<>,
00050          typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
00051 class EMFit
00052 {
00053  public:
00071   EMFit(const size_t maxIterations = 300,
00072         const double tolerance = 1e-10,
00073         InitialClusteringType clusterer = InitialClusteringType(),
00074         CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
00075 
00091   void Estimate(const arma::mat& observations,
00092                 std::vector<arma::vec>& means,
00093                 std::vector<arma::mat>& covariances,
00094                 arma::vec& weights,
00095                 const bool useInitialModel = false);
00096 
00114   void Estimate(const arma::mat& observations,
00115                 const arma::vec& probabilities,
00116                 std::vector<arma::vec>& means,
00117                 std::vector<arma::mat>& covariances,
00118                 arma::vec& weights,
00119                 const bool useInitialModel = false);
00120 
00122   const InitialClusteringType& Clusterer() const { return clusterer; }
00124   InitialClusteringType& Clusterer() { return clusterer; }
00125 
00127   const CovarianceConstraintPolicy& Constraint() const { return constraint; }
00129   CovarianceConstraintPolicy& Constraint() { return constraint; }
00130 
00132   size_t MaxIterations() const { return maxIterations; }
00134   size_t& MaxIterations() { return maxIterations; }
00135 
00137   double Tolerance() const { return tolerance; }
00139   double& Tolerance() { return tolerance; }
00140 
00141  private:
00152   void InitialClustering(const arma::mat& observations,
00153                          std::vector<arma::vec>& means,
00154                          std::vector<arma::mat>& covariances,
00155                          arma::vec& weights);
00156 
00167   double LogLikelihood(const arma::mat& data,
00168                        const std::vector<arma::vec>& means,
00169                        const std::vector<arma::mat>& covariances,
00170                        const arma::vec& weights) const;
00171 
00173   size_t maxIterations;
00175   double tolerance;
00177   InitialClusteringType clusterer;
00179   CovarianceConstraintPolicy constraint;
00180 };
00181 
00182 }; // namespace gmm
00183 }; // namespace mlpack
00184 
00185 // Include implementation.
00186 #include "em_fit_impl.hpp"
00187 
00188 #endif

Generated on 13 Aug 2014 for MLPACK by  doxygen 1.6.1