Update MatrixChainMultiplication.java

This commit is contained in:
Libin Yang 2018-12-21 09:51:26 +08:00 committed by GitHub
parent 2da7d89123
commit 3183957ec7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,120 +1,134 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;
public class MatrixChainMultiplicationTest {
private static Scanner scan = new Scanner(System.in);
private static ArrayList<Matrix> mArray = new ArrayList<>();
private static int size;
private static int[][] m;
private static int[][] s;
private static int[] p;
public static void main(String[] args) {
int count = 1;
while(true) {
String [] mSize = input("input size of matrix A("+count+") ( ex. 10 20 ) : ");
int col = Integer.parseInt(mSize[0]);
if (col==0) break;
int row = Integer.parseInt(mSize[1]);
Matrix matrix = new Matrix(count, col, row);
mArray.add(matrix);
count++;
}
for(Matrix m : mArray) {
System.out.format("A(%d) = %2d x %2d\n", m.count(), m.col(), m.row());
}
size = mArray.size();
m = new int[size + 1][size + 1];
s = new int[size + 1][size + 1];
p = new int[size + 1];
public class MatrixChainMultiplication {
private static Scanner scan = new Scanner(System.in);
private static ArrayList<Matrix> mArray = new ArrayList<>();
private static int size;
private static int[][] m;
private static int[][] s;
private static int[] p;
for (int i=0; i<size+1; i++) {
for (int j=0; j<size+1; j++) {
m[i][j] = -1;
s[i][j] = -1;
}
}
public static void main(String[] args) {
int count = 1;
while (true) {
String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : ");
int col = Integer.parseInt(mSize[0]);
if (col == 0) break;
int row = Integer.parseInt(mSize[1]);
for (int i=0; i<p.length; i++) {
if (i == 0) {
p[i] = mArray.get(i).col();
} else {
p[i] = mArray.get(i-1).row();
}
}
matrixChainOrder();
for(int i=0; i<size; i++) { System.out.print("-------"); }
System.out.println();
printArray(m);
for(int i=0; i<size; i++) { System.out.print("-------"); }
System.out.println();
printArray(s);
for(int i=0; i<size; i++) { System.out.print("-------"); }
System.out.println();
System.out.println("Optimal solution : "+ m[1][size]);
System.out.print("Optimal parens : ");
printOptimalParens(1, size);
}
private static void printOptimalParens(int i, int j) {
if (i == j) {
System.out.print("A" + i);
} else {
System.out.print("(");
printOptimalParens(i, s[i][j]);
printOptimalParens(s[i][j] + 1, j);
System.out.print(")");
}
}
private static void printArray(int[][] array) {
for (int i = 1; i < size+1; i++) {
for (int j = 1; j < size+1; j++) {
System.out.print(String.format("%7d", array[i][j]));
}
System.out.println();
}
}
private static void matrixChainOrder() {
for (int i = 1; i<size+1; i++) {
m[i][i] = 0;
}
Matrix matrix = new Matrix(count, col, row);
mArray.add(matrix);
count++;
}
for (Matrix m : mArray) {
System.out.format("A(%d) = %2d x %2d\n", m.count(), m.col(), m.row());
}
for (int l = 2; l <size+1; l++) {
for (int i = 1; i<size-l+2; i++) {
int j = i+l-1;
m[i][j] = Integer.MAX_VALUE;
size = mArray.size();
m = new int[size + 1][size + 1];
s = new int[size + 1][size + 1];
p = new int[size + 1];
for (int k = i; k < j; k++) {
int q = m[i][k] + m[k+1][j] + p[i-1] * p[k] * p[j];
if (q < m[i][j]) {
m[i][j] = q;
s[i][j] = k;
}
}
}
}
}
private static String[] input(String string) {
System.out.print(string);
return (scan.nextLine().split(" ") );
}
for (int i = 0; i < size + 1; i++) {
Arrays.fill(m[i], -1);
Arrays.fill(s[i], -1);
}
for (int i = 0; i < p.length; i++) {
p[i] = i == 0 ? mArray.get(i).col() : mArray.get(i - 1).row();
}
matrixChainOrder();
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(m);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(s);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
System.out.println("Optimal solution : " + m[1][size]);
System.out.print("Optimal parens : ");
printOptimalParens(1, size);
}
private static void printOptimalParens(int i, int j) {
if (i == j) {
System.out.print("A" + i);
} else {
System.out.print("(");
printOptimalParens(i, s[i][j]);
printOptimalParens(s[i][j] + 1, j);
System.out.print(")");
}
}
private static void printArray(int[][] array) {
for (int i = 1; i < size + 1; i++) {
for (int j = 1; j < size + 1; j++) {
System.out.print(String.format("%7d", array[i][j]));
}
System.out.println();
}
}
private static void matrixChainOrder() {
for (int i = 1; i < size + 1; i++) {
m[i][i] = 0;
}
for (int l = 2; l < size + 1; l++) {
for (int i = 1; i < size - l + 2; i++) {
int j = i + l - 1;
m[i][j] = Integer.MAX_VALUE;
for (int k = i; k < j; k++) {
int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
if (q < m[i][j]) {
m[i][j] = q;
s[i][j] = k;
}
}
}
}
}
private static String[] input(String string) {
System.out.print(string);
return (scan.nextLine().split(" "));
}
}
class Matrix {
private int count;
private int col;
private int row;
public Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
}
public int count() { return this.count; }
public int col() { return this.col; }
public int row() { return this.row; }
private int count;
private int col;
private int row;
Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
}
int count() {
return count;
}
int col() {
return col;
}
int row() {
return row;
}
}