Strassen algorithm in Python

Strassen’s algorithm is an efficient method for matrix multiplication. It reduces the number of arithmetic operations required for multiplying two matrices by decomposing them into smaller submatrices and performing recursive multiplication. Strassen’s algorithm is based on the divide-and-conquer approach and is particularly useful for large matrices.

Strassen Algorithm:

  • If the size of the matrices is small enough (e.g., 1×1 or 2×2), perform the standard matrix multiplication.
  • Divide each input matrix into four equally sized submatrices.
  • Calculate seven products of these submatrices using addition and subtraction operations.
  • Use these products to compute the four quadrants of the resulting matrix using addition and subtraction operations.

How Strassen algorithm works?

Consider two matrices:

A = [[1, 3], [7, 5]]
B = [[6, 8], [4, 2]]

We want to compute the product C=A×B.

  1. Initialization:
    • Start with the two input matrices A and B.
  2. Base Case Check:
    • Since both matrices A and B are 2×2, which is small enough, we perform standard matrix multiplication.
  3. Partition Matrices:
    • Partition matrices A and B into four submatrices each:
      • A11=[1​], A12=[3​], A21=[7​], A22=[5​]
      • B11=[6​], B12=[8​], B21=[4​], B22=[2​]
  4. Recursive Multiplication:
    • Recursively calculate seven products:
      • P1=A11×(B12−B22)=1×(8−2)=6
      • P2=(A11+A12)×B22=(1+3)×2=8
      • P3=(A21+A22)×B11=(7+5)×6=72
      • P4=A22×(B21−B11)=5×(4−6)=−10
      • P5=(A11+A22)×(B11+B22)=(1+5)×(6+2)=36
      • P6=(A12−A22)×(B21+B22)=(3−5)×(4+2)=−12
      • P7=(A11−A21)×(B11+B12)=(1−7)×(6+8)=−96
  5. Combine Results:
    • Calculate the four quadrants of the resulting matrix C:
      • C11=P5+P4−P2+P6=36+(−10)−8+(−12)=6
      • C12=P1+P2=6+8=14
      • C21=P3+P4=72+(−10)=62
      • C22=P5+P1−P3−P7=36+6−72−(−96)=66

Example Output:

The resulting matrix C (Result of A×B) is:

C = [[6, 14],
[62, 66]]

Python Implementation for Strassen algorithm:

Python
import numpy as np

def strassen(A, B):
    n = len(A)
    
    if n <= 2:  # Base case
        return np.dot(A, B)
    
    # Partition matrices into submatrices
    mid = n // 2
    A11 = A[:mid, :mid]
    A12 = A[:mid, mid:]
    A21 = A[mid:, :mid]
    A22 = A[mid:, mid:]
    B11 = B[:mid, :mid]
    B12 = B[:mid, mid:]
    B21 = B[mid:, :mid]
    B22 = B[mid:, mid:]
    
    # Recursive multiplication
    P1 = strassen(A11, B12 - B22)
    P2 = strassen(A11 + A12, B22)
    P3 = strassen(A21 + A22, B11)
    P4 = strassen(A22, B21 - B11)
    P5 = strassen(A11 + A22, B11 + B22)
    P6 = strassen(A12 - A22, B21 + B22)
    P7 = strassen(A11 - A21, B11 + B12)
    
    # Combine results to form C
    C11 = P5 + P4 - P2 + P6
    C12 = P1 + P2
    C21 = P3 + P4
    C22 = P5 + P1 - P3 - P7
    
    # Combine quadrants to form C
    C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
    return C

# Example usage:
A = np.array([[1, 3], [7, 5]])
B = np.array([[6, 8], [4, 2]])
C = strassen(A, B)
print("Matrix C (Result of A * B):\n", C)

Output
Matrix C (Result of A * B):
 [[18 14]
 [62 66]]

Complexity Analysis:

  • Time Complexity: The time complexity of Strassen’s algorithm is approximately O(n^2.81), where n is the size of the matrices. Although it reduces the number of multiplications, it increases the number of additions and subtractions.
  • Space Complexity: Strassen’s algorithm has a space complexity of O(n^2) due to the recursive calls and the additional space required for submatrices.