00001
00023 #ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
00024 #define __MLPACK_METHODS_MOG_MOG_EM_HPP
00025
00026 #include <mlpack/core.hpp>
00027
00028
00029 #include "em_fit.hpp"
00030
00031 namespace mlpack {
00032 namespace gmm {
00033
00088 template<typename FittingType = EMFit<> >
00089 class GMM
00090 {
00091 private:
00093 size_t gaussians;
00095 size_t dimensionality;
00097 std::vector<arma::vec> means;
00099 std::vector<arma::mat> covariances;
00101 arma::vec weights;
00102
00103 public:
00107 GMM() :
00108 gaussians(0),
00109 dimensionality(0),
00110 localFitter(FittingType()),
00111 fitter(localFitter)
00112 {
00113
00114
00115
00116 Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
00117 << "unless parameters are set." << std::endl;
00118 }
00119
00127 GMM(const size_t gaussians, const size_t dimensionality) :
00128 gaussians(gaussians),
00129 dimensionality(dimensionality),
00130 means(gaussians, arma::vec(dimensionality)),
00131 covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00132 weights(gaussians),
00133 localFitter(FittingType()),
00134 fitter(localFitter) { }
00135
00146 GMM(const size_t gaussians,
00147 const size_t dimensionality,
00148 FittingType& fitter) :
00149 gaussians(gaussians),
00150 dimensionality(dimensionality),
00151 means(gaussians, arma::vec(dimensionality)),
00152 covariances(gaussians, arma::mat(dimensionality, dimensionality)),
00153 weights(gaussians),
00154 fitter(fitter) { }
00155
00163 GMM(const std::vector<arma::vec>& means,
00164 const std::vector<arma::mat>& covariances,
00165 const arma::vec& weights) :
00166 gaussians(means.size()),
00167 dimensionality((!means.empty()) ? means[0].n_elem : 0),
00168 means(means),
00169 covariances(covariances),
00170 weights(weights),
00171 localFitter(FittingType()),
00172 fitter(localFitter) { }
00173
00183 GMM(const std::vector<arma::vec>& means,
00184 const std::vector<arma::mat>& covariances,
00185 const arma::vec& weights,
00186 FittingType& fitter) :
00187 gaussians(means.size()),
00188 dimensionality((!means.empty()) ? means[0].n_elem : 0),
00189 means(means),
00190 covariances(covariances),
00191 weights(weights),
00192 fitter(fitter) { }
00193
00197 template<typename OtherFittingType>
00198 GMM(const GMM<OtherFittingType>& other);
00199
00204 GMM(const GMM& other);
00205
00209 template<typename OtherFittingType>
00210 GMM& operator=(const GMM<OtherFittingType>& other);
00211
00216 GMM& operator=(const GMM& other);
00217
00224 void Load(const std::string& filename);
00225
00231 void Save(const std::string& filename) const;
00232
00234 size_t Gaussians() const { return gaussians; }
00237 size_t& Gaussians() { return gaussians; }
00238
00240 size_t Dimensionality() const { return dimensionality; }
00243 size_t& Dimensionality() { return dimensionality; }
00244
00246 const std::vector<arma::vec>& Means() const { return means; }
00248 std::vector<arma::vec>& Means() { return means; }
00249
00251 const std::vector<arma::mat>& Covariances() const { return covariances; }
00253 std::vector<arma::mat>& Covariances() { return covariances; }
00254
00256 const arma::vec& Weights() const { return weights; }
00258 arma::vec& Weights() { return weights; }
00259
00261 const FittingType& Fitter() const { return fitter; }
00263 FittingType& Fitter() { return fitter; }
00264
00271 double Probability(const arma::vec& observation) const;
00272
00280 double Probability(const arma::vec& observation,
00281 const size_t component) const;
00282
00289 arma::vec Random() const;
00290
00313 double Estimate(const arma::mat& observations,
00314 const size_t trials = 1,
00315 const bool useExistingModel = false);
00316
00341 double Estimate(const arma::mat& observations,
00342 const arma::vec& probabilities,
00343 const size_t trials = 1,
00344 const bool useExistingModel = false);
00345
00362 void Classify(const arma::mat& observations,
00363 arma::Col<size_t>& labels) const;
00364
00365 private:
00375 double LogLikelihood(const arma::mat& dataPoints,
00376 const std::vector<arma::vec>& means,
00377 const std::vector<arma::mat>& covars,
00378 const arma::vec& weights) const;
00379
00381 FittingType localFitter;
00382
00384 FittingType& fitter;
00385 };
00386
00387 };
00388 };
00389
00390
00391 #include "gmm_impl.hpp"
00392
00393 #endif