182 lines
5.0 KiB
Java
182 lines
5.0 KiB
Java
// Java Program to Implement Strassen Algorithm
|
||
|
||
// Class Strassen matrix multiplication
|
||
public class StrassenMatrixMultiplication {
|
||
|
||
// Method 1
|
||
// Function to multiply matrices
|
||
public int[][] multiply(int[][] A, int[][] B)
|
||
{
|
||
int n = A.length;
|
||
|
||
int[][] R = new int[n][n];
|
||
|
||
if (n == 1)
|
||
|
||
R[0][0] = A[0][0] * B[0][0];
|
||
|
||
else {
|
||
// Dividing Matrix into parts
|
||
// by storing sub-parts to variables
|
||
int[][] A11 = new int[n / 2][n / 2];
|
||
int[][] A12 = new int[n / 2][n / 2];
|
||
int[][] A21 = new int[n / 2][n / 2];
|
||
int[][] A22 = new int[n / 2][n / 2];
|
||
int[][] B11 = new int[n / 2][n / 2];
|
||
int[][] B12 = new int[n / 2][n / 2];
|
||
int[][] B21 = new int[n / 2][n / 2];
|
||
int[][] B22 = new int[n / 2][n / 2];
|
||
|
||
// Dividing matrix A into 4 parts
|
||
split(A, A11, 0, 0);
|
||
split(A, A12, 0, n / 2);
|
||
split(A, A21, n / 2, 0);
|
||
split(A, A22, n / 2, n / 2);
|
||
|
||
// Dividing matrix B into 4 parts
|
||
split(B, B11, 0, 0);
|
||
split(B, B12, 0, n / 2);
|
||
split(B, B21, n / 2, 0);
|
||
split(B, B22, n / 2, n / 2);
|
||
|
||
// Using Formulas as described in algorithm
|
||
|
||
// M1:=(A1+A3)×(B1+B2)
|
||
int[][] M1
|
||
= multiply(add(A11, A22), add(B11, B22));
|
||
|
||
// M2:=(A2+A4)×(B3+B4)
|
||
int[][] M2 = multiply(add(A21, A22), B11);
|
||
|
||
// M3:=(A1−A4)×(B1+A4)
|
||
int[][] M3 = multiply(A11, sub(B12, B22));
|
||
|
||
// M4:=A1×(B2−B4)
|
||
int[][] M4 = multiply(A22, sub(B21, B11));
|
||
|
||
// M5:=(A3+A4)×(B1)
|
||
int[][] M5 = multiply(add(A11, A12), B22);
|
||
|
||
// M6:=(A1+A2)×(B4)
|
||
int[][] M6
|
||
= multiply(sub(A21, A11), add(B11, B12));
|
||
|
||
// M7:=A4×(B3−B1)
|
||
int[][] M7
|
||
= multiply(sub(A12, A22), add(B21, B22));
|
||
|
||
// P:=M2+M3−M6−M7
|
||
int[][] C11 = add(sub(add(M1, M4), M5), M7);
|
||
|
||
// Q:=M4+M6
|
||
int[][] C12 = add(M3, M5);
|
||
|
||
// R:=M5+M7
|
||
int[][] C21 = add(M2, M4);
|
||
|
||
// S:=M1−M3−M4−M5
|
||
int[][] C22 = add(sub(add(M1, M3), M2), M6);
|
||
|
||
join(C11, R, 0, 0);
|
||
join(C12, R, 0, n / 2);
|
||
join(C21, R, n / 2, 0);
|
||
join(C22, R, n / 2, n / 2);
|
||
}
|
||
|
||
return R;
|
||
}
|
||
|
||
// Method 2
|
||
// Function to subtract two matrices
|
||
public int[][] sub(int[][] A, int[][] B)
|
||
{
|
||
int n = A.length;
|
||
|
||
int[][] C = new int[n][n];
|
||
|
||
for (int i = 0; i < n; i++)
|
||
for (int j = 0; j < n; j++)
|
||
C[i][j] = A[i][j] - B[i][j];
|
||
|
||
return C;
|
||
}
|
||
|
||
// Method 3
|
||
// Function to add two matrices
|
||
public int[][] add(int[][] A, int[][] B)
|
||
{
|
||
|
||
int n = A.length;
|
||
|
||
int[][] C = new int[n][n];
|
||
|
||
for (int i = 0; i < n; i++)
|
||
for (int j = 0; j < n; j++)
|
||
C[i][j] = A[i][j] + B[i][j];
|
||
|
||
return C;
|
||
}
|
||
|
||
// Method 4
|
||
// Function to split parent matrix
|
||
// into child matrices
|
||
public void split(int[][] P, int[][] C, int iB, int jB)
|
||
{
|
||
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
|
||
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
|
||
C[i1][j1] = P[i2][j2];
|
||
}
|
||
|
||
// Method 5
|
||
// Function to join child matrices
|
||
// into (to) parent matrix
|
||
public void join(int[][] C, int[][] P, int iB, int jB)
|
||
|
||
{
|
||
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
|
||
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
|
||
P[i2][j2] = C[i1][j1];
|
||
}
|
||
|
||
// Method 5
|
||
// Main driver method
|
||
public static void main(String[] args)
|
||
{
|
||
System.out.println("Strassen Multiplication Algorithm Implementation For Matrix Multiplication :\n");
|
||
|
||
StrassenMatrixMultiplication s = new StrassenMatrixMultiplication();
|
||
|
||
// Size of matrix
|
||
// Considering size as 4 in order to illustrate
|
||
int N = 4;
|
||
|
||
// Matrix A
|
||
// Custom input to matrix
|
||
int[][] A = { { 1, 2, 5, 4 },
|
||
{ 9, 3, 0, 6 },
|
||
{ 4, 6, 3, 1 },
|
||
{ 0, 2, 0, 6 } };
|
||
|
||
// Matrix B
|
||
// Custom input to matrix
|
||
int[][] B = { { 1, 0, 4, 1 },
|
||
{ 1, 2, 0, 2 },
|
||
{ 0, 3, 1, 3 },
|
||
{ 1, 8, 1, 2 } };
|
||
|
||
// Matrix C computations
|
||
|
||
// Matrix C calling method to get Result
|
||
int[][] C = s.multiply(A, B);
|
||
|
||
System.out.println("\nProduct of matrices A and B : ");
|
||
|
||
// Print the output
|
||
for (int i = 0; i < N; i++) {
|
||
for (int j = 0; j < N; j++)
|
||
System.out.print(C[i][j] + " ");
|
||
System.out.println();
|
||
}
|
||
}
|
||
}
|