From 65e32641fc96012ae840ebc93eb2176589527783 Mon Sep 17 00:00:00 2001 From: Alex Klymenko Date: Mon, 9 Sep 2024 09:15:41 +0200 Subject: [PATCH] refactor: `InverseOfMatrix` (#5446) refactor: InverseOfMatrix --- .../thealgorithms/misc/InverseOfMatrix.java | 74 ++++++------------- .../misc/InverseOfMatrixTest.java | 28 +++++++ 2 files changed, 52 insertions(+), 50 deletions(-) create mode 100644 src/test/java/com/thealgorithms/misc/InverseOfMatrixTest.java diff --git a/src/main/java/com/thealgorithms/misc/InverseOfMatrix.java b/src/main/java/com/thealgorithms/misc/InverseOfMatrix.java index 5543463e..706feab0 100644 --- a/src/main/java/com/thealgorithms/misc/InverseOfMatrix.java +++ b/src/main/java/com/thealgorithms/misc/InverseOfMatrix.java @@ -1,57 +1,29 @@ package com.thealgorithms.misc; -import java.util.Scanner; - -/* - * Wikipedia link : https://en.wikipedia.org/wiki/Invertible_matrix - * - * Here we use gauss elimination method to find the inverse of a given matrix. - * To understand gauss elimination method to find inverse of a matrix: - * https://www.sangakoo.com/en/unit/inverse-matrix-method-of-gaussian-elimination - * - * We can also find the inverse of a matrix +/** + * This class provides methods to compute the inverse of a square matrix + * using Gaussian elimination. For more details, refer to: + * https://en.wikipedia.org/wiki/Invertible_matrix */ public final class InverseOfMatrix { private InverseOfMatrix() { } - public static void main(String[] argv) { - Scanner input = new Scanner(System.in); - System.out.println("Enter the matrix size (Square matrix only): "); - int n = input.nextInt(); - double[][] a = new double[n][n]; - System.out.println("Enter the elements of matrix: "); - for (int i = 0; i < n; i++) { - for (int j = 0; j < n; j++) { - a[i][j] = input.nextDouble(); - } - } - - double[][] d = invert(a); - System.out.println(); - System.out.println("The inverse is: "); - for (int i = 0; i < n; ++i) { - for (int j = 0; j < n; ++j) { - System.out.print(d[i][j] + " "); - } - System.out.println(); - } - input.close(); - } - public static double[][] invert(double[][] a) { int n = a.length; double[][] x = new double[n][n]; double[][] b = new double[n][n]; int[] index = new int[n]; + + // Initialize the identity matrix for (int i = 0; i < n; ++i) { b[i][i] = 1; } - // Transform the matrix into an upper triangle + // Perform Gaussian elimination gaussian(a, index); - // Update the matrix b[i][j] with the ratios stored + // Update matrix b with the ratios stored during elimination for (int i = 0; i < n - 1; ++i) { for (int j = i + 1; j < n; ++j) { for (int k = 0; k < n; ++k) { @@ -60,7 +32,7 @@ public final class InverseOfMatrix { } } - // Perform backward substitutions + // Perform backward substitution to find the inverse for (int i = 0; i < n; ++i) { x[n - 1][i] = b[index[n - 1]][i] / a[index[n - 1]][n - 1]; for (int j = n - 2; j >= 0; --j) { @@ -73,19 +45,20 @@ public final class InverseOfMatrix { } return x; } - - // Method to carry out the partial-pivoting Gaussian - // elimination. Here index[] stores pivoting order. - public static void gaussian(double[][] a, int[] index) { + /** + * Method to carry out the partial-pivoting Gaussian + * elimination. Here index[] stores pivoting order. + **/ + private static void gaussian(double[][] a, int[] index) { int n = index.length; double[] c = new double[n]; - // Initialize the index + // Initialize the index array for (int i = 0; i < n; ++i) { index[i] = i; } - // Find the rescaling factors, one from each row + // Find the rescaling factors for each row for (int i = 0; i < n; ++i) { double c1 = 0; for (int j = 0; j < n; ++j) { @@ -97,22 +70,23 @@ public final class InverseOfMatrix { c[i] = c1; } - // Search the pivoting element from each column - int k = 0; + // Perform pivoting for (int j = 0; j < n - 1; ++j) { double pi1 = 0; + int k = j; for (int i = j; i < n; ++i) { - double pi0 = Math.abs(a[index[i]][j]); - pi0 /= c[index[i]]; + double pi0 = Math.abs(a[index[i]][j]) / c[index[i]]; if (pi0 > pi1) { pi1 = pi0; k = i; } } - // Interchange rows according to the pivoting order - int itmp = index[j]; + + // Swap rows + int temp = index[j]; index[j] = index[k]; - index[k] = itmp; + index[k] = temp; + for (int i = j + 1; i < n; ++i) { double pj = a[index[i]][j] / a[index[j]][j]; diff --git a/src/test/java/com/thealgorithms/misc/InverseOfMatrixTest.java b/src/test/java/com/thealgorithms/misc/InverseOfMatrixTest.java new file mode 100644 index 00000000..2f20de44 --- /dev/null +++ b/src/test/java/com/thealgorithms/misc/InverseOfMatrixTest.java @@ -0,0 +1,28 @@ +package com.thealgorithms.misc; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class InverseOfMatrixTest { + + @ParameterizedTest + @MethodSource("provideTestCases") + void testInvert(double[][] matrix, double[][] expectedInverse) { + double[][] result = InverseOfMatrix.invert(matrix); + assertMatrixEquals(expectedInverse, result); + } + + private static Stream provideTestCases() { + return Stream.of(Arguments.of(new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}, new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}), Arguments.of(new double[][] {{4, 7}, {2, 6}}, new double[][] {{0.6, -0.7}, {-0.2, 0.4}})); + } + + private void assertMatrixEquals(double[][] expected, double[][] actual) { + for (int i = 0; i < expected.length; i++) { + assertArrayEquals(expected[i], actual[i], 1.0E-10, "Row " + i + " is not equal"); + } + } +}