ergo
gemm_sse.h
Go to the documentation of this file.
1/* Ergo, version 3.8.2, a program for linear scaling electronic structure
2 * calculations.
3 * Copyright (C) 2023 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
4 * and Anastasia Kruchinina.
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Primary academic reference:
20 * Ergo: An open-source program for linear-scaling electronic structure
21 * calculations,
22 * Elias Rudberg, Emanuel H. Rubensson, Pawel Salek, and Anastasia
23 * Kruchinina,
24 * SoftwareX 7, 107 (2018),
25 * <http://dx.doi.org/10.1016/j.softx.2018.03.005>
26 *
27 * For further information about Ergo, see <http://www.ergoscf.org>.
28 */
29
40#ifndef GEMM_SSE_H
41#define GEMM_SSE_H
42#include <stdexcept>
44#include "mm_kernel_outer_A.h"
45
46
47template<typename real, typename regType,
48 int m_kernel, int n_kernel, int k_kernel,
49 int m_block, int n_block>
50 static void gemm_sse(real const * const A,
51 real const * const B,
52 real * C,
53 size_t const m,
54 size_t const n,
55 size_t const k,
56 real * A_packed,
57 real * B_packed,
58 real * C_packed,
59 size_t const ap_size,
60 size_t const bp_size,
61 size_t const cp_size) {
62 // typedef double real; typedef __m128d regType;
63 // typedef float real; typedef __m128 regType;
66 if (m != m_kernel*m_block)
67 throw std::runtime_error("Error in gemm_sse(...): m != m_kernel*m_block");
68 if (n != n_kernel*n_block)
69 throw std::runtime_error("Error in gemm_sse(...): n != n_kernel*n_block");
70 if (k != k_kernel)
71 throw std::runtime_error("Error in gemm_sse(...): k != k_kernel");
72 if (ap_size < MM_outer::Pack_type_A::size_packed)
73 throw std::runtime_error("Error in gemm_sse(...): "
74 "ap_size < MM_outer::Pack_type_A::size_packed");
75 if (bp_size < MM_outer::Pack_type_B::size_packed)
76 throw std::runtime_error("Error in gemm_sse(...): "
77 "bp_size < MM_outer::Pack_type_B::size_packed");
78 if (cp_size < MM_outer::Pack_type_C::size_packed)
79 throw std::runtime_error("Error in gemm_sse(...): "
80 "cp_size < MM_outer::Pack_type_C::size_packed");
81 MM_outer::Pack_type_C::template pack<Ordering_col_wise>( C, C_packed, m, n);
82 MM_outer::Pack_type_A::template pack<Ordering_col_wise>( A, A_packed, m, k);
83 MM_outer::Pack_type_B::template pack<Ordering_col_wise>( B, B_packed, k, n);
84 MM_outer::exec(&A_packed, &B_packed, C_packed);
85 MM_outer::Pack_type_C::template unpack<Ordering_col_wise>(C, C_packed, m, n);
86}
87
88template<typename real>
89static void gemm_sse(real const * const A,
90 real const * const B,
91 real * C,
92 size_t const m,
93 size_t const n,
94 size_t const k,
95 real * A_packed,
96 real * B_packed,
97 real * C_packed,
98 size_t const ap_size,
99 size_t const bp_size,
100 size_t const cp_size) {
101 throw std::runtime_error("gemm_sse not implemented for chosen real type.");
102}
103
104template<>
105void gemm_sse(double const * const A,
106 double const * const B,
107 double * C,
108 size_t const m,
109 size_t const n,
110 size_t const k,
111 double * A_packed,
112 double * B_packed,
113 double * C_packed,
114 size_t const ap_size,
115 size_t const bp_size,
116 size_t const cp_size) {
117 gemm_sse<double, __m128d, 4, 4, 32, 8, 8>
118 (A, B, C, m, n, k,
119 A_packed, B_packed, C_packed, ap_size, bp_size, cp_size);
120}
121
122template<>
123void gemm_sse(float const * const A,
124 float const * const B,
125 float * C,
126 size_t const m,
127 size_t const n,
128 size_t const k,
129 float * A_packed,
130 float * B_packed,
131 float * C_packed,
132 size_t const ap_size,
133 size_t const bp_size,
134 size_t const cp_size) {
135 gemm_sse<float, __m128, 8, 4, 32, 4, 8>
136 (A, B, C, m, n, k,
137 A_packed, B_packed, C_packed, ap_size, bp_size, cp_size);
138}
139
140#endif
Matrix multiplication template for architectures with SSE2 or higher and compilers that support C++ i...
Definition: mm_kernel_inner_sse2_A.h:63
Template for matrix matrix multiplication that wraps around a kernel given as template argument.
Definition: mm_kernel_outer_A.h:53
ergo_real real
Definition: test.cc:46
static void gemm_sse(real const *const A, real const *const B, real *C, size_t const m, size_t const n, size_t const k, real *A_packed, real *B_packed, real *C_packed, size_t const ap_size, size_t const bp_size, size_t const cp_size)
Definition: gemm_sse.h:50
#define B
#define A
Templates for efficient gemm kernels.
Templates for efficient gemm kernels.