NASA Logo
Ocean Color Science Software

ocssw V2022
Strassen.h
Go to the documentation of this file.
1 #include "Matrix.h"
2 
3 float** strassenMultiply(float** A, float** B, float n) {
4  if (n == 1) {
5  float** C = initializeMatrix(1);
6  C[0][0] = A[0][0] * B[0][0];
7  return C;
8  }
9  float** C = initializeMatrix(n);
10  int k = n / 2;
11 
12  float** A11 = initializeMatrix(k);
13  float** A12 = initializeMatrix(k);
14  float** A21 = initializeMatrix(k);
15  float** A22 = initializeMatrix(k);
16  float** B11 = initializeMatrix(k);
17  float** B12 = initializeMatrix(k);
18  float** B21 = initializeMatrix(k);
19  float** B22 = initializeMatrix(k);
20 
21  for (int i = 0; i < k; i++)
22  for (int j = 0; j < k; j++) {
23  A11[i][j] = A[i][j];
24  A12[i][j] = A[i][k + j];
25  A21[i][j] = A[k + i][j];
26  A22[i][j] = A[k + i][k + j];
27  B11[i][j] = B[i][j];
28  B12[i][j] = B[i][k + j];
29  B21[i][j] = B[k + i][j];
30  B22[i][j] = B[k + i][k + j];
31  }
32 
33  float** P1 = strassenMultiply(A11, subtract(B12, B22, k), k);
34  float** P2 = strassenMultiply(add(A11, A12, k), B22, k);
35  float** P3 = strassenMultiply(add(A21, A22, k), B11, k);
36  float** P4 = strassenMultiply(A22, subtract(B21, B11, k), k);
37  float** P5 = strassenMultiply(add(A11, A22, k), add(B11, B22, k), k);
38  float** P6 = strassenMultiply(subtract(A12, A22, k), add(B21, B22, k), k);
39  float** P7 = strassenMultiply(subtract(A11, A21, k), add(B11, B12, k), k);
40 
41  float** C11 = subtract(add(add(P5, P4, k), P6, k), P2, k);
42  float** C12 = add(P1, P2, k);
43  float** C21 = add(P3, P4, k);
44  float** C22 = subtract(subtract(add(P5, P1, k), P3, k), P7, k);
45 
46  for (int i = 0; i < k; i++)
47  for (int j = 0; j < k; j++) {
48  C[i][j] = C11[i][j];
49  C[i][j + k] = C12[i][j];
50  C[k + i][j] = C21[i][j];
51  C[k + i][k + j] = C22[i][j];
52  }
53 
54  for (int i = 0; i < k; i++) {
55  delete[] A11[i];
56  delete[] A12[i];
57  delete[] A21[i];
58  delete[] A22[i];
59  delete[] B11[i];
60  delete[] B12[i];
61  delete[] B21[i];
62  delete[] B22[i];
63  delete[] P1[i];
64  delete[] P2[i];
65  delete[] P3[i];
66  delete[] P4[i];
67  delete[] P5[i];
68  delete[] P6[i];
69  delete[] P7[i];
70  delete[] C11[i];
71  delete[] C12[i];
72  delete[] C21[i];
73  delete[] C22[i];
74  }
75 
76  delete[] A11;
77  delete[] A12;
78  delete[] A21;
79  delete[] A22;
80  delete[] B11;
81  delete[] B12;
82  delete[] B21;
83  delete[] B22;
84  delete[] P1;
85  delete[] P2;
86  delete[] P3;
87  delete[] P4;
88  delete[] P5;
89  delete[] P6;
90  delete[] P7;
91  delete[] C11;
92  delete[] C12;
93  delete[] C21;
94  delete[] C22;
95 
96  return C;
97 }
int j
Definition: decode_rs.h:73
@ B11
Definition: l1_msi.cpp:30
const double C
Definition: calc_par.c:102
const float A
Definition: calc_par.c:100
@ B12
Definition: l1_msi.cpp:30
float ** subtract(float **M1, float **M2, int n)
Definition: Matrix.h:35
float ** add(float **M1, float **M2, int n)
Definition: Matrix.h:27
float ** initializeMatrix(int n)
Definition: Matrix.h:3
const float B
Definition: calc_par.c:101
int i
Definition: decode_rs.h:71
int k
Definition: decode_rs.h:73
float ** strassenMultiply(float **A, float **B, float n)
Definition: Strassen.h:3