00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00040 #ifndef GEMM_SSE_H
00041 #define GEMM_SSE_H
00042 #include <stdexcept>
00043 #include "mm_kernel_inner_sse2_A.h"
00044 #include "mm_kernel_outer_A.h"
00045
00046
00047 template<typename real, typename regType,
00048 int m_kernel, int n_kernel, int k_kernel,
00049 int m_block, int n_block>
00050 static void gemm_sse(real const * const A,
00051 real const * const B,
00052 real * C,
00053 size_t const m,
00054 size_t const n,
00055 size_t const k,
00056 real * A_packed,
00057 real * B_packed,
00058 real * C_packed,
00059 size_t const ap_size,
00060 size_t const bp_size,
00061 size_t const cp_size) {
00062
00063
00064 typedef MM_kernel_inner_sse2_A<real, regType, m_kernel, n_kernel, k_kernel> MM_inner;
00065 typedef MM_kernel_outer_A<MM_inner, m_block, n_block> MM_outer;
00066 if (m != m_kernel*m_block)
00067 throw std::runtime_error("Error in gemm_sse(...): m != m_kernel*m_block");
00068 if (n != n_kernel*n_block)
00069 throw std::runtime_error("Error in gemm_sse(...): n != n_kernel*n_block");
00070 if (k != k_kernel)
00071 throw std::runtime_error("Error in gemm_sse(...): k != k_kernel");
00072 if (ap_size < MM_outer::Pack_type_A::size_packed)
00073 throw std::runtime_error("Error in gemm_sse(...): "
00074 "ap_size < MM_outer::Pack_type_A::size_packed");
00075 if (bp_size < MM_outer::Pack_type_B::size_packed)
00076 throw std::runtime_error("Error in gemm_sse(...): "
00077 "bp_size < MM_outer::Pack_type_B::size_packed");
00078 if (cp_size < MM_outer::Pack_type_C::size_packed)
00079 throw std::runtime_error("Error in gemm_sse(...): "
00080 "cp_size < MM_outer::Pack_type_C::size_packed");
00081 MM_outer::Pack_type_C::template pack<Ordering_col_wise>( C, C_packed, m, n);
00082 MM_outer::Pack_type_A::template pack<Ordering_col_wise>( A, A_packed, m, k);
00083 MM_outer::Pack_type_B::template pack<Ordering_col_wise>( B, B_packed, k, n);
00084 MM_outer::exec(&A_packed, &B_packed, C_packed);
00085 MM_outer::Pack_type_C::template unpack<Ordering_col_wise>(C, C_packed, m, n);
00086 }
00087
00088 template<typename real>
00089 static void gemm_sse(real const * const A,
00090 real const * const B,
00091 real * C,
00092 size_t const m,
00093 size_t const n,
00094 size_t const k,
00095 real * A_packed,
00096 real * B_packed,
00097 real * C_packed,
00098 size_t const ap_size,
00099 size_t const bp_size,
00100 size_t const cp_size) {
00101 throw std::runtime_error("gemm_sse not implemented for chosen real type.");
00102 }
00103
00104 template<>
00105 void gemm_sse(double const * const A,
00106 double const * const B,
00107 double * C,
00108 size_t const m,
00109 size_t const n,
00110 size_t const k,
00111 double * A_packed,
00112 double * B_packed,
00113 double * C_packed,
00114 size_t const ap_size,
00115 size_t const bp_size,
00116 size_t const cp_size) {
00117 gemm_sse<double, __m128d, 4, 4, 32, 8, 8>
00118 (A, B, C, m, n, k,
00119 A_packed, B_packed, C_packed, ap_size, bp_size, cp_size);
00120 }
00121
00122 template<>
00123 void gemm_sse(float const * const A,
00124 float const * const B,
00125 float * C,
00126 size_t const m,
00127 size_t const n,
00128 size_t const k,
00129 float * A_packed,
00130 float * B_packed,
00131 float * C_packed,
00132 size_t const ap_size,
00133 size_t const bp_size,
00134 size_t const cp_size) {
00135 gemm_sse<float, __m128, 8, 4, 32, 4, 8>
00136 (A, B, C, m, n, k,
00137 A_packed, B_packed, C_packed, ap_size, bp_size, cp_size);
00138 }
00139
00140 #endif