package Maths; import java.math.BigDecimal; import java.util.Arrays; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; import java.util.stream.IntStream; /** * @author: caos321 * @date: 31 October 2021 (Sunday) */ public class MatrixUtil { public static boolean isValid(final BigDecimal[][] matrix) { return matrix != null && matrix.length > 0 && matrix[0].length > 0; } public static boolean hasEqualSizes(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2) { return isValid(matrix1) && isValid(matrix2) && matrix1.length == matrix2.length && matrix1[0].length == matrix2[0].length; } public static boolean canMultiply(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2) { return isValid(matrix1) && isValid(matrix2) && matrix1[0].length == matrix2.length; } public static Optional operate(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2, final BiFunction operation) { if (!hasEqualSizes(matrix1, matrix2)) { return Optional.empty(); } final int rowSize = matrix1.length; final int columnSize = matrix1[0].length; final BigDecimal[][] result = new BigDecimal[rowSize][columnSize]; IntStream.range(0, rowSize).forEach(rowIndex -> IntStream.range(0, columnSize).forEach(columnIndex -> { final BigDecimal value1 = matrix1[rowIndex][columnIndex]; final BigDecimal value2 = matrix2[rowIndex][columnIndex]; result[rowIndex][columnIndex] = operation.apply(value1, value2); })); return Optional.of(result); } public static Optional add(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2) { return operate(matrix1, matrix2, BigDecimal::add); } public static Optional subtract(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2) { return operate(matrix1, matrix2, BigDecimal::subtract); } public static Optional multiply(final BigDecimal[][] matrix1, final BigDecimal[][] matrix2) { if (!canMultiply(matrix1, matrix2)) { return Optional.empty(); } final int size = matrix1[0].length; final int matrix1RowSize = matrix1.length; final int matrix2ColumnSize = matrix2[0].length; final BigDecimal[][] result = new BigDecimal[matrix1RowSize][matrix2ColumnSize]; IntStream.range(0, matrix1RowSize).forEach(rowIndex -> IntStream.range(0, matrix2ColumnSize).forEach(columnIndex -> result[rowIndex][columnIndex] = IntStream.range(0, size).mapToObj(index -> { final BigDecimal value1 = matrix1[rowIndex][index]; final BigDecimal value2 = matrix2[index][columnIndex]; return value1.multiply(value2); }) .reduce(BigDecimal.ZERO, BigDecimal::add) ) ); return Optional.of(result); } public static void assertThat(final BigDecimal[][] actual, final BigDecimal[][] expected) { if (!Objects.deepEquals(actual, expected)) { throw new AssertionError(String.format( "expected=%s but was actual=%s", Arrays.deepToString(expected), Arrays.deepToString(actual) )); } } public static void main(final String[] args) { { final BigDecimal[][] matrix1 = { {new BigDecimal(3), new BigDecimal(2)}, {new BigDecimal(0), new BigDecimal(1)}, }; final BigDecimal[][] matrix2 = { {new BigDecimal(1), new BigDecimal(3)}, {new BigDecimal(2), new BigDecimal(0)}, }; final BigDecimal[][] actual = add(matrix1, matrix2) .orElseThrow(() -> new AssertionError("Could not compute matrix!")); final BigDecimal[][] expected = { {new BigDecimal(4), new BigDecimal(5)}, {new BigDecimal(2), new BigDecimal(1)} }; assertThat(actual, expected); } { final BigDecimal[][] matrix1 = { {new BigDecimal(1), new BigDecimal(4)}, {new BigDecimal(5), new BigDecimal(6)}, }; final BigDecimal[][] matrix2 = { {new BigDecimal(2), new BigDecimal(0)}, {new BigDecimal(-2), new BigDecimal(-3)}, }; final BigDecimal[][] actual = subtract(matrix1, matrix2) .orElseThrow(() -> new AssertionError("Could not compute matrix!")); final BigDecimal[][] expected = { {new BigDecimal(-1), new BigDecimal(4)}, {new BigDecimal(7), new BigDecimal(9)} }; assertThat(actual, expected); } { final BigDecimal[][] matrix1 = { {new BigDecimal(1), new BigDecimal(2), new BigDecimal(3)}, {new BigDecimal(4), new BigDecimal(5), new BigDecimal(6)}, {new BigDecimal(7), new BigDecimal(8), new BigDecimal(9)} }; final BigDecimal[][] matrix2 = { {new BigDecimal(1), new BigDecimal(2)}, {new BigDecimal(3), new BigDecimal(4)}, {new BigDecimal(5), new BigDecimal(6)} }; final BigDecimal[][] actual = multiply(matrix1, matrix2) .orElseThrow(() -> new AssertionError("Could not compute matrix!")); final BigDecimal[][] expected = { {new BigDecimal(22), new BigDecimal(28)}, {new BigDecimal(49), new BigDecimal(64)}, {new BigDecimal(76), new BigDecimal(100)} }; assertThat(actual, expected); } } }