ergo
mm_kernel_outer_A.h
Go to the documentation of this file.
1/* Ergo, version 3.8.2, a program for linear scaling electronic structure
2 * calculations.
3 * Copyright (C) 2023 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
4 * and Anastasia Kruchinina.
5 *
6 * This program is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18 *
19 * Primary academic reference:
20 * Ergo: An open-source program for linear-scaling electronic structure
21 * calculations,
22 * Elias Rudberg, Emanuel H. Rubensson, Pawel Salek, and Anastasia
23 * Kruchinina,
24 * SoftwareX 7, 107 (2018),
25 * <http://dx.doi.org/10.1016/j.softx.2018.03.005>
26 *
27 * For further information about Ergo, see <http://www.ergoscf.org>.
28 */
29
40#ifndef MM_KERNEL_OUTER_A_H
41#define MM_KERNEL_OUTER_A_H
42#include "common.h"
43#ifdef _OPENMP
44#include <omp.h>
45#endif
46
52template<typename T_gemm_kernel, int T_M_block, int T_N_block>
54 template<int T_rows_block, int T_cols_block, typename T_ordering_block, typename T_pack_type_kernel>
55 class Pack;
56 public:
57 static int const M_kernel = T_gemm_kernel::M;
58 static int const N_kernel = T_gemm_kernel::N;
59 static int const K_kernel = T_gemm_kernel::K;
60 static int const M_block = T_M_block;
61 static int const N_block = T_N_block;
62 static int const K_block = 1;
63 static int const M = M_kernel * M_block;
64 static int const N = N_kernel * N_block;
65 static int const K = K_kernel * K_block;
66 typedef typename T_gemm_kernel::real real;
71
78 static void exec( real const * const * const A,
79 real const * const * const B,
80 real * const C,
81 int const i = 1);
82
83};
84
85template<typename T_gemm_kernel, int T_M_block, int T_N_block>
87 real const * const * const B,
88 real * const C,
89 int const n_mul ) {
90#if 1
91 for ( int n = 0; n < N_block; ++n )
92 for ( int m = 0; m < M_block; ++m ) {
93 T_gemm_kernel::exec( A, B, C, n_mul,
94 Ordering_block_A::get( m, 0, M_block, K_block ) * T_gemm_kernel::Pack_type_A::size_packed,
95 Ordering_block_B::get( 0, n, K_block, N_block ) * T_gemm_kernel::Pack_type_B::size_packed,
96 Ordering_block_C::get( m, n, M_block, N_block ) * T_gemm_kernel::Pack_type_C::size_packed );
97 }
98
99#else
100#if 1
101 // FIXME: This is faster since the offsets are known at compile time, TODO: unroll for loops...
102 T_gemm_kernel::template exec<Ordering_block_A::template Get<0, 0, M_block, K_block>::index * T_gemm_kernel::Pack_type_A::size_packed,
103 Ordering_block_B::template Get<0, 0, K_block, N_block>::index * T_gemm_kernel::Pack_type_B::size_packed,
104 Ordering_block_C::template Get<0, 0, M_block, N_block>::index * T_gemm_kernel::Pack_type_C::size_packed>( A, B, C, n_mul );
105 T_gemm_kernel::template exec<Ordering_block_A::template Get<1, 0, M_block, K_block>::index * T_gemm_kernel::Pack_type_A::size_packed,
106 Ordering_block_B::template Get<0, 0, K_block, N_block>::index * T_gemm_kernel::Pack_type_B::size_packed,
107 Ordering_block_C::template Get<1, 0, M_block, N_block>::index * T_gemm_kernel::Pack_type_C::size_packed>( A, B, C, n_mul );
108#else
109 T_gemm_kernel::exec( A, B, C, n_mul,
110 Ordering_block_A::get( 0, 0, M_block, K_block ) * T_gemm_kernel::Pack_type_A::size_packed,
111 Ordering_block_B::get( 0, 0, K_block, N_block ) * T_gemm_kernel::Pack_type_B::size_packed,
112 Ordering_block_C::get( 0, 0, M_block, N_block ) * T_gemm_kernel::Pack_type_C::size_packed );
113#endif
114#endif
115}
116
117
127template<typename T_gemm_kernel, int T_M_block, int T_N_block>
128 template<int T_rows_block, int T_cols_block, typename T_ordering_block, typename T_pack_type_kernel>
129 class MM_kernel_outer_A<T_gemm_kernel, T_M_block, T_N_block>::Pack {
130 static int const rows_kernel = T_pack_type_kernel::rows;
131 static int const cols_kernel = T_pack_type_kernel::cols;
132 public:
133 static int const rows = rows_kernel * T_rows_block;
134 static int const cols = cols_kernel * T_cols_block;
138 // static int const size_packed = rows * cols * T_pack_type_kernel::size_packed;
139 static unsigned int const size_packed = T_rows_block * T_cols_block * T_pack_type_kernel::size_packed;
140 // typedef Packed<T_real, T_rows_block, T_cols_block, T_rows_kernel, T_cols_kernel, T_kernel_index, T_repetitions> ThisType;
141
142 template<typename T_ordering_matrix>
143 struct Assign_to_packed : public T_pack_type_kernel::template Assign_to_packed<T_ordering_matrix> {
144 typedef T_ordering_matrix Ordering_matrix;
145 };
146 template<typename T_ordering_matrix>
147 struct Extract_from_packed : public T_pack_type_kernel::template Extract_from_packed<T_ordering_matrix> {
148 typedef T_ordering_matrix Ordering_matrix;
149 };
150
151
156 template<template<typename T_ordering> class T_assign, typename T_ordering_matrix>
157 static void exec(typename T_assign<T_ordering_matrix>::PtrType X, typename T_assign<T_ordering_matrix>::PtrTypePacked X_packed,
158 int const rows_total_matrix, int const cols_total_matrix) {
159 // Loop over column blocks of new packed matrix
160 for ( int col_b = 0; col_b < T_cols_block; ++col_b ) {
161 // Loop over row blocks of new packed matrix
162 for ( int row_b = 0; row_b < T_rows_block; ++row_b ) {
163 T_pack_type_kernel::template exec< T_assign, T_ordering_matrix >
164 ( &X[ T_assign<T_ordering_matrix>::Ordering_matrix::get( row_b * rows_kernel, col_b * cols_kernel,
165 rows_total_matrix, cols_total_matrix ) ],
166 &X_packed[ T_ordering_block::get( row_b, col_b, T_rows_block, T_cols_block ) *
167 T_pack_type_kernel::size_packed ],
168 rows_total_matrix, cols_total_matrix );
169 // Indexes of original matrix : ( row_b * rows_kernel, col_b * cols_kernel )
170 // Block indexes (packed matrix) : ( row_b, col_b )
171 // Number of reals needed for each kernel : T_pack_type_kernel::size_packed
172 }
173 }
174 } // end exec()
175
180 template<typename T_ordering_matrix>
181 inline static void pack(real const * const X, real * X_packed,
182 int const rows_total_matrix, int const cols_total_matrix) {
183 exec< Assign_to_packed, T_ordering_matrix >(X, X_packed, rows_total_matrix, cols_total_matrix);
184 }
189 template<typename T_ordering_matrix>
190 inline static void unpack(real * X, real const * const X_packed,
191 int const rows_total_matrix, int const cols_total_matrix) {
192 exec< Extract_from_packed, T_ordering_matrix >(X, X_packed, rows_total_matrix, cols_total_matrix);
193 }
194
195 // real * values;
196};
197#endif
Template for for translations between unpacked and packed matrix storage.
Definition: mm_kernel_outer_A.h:129
static void exec(typename T_assign< T_ordering_matrix >::PtrType X, typename T_assign< T_ordering_matrix >::PtrTypePacked X_packed, int const rows_total_matrix, int const cols_total_matrix)
Elaborate function that can be called either to assign to or extract from packed format.
Definition: mm_kernel_outer_A.h:157
static void unpack(real *X, real const *const X_packed, int const rows_total_matrix, int const cols_total_matrix)
Convenience function for extracting matrix from packed matrix.
Definition: mm_kernel_outer_A.h:190
static void pack(real const *const X, real *X_packed, int const rows_total_matrix, int const cols_total_matrix)
Convenience function for assignments to packed matrix.
Definition: mm_kernel_outer_A.h:181
Template for matrix matrix multiplication that wraps around a kernel given as template argument.
Definition: mm_kernel_outer_A.h:53
Ordering_col_wise Ordering_block_A
Definition: mm_kernel_outer_A.h:68
static int const M
Number of rows of A and C.
Definition: mm_kernel_outer_A.h:63
static int const N_kernel
Number of columns of B and C kernels.
Definition: mm_kernel_outer_A.h:58
T_gemm_kernel::real real
Real number type (usually float or double)
Definition: mm_kernel_outer_A.h:66
static int const K_kernel
Number of columns of A kernels and rows of B kernels.
Definition: mm_kernel_outer_A.h:59
static int const K
Number of columns of A and rows of B.
Definition: mm_kernel_outer_A.h:65
static int const N_block
Number of columns of B and C (blocks).
Definition: mm_kernel_outer_A.h:61
static int const M_kernel
Number of rows of A and C kernels.
Definition: mm_kernel_outer_A.h:57
static void exec(real const *const *const A, real const *const *const B, real *const C, int const i=1)
Executes the matrix-matrix multiply C += A B with the three matrices A, B, and C stored using the pac...
Definition: mm_kernel_outer_A.h:86
Pack< M_block, K_block, Ordering_block_A, typename T_gemm_kernel::Pack_type_A > Pack_type_A
Definition: mm_kernel_outer_A.h:72
Pack< K_block, N_block, Ordering_block_B, typename T_gemm_kernel::Pack_type_B > Pack_type_B
Definition: mm_kernel_outer_A.h:73
static int const K_block
Number of columns of A and rows of B (blocks).
Definition: mm_kernel_outer_A.h:62
Ordering_col_wise Ordering_block_C
Definition: mm_kernel_outer_A.h:70
static int const M_block
Number of rows of A and C (blocks).
Definition: mm_kernel_outer_A.h:60
static int const N
Number of columns of B and C.
Definition: mm_kernel_outer_A.h:64
Pack< M_block, N_block, Ordering_block_C, typename T_gemm_kernel::Pack_type_C > Pack_type_C
Definition: mm_kernel_outer_A.h:74
Ordering_col_wise Ordering_block_B
Definition: mm_kernel_outer_A.h:69
Macros for inlining and static assertions and structs for access to matrix elements specifying the la...
mat::SizesAndBlocks rows
Definition: test.cc:51
mat::SizesAndBlocks cols
Definition: test.cc:52
#define B
#define A
Definition: mm_kernel_outer_A.h:143
T_ordering_matrix Ordering_matrix
Definition: mm_kernel_outer_A.h:144
Definition: mm_kernel_outer_A.h:147
T_ordering_matrix Ordering_matrix
Definition: mm_kernel_outer_A.h:148
Struct for access to matrix elements stored in column wise order.
Definition: common.h:104