als_update_rules.hpp
Go to the documentation of this file.00001
00028 #ifndef __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00029 #define __MLPACK_METHODS_NMF_ALS_UPDATE_RULES_HPP
00030
00031 #include <mlpack/core.hpp>
00032
00033 namespace mlpack {
00034 namespace nmf {
00035
00042 class WAlternatingLeastSquaresRule
00043 {
00044 public:
00045
00046 WAlternatingLeastSquaresRule() { }
00047
00056 template<typename MatType>
00057 inline static void Update(const MatType& V,
00058 arma::mat& W,
00059 const arma::mat& H)
00060 {
00061
00062
00063 W = V * H.t() * pinv(H * H.t());
00064
00065
00066 for (size_t i = 0; i < W.n_elem; i++)
00067 {
00068 if (W(i) < 0.0)
00069 {
00070 W(i) = 0.0;
00071 }
00072 }
00073 }
00074 };
00075
00082 class HAlternatingLeastSquaresRule
00083 {
00084 public:
00085
00086 HAlternatingLeastSquaresRule() { }
00087
00096 template<typename MatType>
00097 inline static void Update(const MatType& V,
00098 const arma::mat& W,
00099 arma::mat& H)
00100 {
00101 H = pinv(W.t() * W) * W.t() * V;
00102
00103
00104 for (size_t i = 0; i < H.n_elem; i++)
00105 {
00106 if (H(i) < 0.0)
00107 {
00108 H(i) = 0.0;
00109 }
00110 }
00111 }
00112 };
00113
00114 };
00115 };
00116
00117 #endif