In describing and analyzing Strassen\'s algorithm we assumed that we used divide
ID: 3776431 • Letter: I
Question
In describing and analyzing Strassen's algorithm we assumed that we used divide and conquer all the way down to tiny matrices. However, on small matrices the ordinary brute-force matrix multiplication algorithm will be faster because of lower overhead. This is a common issue with divide and conquer algorithms. The best way to run these algorithms is to test the input size n at the start to see if it is big enough to make using divide and conquer worthwhile; if n is larger than some threshold then the algorithm would do a level of recursion, if n is below that threshold then it would do the non-recursive algorithm. Your job is to figure out the best choice for that threshold value for a version of Strassen's algorithm based on your implementation. (Sec the class slides for the description of the Strassen's algorithm and for the code for the basic non-recursive algorithm for matrix multiplication.) Steps: First, write a function that implements the ordinary brute-force algorithm for matrix multiplication over the integers (sec the slides). Next, write a recursive function that implements Strasscn's algorithm for matrix multiplication over the integers (sec the slides). To keep the recursion simple, you can assume that the input matrices are n times n matrices where n is a perfect power of 2. Now modify your recursive function so that it uses an extra argument s and tests the input size n of the matrices first. If n is at most s then it skips out of the recursion and calls the brute-force version from above. Your job now is to figure out which value of s is best to use with your algorithm. To do so, you will need to figure out how long your code takes to work for various values of s. (You can assume s is a power of 2. You can experiment with n = 2048) Generalize your code so that it works with matrices of any dimension, not just powers of 2. Figure out the best value of s to use in this case by re-doing the above analysis.Explanation / Answer
public class matmult
{
public int[][] strassen(int[][] A, int[][] B)
{
int n = A.length;
int[][] R = new int[n][n];
/** base case **/
if (n == 1)
R[0][0] = A[0][0] * B[0][0];
else
{
int[][] A11 = new int[n/2][n/2];
int[][] A12 = new int[n/2][n/2];
int[][] A21 = new int[n/2][n/2];
int[][] A22 = new int[n/2][n/2];
int[][] B11 = new int[n/2][n/2];
int[][] B12 = new int[n/2][n/2];
int[][] B21 = new int[n/2][n/2];
int[][] B22 = new int[n/2][n/2];
split(A, A11, 0 , 0);
split(A, A12, 0 , n/2);
split(A, A21, n/2, 0);
split(A, A22, n/2, n/2);
split(B, B11, 0 , 0);
split(B, B12, 0 , n/2);
split(B, B21, n/2, 0);
split(B, B22, n/2, n/2);
int [][] M1 = strassen(add(A11, A22), add(B11, B22));
int [][] M2 = strassen(add(A21, A22), B11);
int [][] M3 = strassen(A11, sub(B12, B22));
int [][] M4 = strassen(A22, sub(B21, B11));
int [][] M5 = strassen(add(A11, A12), B22);
int [][] M6 = strassen(sub(A21, A11), add(B11, B12));
int [][] M7 = strassen(sub(A12, A22), add(B21, B22));
int [][] C11 = add(sub(add(M1, M4), M5), M7);
int [][] C12 = add(M3, M5);
int [][] C21 = add(M2, M4);
int [][] C22 = add(sub(add(M1, M3), M2), M6);
join(C11, R, 0 , 0);
join(C12, R, 0 , n/2);
join(C21, R, n/2, 0);
join(C22, R, n/2, n/2);
}
return R;
}
public int[][] sub(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] - B[i][j];
return C;
}
public int[][] add(int[][] A, int[][] B)
{
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] + B[i][j];
return C;
}
public void split(int[][] P, int[][] C, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}
public void join(int[][] C, int[][] P, int iB, int jB)
{
for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}
public int[][] brute_force(int[][] A, int[][] B)
{
int mat3[][]=new int[A.length][B.length];
for(int i=0;i<A.length;i++)
for(int j=0;j<B.length;j++)
for(int k=0;k<B.length;k++)
mat3[i][j]+=A[i][k]*B[k][j];
return mat3;
}
public void compare_runtime()
{
int A[][],B[][];
long startTime,endTime;
System.out.println("n | Strassen | Brute-force");
System.out.println(" (in ms) | (in ms) ");
for(int k=2;k<=2048;k++)
{
System.out.print(k+" ");
A=new int[k][k];
B=new int[k][k];
for(int i=0;i<k;i++)
for(int j=0;j<k;j++)
A[i][j]=(int)(Math.random()*100);
startTime = System.currentTimeMillis();
B=strassen(A,A);
endTime = System.currentTimeMillis();
System.out.print((endTime - startTime)+" ");
startTime = System.currentTimeMillis();
B=brute_force(A,A);
endTime = System.currentTimeMillis();
System.out.print(endTime - startTime);
System.out.println();
}
}
public static void main(String args[])
{
System.out.println("The runtime comparison between strassen and brute force approach ");
matmult m=new matmult();
m.compare_runtime();
}
}
Related Questions
drjack9650@gmail.com
Navigate
Integrity-first tutoring: explanations and feedback only — we do not complete graded work. Learn more.