C++
// Strassen’s Matrix Multiplication
#include <bits/stdc++.h>
using namespace std;
// Size of two matrices
#define ROW_1 4
#define COL_1 4
#define ROW_2 4
#define COL_2 4
// print the matrix
void print(vector<vector<int> > matrix) {
for(int i = 0; i < matrix.size(); i++){
for(int j = 0; j < matrix[i].size(); j++){
cout << matrix[i][j] << ' ';
}
cout << endl;
}
}
// Add two matrices and return the result
vector<vector<int>> add(vector<vector<int> > A, vector<vector<int> > B, int split_index, int multiplier = 1) {
for (auto i = 0; i < split_index; i++)
for (auto j = 0; j < split_index; j++)
A[i][j] = A[i][j] + (multiplier * B[i][j]);
return A;
}
vector<vector<int> >
strassen_multiplication(vector<vector<int> > A, vector<vector<int> > B) {
// calculating the size of matrix
int col_1 = A[0].size();
int row_1 = A.size();
int col_2 = B[0].size();
int row_2 = B.size();
// checking if multiplication is possible or not
// between the input matrices
if (col_1 != row_2) {
cout << "The Two Matrices cannot be multiplied";
return {};
}
// creating an empty matrix to store the result
vector<int> result_row(col_2, 0);
vector<vector<int> > result(row_1, result_row);
// Base case
// if size of matrix is 1
if (col_1 == 1)
result[0][0]
= A[0][0] * B[0][0];
else {
// split index
int split_index = col_1 / 2;
vector<int> row_vector(split_index, 0);
// Splitting the matrices in sub matrices
vector<vector<int> > a00(split_index, row_vector);
vector<vector<int> > a01(split_index, row_vector);
vector<vector<int> > a10(split_index, row_vector);
vector<vector<int> > a11(split_index, row_vector);
vector<vector<int> > b00(split_index, row_vector);
vector<vector<int> > b01(split_index, row_vector);
vector<vector<int> > b10(split_index, row_vector);
vector<vector<int> > b11(split_index, row_vector);
// calculating and storing the result
// inside our quadrants
for (auto i = 0; i < split_index; i++)
for (auto j = 0; j < split_index; j++) {
a00[i][j] = A[i][j];
a01[i][j] = A[i][j + split_index];
a10[i][j] = A[split_index + i][j];
a11[i][j] = A[i + split_index]
[j + split_index];
b00[i][j] = B[i][j];
b01[i][j] = B[i][j + split_index];
b10[i][j] = B[split_index + i][j];
b11[i][j] = B[i + split_index]
[j + split_index];
}
// Calculating the multiplication using the formula
// given by strassent algorithm
vector<vector<int>> p1(
strassen_multiplication(a00, add(b01, b11, split_index, -1))
);
vector<vector<int>> p2(
strassen_multiplication(add(a00, a01, split_index), b11)
);
vector<vector<int>> p3(
strassen_multiplication(add(a10, a11, split_index), b00)
);
vector<vector<int>> p4(
strassen_multiplication(a11, add(b10, b00, split_index, -1))
);
vector<vector<int>> p5(
strassen_multiplication(add(a00, a11, split_index),add(b00, b11, split_index))
);
vector<vector<int>> p6(
strassen_multiplication(add(a01, a11, split_index, -1),add(b10, b11, split_index))
);
vector<vector<int>> p7(
strassen_multiplication(
add(a00, a10, split_index, -1),
add(b00, b01, split_index)
)
);
// calculating the result
vector<vector<int> > result_00(
add(add(add(p5, p4, split_index), p6, split_index), p2, split_index, -1)
);
vector<vector<int> > result_01(
add(p1, p2, split_index)
);
vector<vector<int> > result_10(
add(p3, p4, split_index)
);
vector<vector<int> > result_11(
add(add(add(p5, p1, split_index), p3, split_index, -1), p7, split_index, -1)
);
// calulating and storing the result
// inside matrix
for (auto i = 0; i < split_index; i++){
for (auto j = 0; j < split_index; j++) {
result[i][j] = result_00[i][j];
result[i][j + split_index] = result_01[i][j];
result[split_index + i][j] = result_10[i][j];
result[i + split_index][j + split_index] = result_11[i][j];
}
}
// clearing all the arrays
a00.clear();
a01.clear();
a10.clear();
a11.clear();
b00.clear();
b01.clear();
b10.clear();
b11.clear();
p1.clear();
p2.clear();
p3.clear();
p4.clear();
p5.clear();
p6.clear();
p7.clear();
result_00.clear();
result_01.clear();
result_10.clear();
result_11.clear();
}
return result;
}
int main() {
// Input Matrix A
vector<vector<int>> A = {{2, 2, 3, 1},{1, 4, 1, 2},{2, 3, 1, 1}, {1, 3, 1, 2}};
// Input Matrix B
vector<vector<int>> B = {{2, 1, 2, 1},{3, 1, 2, 1},{3, 2, 1, 1}, {1, 4, 3, 2}};
// Getting the result
vector<vector<int> > result(strassen_multiplication(A, B));
// Printing the result
print(result);
}

You can also try this code with Online C++ Compiler
Run Code
Java
import java.util.Arrays;
public class StrassenMatrixMultiplication {
static int[][] add(int[][] A, int[][] B, int splitIndex, int multiplier) {
int[][] result = new int[splitIndex][splitIndex];
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
result[i][j] = A[i][j] + multiplier * B[i][j];
}
}
return result;
}
static int[][] strassenMultiplication(int[][] A, int[][] B) {
int N = A.length;
int[][] C = new int[N][N];
if (N == 1) {
C[0][0] = A[0][0] * B[0][0];
return C;
}
int splitIndex = N / 2;
int[][] a00 = new int[splitIndex][splitIndex];
int[][] a01 = new int[splitIndex][splitIndex];
int[][] a10 = new int[splitIndex][splitIndex];
int[][] a11 = new int[splitIndex][splitIndex];
int[][] b00 = new int[splitIndex][splitIndex];
int[][] b01 = new int[splitIndex][splitIndex];
int[][] b10 = new int[splitIndex][splitIndex];
int[][] b11 = new int[splitIndex][splitIndex];
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
a00[i][j] = A[i][j];
a01[i][j] = A[i][j + splitIndex];
a10[i][j] = A[i + splitIndex][j];
a11[i][j] = A[i + splitIndex][j + splitIndex];
b00[i][j] = B[i][j];
b01[i][j] = B[i][j + splitIndex];
b10[i][j] = B[i + splitIndex][j];
b11[i][j] = B[i + splitIndex][j + splitIndex];
}
}
int[][] p1 = strassenMultiplication(a00, add(b01, b11, splitIndex, -1));
int[][] p2 = strassenMultiplication(add(a00, a01, splitIndex, 1), b11);
int[][] p3 = strassenMultiplication(add(a10, a11, splitIndex, 1), b00);
int[][] p4 = strassenMultiplication(a11, add(b10, b00, splitIndex, -1));
int[][] p5 = strassenMultiplication(add(a00, a11, splitIndex, 1), add(b00, b11, splitIndex, 1));
int[][] p6 = strassenMultiplication(add(a01, a11, splitIndex, -1), add(b10, b11, splitIndex, 1));
int[][] p7 = strassenMultiplication(add(a00, a10, splitIndex, -1), add(b00, b01, splitIndex, 1));
int[][] c00 = add(add(add(p5, p4, splitIndex, 1), p6, splitIndex, 1), p2, splitIndex, -1);
int[][] c01 = add(p1, p2, splitIndex, 1);
int[][] c10 = add(p3, p4, splitIndex, 1);
int[][] c11 = add(add(add(p5, p1, splitIndex, 1), p3, splitIndex, -1), p7, splitIndex, -1);
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
C[i][j] = c00[i][j];
C[i][j + splitIndex] = c01[i][j];
C[i + splitIndex][j] = c10[i][j];
C[i + splitIndex][j + splitIndex] = c11[i][j];
}
}
return C;
}
static void printMatrix(int[][] matrix) {
for (int[] row : matrix) {
System.out.println(Arrays.toString(row));
}
}
public static void main(String[] args) {
int[][] A = {{2, 2, 3, 1}, {1, 4, 1, 2}, {2, 3, 1, 1}, {1, 3, 1, 2}};
int[][] B = {{2, 1, 2, 1}, {3, 1, 2, 1}, {3, 2, 1, 1}, {1, 4, 3, 2}};
int[][] result = strassenMultiplication(A, B);
printMatrix(result);
}
}

You can also try this code with Online Java Compiler
Run Code
Python
def add(A, B, split_index, multiplier=1):
return [[A[i][j] + multiplier * B[i][j] for j in range(split_index)] for i in range(split_index)]
def strassen_multiplication(A, B):
N = len(A)
if N == 1:
return [[A[0][0] * B[0][0]]]
split_index = N // 2
a00 = [[A[i][j] for j in range(split_index)] for i in range(split_index)]
a01 = [[A[i][j] for j in range(split_index, N)] for i in range(split_index)]
a10 = [[A[i][j] for j in range(split_index)] for i in range(split_index, N)]
a11 = [[A[i][j] for j in range(split_index, N)] for i in range(split_index, N)]
b00 = [[B[i][j] for j in range(split_index)] for i in range(split_index)]
b01 = [[B[i][j] for j in range(split_index, N)] for i in range(split_index)]
b10 = [[B[i][j] for j in range(split_index)] for i in range(split_index, N)]
b11 = [[B[i][j] for j in range(split_index, N)] for i in range(split_index, N)]
p1 = strassen_multiplication(a00, add(b01, b11, split_index, -1))
p2 = strassen_multiplication(add(a00, a01, split_index), b11)
p3 = strassen_multiplication(add(a10, a11, split_index), b00)
p4 = strassen_multiplication(a11, add(b10, b00, split_index, -1))
p5 = strassen_multiplication(add(a00, a11, split_index), add(b00, b11, split_index))
p6 = strassen_multiplication(add(a01, a11, split_index, -1), add(b10, b11, split_index))
p7 = strassen_multiplication(add(a00, a10, split_index, -1), add(b00, b01, split_index))
c00 = add(add(add(p5, p4, split_index), p6, split_index), p2, split_index, -1)
c01 = add(p1, p2, split_index)
c10 = add(p3, p4, split_index)
c11 = add(add(add(p5, p1, split_index), p3, split_index, -1), p7, split_index, -1)
result = [[0] * N for _ in range(N)]
for i in range(split_index):
for j in range(split_index):
result[i][j] = c00[i][j]
result[i][j + split_index] = c01[i][j]
result[i + split_index][j] = c10[i][j]
result[i + split_index][j + split_index] = c11[i][j]
return result
def print_matrix(matrix):
for row in matrix:
print(" ".join(map(str, row)))
A = [[2, 2, 3, 1], [1, 4, 1, 2], [2, 3, 1, 1], [1, 3, 1, 2]]
B = [[2, 1, 2, 1], [3, 1, 2, 1], [3, 2, 1, 1], [1, 4, 3, 2]]
result = strassen_multiplication(A, B)
print_matrix(result)

You can also try this code with Online Python Compiler
Run Code
JS
function add(A, B, splitIndex, multiplier = 1) {
let result = Array.from({ length: splitIndex }, () => Array(splitIndex).fill(0));
for (let i = 0; i < splitIndex; i++) {
for (let j = 0; j < splitIndex; j++) {
result[i][j] = A[i][j] + multiplier * B[i][j];
}
}
return result;
}
function strassenMultiplication(A, B) {
let N = A.length;
if (N === 1) {
return [[A[0][0] * B[0][0]]];
}
let splitIndex = N / 2;
let a00 = [], a01 = [], a10 = [], a11 = [];
let b00 = [], b01 = [], b10 = [], b11 = [];
for (let i = 0; i < splitIndex; i++) {
a00.push(A[i].slice(0, splitIndex));
a01.push(A[i].slice(splitIndex));
a10.push(A[i + splitIndex].slice(0, splitIndex));
a11.push(A[i + splitIndex].slice(splitIndex));
b00.push(B[i].slice(0, splitIndex));
b01.push(B[i].slice(splitIndex));
b10.push(B[i + splitIndex].slice(0, splitIndex));
b11.push(B[i + splitIndex].slice(splitIndex));
}
let p1 = strassenMultiplication(a00, add(b01, b11, splitIndex, -1));
let p2 = strassenMultiplication(add(a00, a01, splitIndex), b11);
let p3 = strassenMultiplication(add(a10, a11, splitIndex), b00);
let p4 = strassenMultiplication(a11, add(b10, b00, splitIndex, -1));
let p5 = strassenMultiplication(add(a00, a11, splitIndex), add(b00, b11, splitIndex));
let p6 = strassenMultiplication(add(a01, a11, splitIndex, -1), add(b10, b11, splitIndex));
let p7 = strassenMultiplication(add(a00, a10, splitIndex, -1), add(b00, b01, splitIndex));
let c00 = add(add(add(p5, p4, splitIndex), p6, splitIndex), p2, splitIndex, -1);
let c01 = add(p1, p2, splitIndex);
let c10 = add(p3, p4, splitIndex);
let c11 = add(add(add(p5, p1, splitIndex), p3, splitIndex, -1), p7, splitIndex, -1);
let result = Array.from({ length: N }, () => Array(N).fill(0));
for (let i = 0; i < splitIndex; i++) {
for (let j = 0; j < splitIndex; j++) {
result[i][j] = c00[i][j];
result[i][j + splitIndex] = c01[i][j];
result[i + splitIndex][j] = c10[i][j];
result[i + splitIndex][j + splitIndex] = c11[i][j];
}
}
return result;
}
function printMatrix(matrix) {
matrix.forEach(row => console.log(row.join(" ")));
}
let A = [[2, 2, 3, 1], [1, 4, 1, 2], [2, 3, 1, 1], [1, 3, 1, 2]];
let B = [[2, 1, 2, 1], [3, 1, 2, 1], [3, 2, 1, 1], [1, 4, 3, 2]];
let result = strassenMultiplication(A, B);
printMatrix(result);

You can also try this code with Online Javascript Compiler
Run Code
C#
using System;
class StrassenMultiplication {
static int[,] Add(int[,] A, int[,] B, int splitIndex, int multiplier = 1) {
int[,] result = new int[splitIndex, splitIndex];
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
result[i, j] = A[i, j] + multiplier * B[i, j];
}
}
return result;
}
static int[,] StrassenMultiply(int[,] A, int[,] B) {
int N = A.GetLength(0);
int[,] C = new int[N, N];
if (N == 1) {
C[0, 0] = A[0, 0] * B[0, 0];
return C;
}
int splitIndex = N / 2;
int[,] a00 = new int[splitIndex, splitIndex];
int[,] a01 = new int[splitIndex, splitIndex];
int[,] a10 = new int[splitIndex, splitIndex];
int[,] a11 = new int[splitIndex, splitIndex];
int[,] b00 = new int[splitIndex, splitIndex];
int[,] b01 = new int[splitIndex, splitIndex];
int[,] b10 = new int[splitIndex, splitIndex];
int[,] b11 = new int[splitIndex, splitIndex];
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
a00[i, j] = A[i, j];
a01[i, j] = A[i, j + splitIndex];
a10[i, j] = A[i + splitIndex, j];
a11[i, j] = A[i + splitIndex, j + splitIndex];
b00[i, j] = B[i, j];
b01[i, j] = B[i, j + splitIndex];
b10[i, j] = B[i + splitIndex, j];
b11[i, j] = B[i + splitIndex, j + splitIndex];
}
}
int[,] p1 = StrassenMultiply(a00, Add(b01, b11, splitIndex, -1));
int[,] p2 = StrassenMultiply(Add(a00, a01, splitIndex), b11);
int[,] p3 = StrassenMultiply(Add(a10, a11, splitIndex), b00);
int[,] p4 = StrassenMultiply(a11, Add(b10, b00, splitIndex, -1));
int[,] p5 = StrassenMultiply(Add(a00, a11, splitIndex), Add(b00, b11, splitIndex));
int[,] p6 = StrassenMultiply(Add(a01, a11, splitIndex, -1), Add(b10, b11, splitIndex));
int[,] p7 = StrassenMultiply(Add(a00, a10, splitIndex, -1), Add(b00, b01, splitIndex));
int[,] c00 = Add(Add(Add(p5, p4, splitIndex), p6, splitIndex), p2, splitIndex, -1);
int[,] c01 = Add(p1, p2, splitIndex);
int[,] c10 = Add(p3, p4, splitIndex);
int[,] c11 = Add(Add(Add(p5, p1, splitIndex), p3, splitIndex, -1), p7, splitIndex, -1);
for (int i = 0; i < splitIndex; i++) {
for (int j = 0; j < splitIndex; j++) {
C[i, j] = c00[i, j];
C[i, j + splitIndex] = c01[i, j];
C[i + splitIndex, j] = c10[i, j];
C[i + splitIndex, j + splitIndex] = c11[i, j];
}
}
return C;
}
static void PrintMatrix(int[,] matrix) {
int size = matrix.GetLength(0);
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
Console.Write(matrix[i, j] + " ");
}
Console.WriteLine();
}
}
static void Main() {
int[,] A = {{2, 2, 3, 1}, {1, 4, 1, 2}, {2, 3, 1, 1}, {1, 3, 1, 2}};
int[,] B = {{2, 1, 2, 1}, {3, 1, 2, 1}, {3, 2, 1, 1}, {1, 4, 3, 2}};
int[,] result = StrassenMultiply(A, B);
PrintMatrix(result);
}
}