Matrix.java

public class Matrix {
  // mathematische Matrix
  
  private double[][] arr;
  
  public Matrix(int m, int n) {
    // erzeuge nxm-Matrix, mit 0 vorbesetzt
    arr = new double[m][n];
  }
  
  public Matrix(Matrix a) {
    // erzeuge Matrix als Kopie von a
    
    int m = a.getColumnDimension();
    int n = a.getRowDimension();
    arr = new double[m][n];
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        arr[i][j] = a.arr[i][j];
      }
    }
  }
  
  public int getColumnDimension() {
    // Spaltenlänge
    return arr.length;
  }
  
  public int getRowDimension() {
    // Zeilenlänge
    return arr[0].length;
  }
  
  public double get(int i, int j) {
    // hole Array-Element A_ij
    return arr[i][j];
  }
  
  public void set(int i, int j, double s) {
    // setze Array-Element A_ij
    arr[i][j] = s;
  }
  
  public Matrix getMatrix(int i0, int i1, int j0, int j1) {
    // erzeugt eine Teilmatrix A(i0..i1, j0..j1)
    
    int m = i1 - i0 + 1;
    int n = j1 - j0 + 1;
    
    Matrix teil = new Matrix(m, n);
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        teil.arr[i][j] = arr[i + i0][j + j0];
      }
    }
    return teil;
  }
  
  void setMatrix(int i0, int j0, Matrix teil) {
    // belegt die Matrix mit einer Teilmatrix
    
    int m = teil.getColumnDimension();
    int n = teil.getRowDimension();
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        arr[i+ i0][j + j0] = teil.arr[i][j];
      }
    }
  }
  
  public static Matrix matmult(Matrix a, Matrix b) {
    // einfache Matrix-Multiplikation
    
    // Achtung: Dies ist eine Version zum Zweck der Zeitmessung, daher
    //   verzichtet sie auf Dimensions-Überprüfungen.
    
    int m = a.getColumnDimension();
    int l = a.getRowDimension();
    int n = b.getRowDimension();
    
    Matrix c = new Matrix(n, m);
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        c.arr[i][j] = 0.0;
        for (int k = 0; k < l; k++) {
          c.arr[i][j] += a.arr[i][k] * b.arr[k][j];
        }
      }
    }
    
    return c;
  }
  
  public static Matrix matmultStrassen(Matrix a, Matrix b) {
    return matmultStrassenMitCutoff(a, b, 1);
  }
  
  public static Matrix matmultStrassenMitCutoff(Matrix a, Matrix b, int cutoff) {
    // Matrix-Multiplikation nach dem Strassen-Algorithmus
    // für Matrizen mit Dimension >= cutoff wird normal multipliziert
    
    // Achtung: Dies ist eine Version zum Zweck der Zeitmessung, daher
    //   - verzichtet sie auf Dimensions-Überprüfungen,
    //    - geht sie von quadratischen Matrizen mit einer Zweierpotenz als Dimension aus
    
    int m = a.getColumnDimension();
    
    Matrix c = new Matrix(m, m);
    
    // Abbruchbedingung der Rekursion: m <= cutoff
    if (m <= cutoff) {
      // normal multiplizieren
      c = matmult(a, b);
      return c;
    }
    
    // ansonsten rekursiv a la Strassen
    
    // Teilmatrizen herausschneiden
    int dim = m/2;     // sollte ohne Rest aufgehen, da m Zweierpotenz
    Matrix a11 = a.getMatrix(  0, dim-1,   0, dim-1);
    Matrix a12 = a.getMatrix(  0, dim-1, dim,   m-1);
    Matrix a21 = a.getMatrix(dim,   m-1,   0, dim-1);
    Matrix a22 = a.getMatrix(dim,   m-1, dim,   m-1);
    Matrix b11 = b.getMatrix(  0, dim-1,   0, dim-1);
    Matrix b12 = b.getMatrix(  0, dim-1, dim,   m-1);
    Matrix b21 = b.getMatrix(dim,   m-1,   0, dim-1);
    Matrix b22 = b.getMatrix(dim,   m-1, dim,   m-1);
    
    // Matrizen m1 .. m7 berechnen
    // dazu zwei Hilfsmatrizen d1, d2 für Zwischenwerte
    Matrix d1 = a12.minus(a22);
    Matrix d2 = b21.plus(b22);
    Matrix m1 = matmultStrassen(d1, d2);
    
    d1 = a11.plus(a22);
    d2 = b11.plus(b22);
    Matrix m2 = matmultStrassen(d1, d2);
    
    d1 = a11.minus(a21);
    d2 = b11.plus(b12);
    Matrix m3 = matmultStrassen(d1, d2);
    
    d1 = a11.plus(a12);
    Matrix m4 = matmultStrassen(d1, b22);
    
    d1 = b12.minus(b22);
    Matrix m5 = matmultStrassen(a11, d1);
    
    d1 = b21.minus(b11);
    Matrix m6 = matmultStrassen(a22, d1);
    
    d1 = a21.plus(a22);
    Matrix m7 = matmultStrassen(d1, b11);
    
    // Teilmatrizen c11 .. c22 bestimmen
    Matrix c11 = new Matrix(m1);
    c11.plusGleich(m2);
    c11.minusGleich(m4);
    c11.plusGleich(m6);
    
    Matrix c12 = new Matrix(m4);
    c12.plusGleich(m5);
    
    Matrix c21 = new Matrix(m6);
    c21.plusGleich(m7);
    
    Matrix c22 = new Matrix(m2);
    c22.minusGleich(m3);
    c22.plusGleich(m5);
    c22.minusGleich(m7);
    
    // Gesamtmatrix zusammensetzen
    c.setMatrix(  0,   0, c11);
    c.setMatrix(  0, dim, c12);
    c.setMatrix(dim,   0, c21);
    c.setMatrix(dim, dim, c22);
    
    return c;
  }
  
  public  Matrix minus(Matrix a) {
    // gibt Differenz von aktueller Matrix und a zurück
    
    int m = getColumnDimension();
    int n = getRowDimension();
    Matrix c = new Matrix(m, n);
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        c.arr[i][j] = arr[i][j] - a.arr[i][j];
      }
    }
    return c;
  }
  
  public  Matrix plus(Matrix a) {
    // gibt Summe von aktueller Matrix und a zurück
    
    int m = getColumnDimension();
    int n = getRowDimension();
    Matrix c = new Matrix(m, n);
    
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        c.arr[i][j] = arr[i][j] + a.arr[i][j];
      }
    }
    return c;
  }
  
  public void plusGleich(Matrix a) {
    // addiert a zur aktuellen Matrix
    
    for (int i = 0; i < getColumnDimension(); i++) {
      for (int j = 0; j < getRowDimension(); j++) {
        arr[i][j] += a.arr[i][j];
      }
    }
  }
  
  public void minusGleich(Matrix a) {
    // subtrahiert a von der aktuellen Matrix
    
    for (int i = 0; i < getColumnDimension(); i++) {
      for (int j = 0; j < getRowDimension(); j++) {
        arr[i][j] -= a.arr[i][j];
      }
    }
  }
  
  public String toString() {
    // Ausgabestring: Zeilen durch Newline getrennt, Werte durch Komma
    
    String s = "";
    for (int i = 0; i < getColumnDimension(); i++) {
      for (int j = 0; j < getColumnDimension(); j++) {
        s += arr[i][j] + ", ";
      }
      s += "\n";
    }
    return s;
  }
}