package DynamicProgramming; import java.util.ArrayList; import java.util.Arrays; import java.util.Scanner; public class MatrixChainMultiplication { private static Scanner scan = new Scanner(System.in); private static ArrayList mArray = new ArrayList<>(); private static int size; private static int[][] m; private static int[][] s; private static int[] p; public static void main(String[] args) { int count = 1; while (true) { String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : "); int col = Integer.parseInt(mSize[0]); if (col == 0) break; int row = Integer.parseInt(mSize[1]); Matrix matrix = new Matrix(count, col, row); mArray.add(matrix); count++; } for (Matrix m : mArray) { System.out.format("A(%d) = %2d x %2d%n", m.count(), m.col(), m.row()); } size = mArray.size(); m = new int[size + 1][size + 1]; s = new int[size + 1][size + 1]; p = new int[size + 1]; for (int i = 0; i < size + 1; i++) { Arrays.fill(m[i], -1); Arrays.fill(s[i], -1); } for (int i = 0; i < p.length; i++) { p[i] = i == 0 ? mArray.get(i).col() : mArray.get(i - 1).row(); } matrixChainOrder(); for (int i = 0; i < size; i++) { System.out.print("-------"); } System.out.println(); printArray(m); for (int i = 0; i < size; i++) { System.out.print("-------"); } System.out.println(); printArray(s); for (int i = 0; i < size; i++) { System.out.print("-------"); } System.out.println(); System.out.println("Optimal solution : " + m[1][size]); System.out.print("Optimal parens : "); printOptimalParens(1, size); } private static void printOptimalParens(int i, int j) { if (i == j) { System.out.print("A" + i); } else { System.out.print("("); printOptimalParens(i, s[i][j]); printOptimalParens(s[i][j] + 1, j); System.out.print(")"); } } private static void printArray(int[][] array) { for (int i = 1; i < size + 1; i++) { for (int j = 1; j < size + 1; j++) { System.out.print(String.format("%7d", array[i][j])); } System.out.println(); } } private static void matrixChainOrder() { for (int i = 1; i < size + 1; i++) { m[i][i] = 0; } for (int l = 2; l < size + 1; l++) { for (int i = 1; i < size - l + 2; i++) { int j = i + l - 1; m[i][j] = Integer.MAX_VALUE; for (int k = i; k < j; k++) { int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j]; if (q < m[i][j]) { m[i][j] = q; s[i][j] = k; } } } } } private static String[] input(String string) { System.out.print(string); return (scan.nextLine().split(" ")); } } class Matrix { private int count; private int col; private int row; Matrix(int count, int col, int row) { this.count = count; this.col = col; this.row = row; } int count() { return count; } int col() { return col; } int row() { return row; } }