/** Matrix: * matrix of doubles * secure version with checks * organized like a Fortran array: * indices running from 1 to n * stored in column order */ public class Matrix { protected int nRows, nCols; // number of colums and rows protected double[] val; // vector of values /** creates a new matrix with given number of rows and columns */ public Matrix(int rows, int cols) { if (rows < 0 || cols < 0) { // throw an exception throw new NegativeArraySizeException( "Invalid array size: " + rows + " x " + cols); } nRows = rows; nCols = cols; val = new double[nRows * nCols]; } /** returns shape (= number of rows and columns) of the matrix */ public int[] shape() { int[] sizes = new int[2]; sizes[0] = nRows; sizes[1] = nCols; return sizes; } /** sets Matrix(i,j) to v */ public void set(int i, int j, double v) { // check indices if (i > 0 && i <= nRows && j > 0 && j <= nCols) { val[i - 1 + nRows * (j - 1)] = v; } else { throw new IndexOutOfBoundsException( "Invalid index: (" + i + "," + j + ")" ); } } /** returns Matrix(i,j) */ public double get(int i, int j) { // check indices if (i > 0 && i <= nRows && j > 0 && j <= nCols) { return val[i - 1 + nRows * (j - 1)]; } else { throw new IndexOutOfBoundsException( "Invalid index: (" + i + "," + j + ")" ); } } /** result = matmul(this, b) */ public void mult(Matrix b, Matrix result) { int[] dimA = shape(); int[] dimB = b.shape(); int[] dimC = result.shape(); // check dimensions if (dimA[1] != dimB[0]) { throw new IllegalArgumentException( "Invalid matrix dimensions for multiplication: (" + dimA[0] + "," + dimA[1] + ") x (" + dimB[0] + "," + dimB[1] + ")" ); } else if (dimA[0] != dimC[0] || dimB[1] != dimC[1]) { throw new IllegalArgumentException( "Invalid matrix dimensions for result: is (" + dimC[0] + "," + dimC[1] + "), should be (" + dimA[0] + "," + dimB[1] + ")" ); } int m = dimA[0]; int l = dimA[1]; int n = dimB[1]; for (int i = 1; i <= m; i++) { for (int j = 1; j <= n; j++) { result.val[i - 1 + m * (j - 1)] = 0.0; for (int k = 1; k <= l; k++) { result.val[i - 1 + m * (j - 1)] += val[i - 1 + m * (k - 1)] * b.val[k - 1 + l * (j - 1)]; } } } } }