JavaAlgorithms/DivideAndConquer/StrassenMatrixMultiplication.java
2021-10-16 16:43:51 +03:00

184 lines
5.0 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package DivideAndConquer;
// Java Program to Implement Strassen Algorithm
// Class Strassen matrix multiplication
public class StrassenMatrixMultiplication {
// Method 1
// Function to multiply matrices
public int[][] multiply(int[][] A, int[][] B)
{
int n = A.length;
int[][] R = new int[n][n];
if (n == 1)
R[0][0] = A[0][0] * B[0][0];
else {
// Dividing Matrix into parts
// by storing sub-parts to variables
int[][] A11 = new int[n / 2][n / 2];
int[][] A12 = new int[n / 2][n / 2];
int[][] A21 = new int[n / 2][n / 2];
int[][] A22 = new int[n / 2][n / 2];
int[][] B11 = new int[n / 2][n / 2];
int[][] B12 = new int[n / 2][n / 2];
int[][] B21 = new int[n / 2][n / 2];
int[][] B22 = new int[n / 2][n / 2];
// Dividing matrix A into 4 parts
split(A, A11, 0, 0);
split(A, A12, 0, n / 2);
split(A, A21, n / 2, 0);
split(A, A22, n / 2, n / 2);
// Dividing matrix B into 4 parts
split(B, B11, 0, 0);
split(B, B12, 0, n / 2);
split(B, B21, n / 2, 0);
split(B, B22, n / 2, n / 2);
// Using Formulas as described in algorithm
// M1:=(A1+A3)×(B1+B2)
int[][] M1
= multiply(add(A11, A22), add(B11, B22));
// M2:=(A2+A4)×(B3+B4)
int[][] M2 = multiply(add(A21, A22), B11);
// M3:=(A1A4)×(B1+A4)
int[][] M3 = multiply(A11, sub(B12, B22));
// M4:=A1×(B2B4)
int[][] M4 = multiply(A22, sub(B21, B11));
// M5:=(A3+A4)×(B1)
int[][] M5 = multiply(add(A11, A12), B22);
// M6:=(A1+A2)×(B4)
int[][] M6
= multiply(sub(A21, A11), add(B11, B12));
// M7:=A4×(B3B1)
int[][] M7
= multiply(sub(A12, A22), add(B21, B22));
// P:=M2+M3M6M7
int[][] C11 = add(sub(add(M1, M4), M5), M7);
// Q:=M4+M6
int[][] C12 = add(M3, M5);
// R:=M5+M7
int[][] C21 = add(M2, M4);
// S:=M1M3M4M5
int[][] C22 = add(sub(add(M1, M3), M2), M6);
join(C11, R, 0, 0);
join(C12, R, 0, n / 2);
join(C21, R, n / 2, 0);
join(C22, R, n / 2, n / 2);
}
return R;
}
// Method 2
// Function to subtract two matrices
public int[][] sub(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] - B[i][j];
return C;
}
// Method 3
// Function to add two matrices
public int[][] add(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] + B[i][j];
return C;
}
// Method 4
// Function to split parent matrix
// into child matrices
public void split(int[][] P, int[][] C, int iB, int jB)
{
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}
// Method 5
// Function to join child matrices
// into (to) parent matrix
public void join(int[][] C, int[][] P, int iB, int jB)
{
for (int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for (int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}
// Method 5
// Main driver method
public static void main(String[] args)
{
System.out.println("Strassen Multiplication Algorithm Implementation For Matrix Multiplication :\n");
StrassenMatrixMultiplication s = new StrassenMatrixMultiplication();
// Size of matrix
// Considering size as 4 in order to illustrate
int N = 4;
// Matrix A
// Custom input to matrix
int[][] A = { { 1, 2, 5, 4 },
{ 9, 3, 0, 6 },
{ 4, 6, 3, 1 },
{ 0, 2, 0, 6 } };
// Matrix B
// Custom input to matrix
int[][] B = { { 1, 0, 4, 1 },
{ 1, 2, 0, 2 },
{ 0, 3, 1, 3 },
{ 1, 8, 1, 2 } };
// Matrix C computations
// Matrix C calling method to get Result
int[][] C = s.multiply(A, B);
System.out.println("\nProduct of matrices A and B : ");
// Print the output
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++)
System.out.print(C[i][j] + " ");
System.out.println();
}
}
}