Code360 powered by Coding Ninjas X Naukri.com. Code360 powered by Coding Ninjas X Naukri.com
Table of contents
1.
Introduction
2.
What is Strassen’s Matrix Multiplication Algorithm
3.
Problem Statement
4.
Example
5.
Brute force Approach
5.1.
Pseudocode
5.2.
Implementation in C++
5.3.
Time Complexity
5.4.
Space Complexity
6.
Optimized Approach
6.1.
Pseudocode
6.2.
Implementation in C++
6.3.
Time Complexity
6.4.
Space Complexity
7.
Frequently Asked Questions
7.1.
What is Strassen’s Matrix Multiplication?
7.2.
Which technique is used in Strassen's matrix?
7.3.
Why Strassen's matrix multiplication is better than ordinary matrix multiplication?
7.4.
What are the 4 methods of matrix?
8.
Conclusion
Last Updated: Apr 14, 2024
Medium

Strassen’s Matrix Multiplication

Author Harsh
0 upvote

Introduction

Strassen's Matrix Multiplication is a divide-and-conquer technique used to efficiently solve matrix multiplication problems.

Multiplication of two matrices requires O(N^3) running time but we can reduce this time to O(N^2.81) by using an efficient approach which is known as Strassen Matrix multiplication.

In this blog,  we will implement that efficient Strassen’s Matrix Multiplication approach in detail to solve our problem. 

Recommended Topic, Array Implementation of Queue

What is Strassen’s Matrix Multiplication Algorithm

Here's how Strassen's Matrix Multiplication Algorithm works:

  • Divide: Take the two matrices you want to multiply, let's call them A and B. Split them into four smaller matrices, each about half the size of the original matrices.
  • Calculate: Use these smaller matrices to calculate seven special values, which we'll call P1, P2, P3, P4, P5, P6, and P7. You do this by doing some simple additions and subtractions of the smaller matrices.
  • Combine: Take these seven values and use them to compute the final result matrix, which we'll call C. You calculate the values of C using the values of P1 to P7.

This method may sound a bit more complicated, but it's faster for really big matrices because it reduces the number of multiplications you need to do, even though it involves more additions and subtractions. For smaller matrices, the regular multiplication is faster, but for huge matrices, Strassen's method can save a lot of time.

Get the tech career you deserve, faster!
Connect with our expert counsellors to understand how to hack your way to success
User rating 4.7/5
1:1 doubt support
95% placement record
Akash Pal
Senior Software Engineer
326% Hike After Job Bootcamp
Himanshu Gusain
Programmer Analyst
32 LPA After Job Bootcamp
After Job
Bootcamp

Problem Statement

Find the multiplication matrix of two square matrices A and B of size n x n each.

Example

Input

Strassen’s Matrix Multiplication Examples

Output

Matrix Multiplication Examples

Explanation

The result we get is after the multiplication of Matrix A and B.

Brute force Approach

The idea is to use 3 nested loops to calculate the value for each cell individually.

Pseudocode

Algorithm MULTIPLY_MATRIX(A, B, C)
for i <- 1 to n do
  for j <- 1 to n do
    C[i][j] <- 0
    for k <- 1 to n do
      C[i][j] <- C[i][j] + A[i][k]*B[k][j]
    end
  end
end

Implementation in C++

#include <bits/stdc++.h>
using namespace std;

// print the matrix
void print(vector<vector<int> > matrix) {
	for(int i = 0; i < matrix.size(); i++){
        for(int j = 0; j < matrix[i].size(); j++){
            cout << matrix[i][j] << ' ';
        }
        cout << endl;
    }
}

void multiply(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &C) {
	int N = 4;
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            C[i][j] = 0;
            for (int k = 0; k < N; k++) {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}

int main() {

    // Input Matrix A
	vector<vector<int>> A = {{2, 2, 3, 1},{1, 4, 1, 2},{2, 3, 1, 1}, {1, 3, 1, 2}};
    
	// Input Matrix B
    vector<vector<int>> B = {{2, 1, 2, 1},{3, 1, 2, 1},{3, 2, 1, 1}, {1, 4, 3, 2}};

	vector<vector<int>> C(4, vector<int>(4));
    
	multiply(A, B, C);

    // Printing the result
	print(C);
	return 0;
}

 

Output

20 14 14 9 
19 15 17 10 
17 11 14 8 
16 14 15 9

Time Complexity

There are three for loops that enclose the inner statement. So, The Running time of the above algorithm is O(N^3).

Space Complexity

A new matrix is used to store the result of the multiplication. So, the space complexity is O(N^2).

Optimized Approach

We can implement Strassen’s Matrix Multiplication and the idea is to use the divide and conquer approach to divide the matrices into sub-matrices of size N/2 and then solve these sub-matrices using a formula given by Strassen's method

Pseudocode

// Algorithm to calculate multiplication of two matrices
// Here A and B are the input Matrices
// and C is the output Matrix and n represents the size
Algorithm STRASSEN_METHOD (A, B, C, int n)
if n == 1 then
  C = C + (A) * (B)
else
  STRASSEN_METHOD (A, B, C, n/4)
  STRASSEN_METHOD (A, B + (n/4), C + (n/4), n/4)
  STRASSEN_METHOD (A + 2 * (n/4), B, C + 2 * (n/4), n/4)
  STRASSEN_METHOD (A + 2 * (n/4), B + (n/4), C + 3 * (n/4), n/4)
  STRASSEN_METHOD (A + (n/4), B + 2 * (n/4), C, n/4)
  STRASSEN_METHOD (A + (n/4), B + 3 * (n/4), C + (n/4), n/4)
  STRASSEN_METHOD (A + 3 * (n/4), B + 2 * (n/4), C + 2 * (n/4), n/4)
  STRASSEN_METHOD (A + 3 * (n/4), B + 3 * (n/4), C + 3 * (n/4), n/4)
end

Implementation in C++

// Strassen’s Matrix Multiplication
#include <bits/stdc++.h>
using namespace std;

// Size of two matrices
#define ROW_1 4
#define COL_1 4
#define ROW_2 4
#define COL_2 4

// print the matrix
void print(vector<vector<int> > matrix) {
	for(int i = 0; i < matrix.size(); i++){
        for(int j = 0; j < matrix[i].size(); j++){
            cout << matrix[i][j] << ' ';
        }
        cout << endl;
    }
}

// Add two matrices and return the result
vector<vector<int>> add(vector<vector<int> > A, vector<vector<int> > B, int split_index, int multiplier = 1) {
	for (auto i = 0; i < split_index; i++)
		for (auto j = 0; j < split_index; j++)
			A[i][j] = A[i][j] + (multiplier * B[i][j]);
	return A;
}

vector<vector<int> >
strassen_multiplication(vector<vector<int> > A, vector<vector<int> > B) {
	
	// calculating the size of matrix
	int col_1 = A[0].size();
	int row_1 = A.size();
	int col_2 = B[0].size();
	int row_2 = B.size();

	// checking if multiplication is possible or not 
	// between the input matrices
	if (col_1 != row_2) {
		cout << "The Two Matrices cannot be multiplied";
		return {};
	}

	// creating an empty matrix to store the result
	vector<int> result_row(col_2, 0);
	vector<vector<int> > result(row_1, result_row);

	// Base case 
	// if size of matrix is 1
	if (col_1 == 1)
		result[0][0]
			= A[0][0] * B[0][0];
	else {

		// split index
		int split_index = col_1 / 2;

		vector<int> row_vector(split_index, 0);

		// Splitting the matrices in sub matrices
		vector<vector<int> > a00(split_index, row_vector);
		vector<vector<int> > a01(split_index, row_vector);
		vector<vector<int> > a10(split_index, row_vector);
		vector<vector<int> > a11(split_index, row_vector);
		vector<vector<int> > b00(split_index, row_vector);
		vector<vector<int> > b01(split_index, row_vector);
		vector<vector<int> > b10(split_index, row_vector);
		vector<vector<int> > b11(split_index, row_vector);

		// calculating and storing the result
		// inside our quadrants
		for (auto i = 0; i < split_index; i++)
			for (auto j = 0; j < split_index; j++) {
				a00[i][j] = A[i][j];
				a01[i][j] = A[i][j + split_index];
				a10[i][j] = A[split_index + i][j];
				a11[i][j] = A[i + split_index]
									[j + split_index];
				b00[i][j] = B[i][j];
				b01[i][j] = B[i][j + split_index];
				b10[i][j] = B[split_index + i][j];
				b11[i][j] = B[i + split_index]
									[j + split_index];
			}

		// Calculating the multiplication using the formula
		// given by strassent algorithm
		vector<vector<int>> p1(
			strassen_multiplication(a00, add(b01, b11, split_index, -1))
		);
		vector<vector<int>> p2(
			strassen_multiplication(add(a00, a01, split_index), b11)
		);
		vector<vector<int>> p3(
			strassen_multiplication(add(a10, a11, split_index), b00)
		);
		vector<vector<int>> p4(
			strassen_multiplication(a11, add(b10, b00, split_index, -1))
		);
		vector<vector<int>> p5(
			strassen_multiplication(add(a00, a11, split_index),add(b00, b11, split_index))
		);
		vector<vector<int>> p6(
			strassen_multiplication(add(a01, a11, split_index, -1),add(b10, b11, split_index))
		);
		vector<vector<int>> p7(
			strassen_multiplication(
				add(a00, a10, split_index, -1),
				add(b00, b01, split_index)
			)
		);

		// calculating the result
		vector<vector<int> > result_00(
			add(add(add(p5, p4, split_index), p6, split_index), p2, split_index, -1)
		);
		vector<vector<int> > result_01(
			add(p1, p2, split_index)
		);
		vector<vector<int> > result_10(	
			add(p3, p4, split_index)
		);
		vector<vector<int> > result_11(
			add(add(add(p5, p1, split_index), p3, split_index, -1), p7, split_index, -1)
		);

		// calulating and storing the result
		// inside matrix
		for (auto i = 0; i < split_index; i++){
			for (auto j = 0; j < split_index; j++) {
				result[i][j] = result_00[i][j];
				result[i][j + split_index] = result_01[i][j];
				result[split_index + i][j] = result_10[i][j];
				result[i + split_index][j + split_index] = result_11[i][j];
			}
		}

		// clearing all the arrays
		a00.clear();
		a01.clear();
		a10.clear();
		a11.clear();
		b00.clear();
		b01.clear();
		b10.clear();
		b11.clear();
		p1.clear();
		p2.clear();
		p3.clear();
		p4.clear();
		p5.clear();
		p6.clear();
		p7.clear();
		result_00.clear();
		result_01.clear();
		result_10.clear();
		result_11.clear();
	}
	return result;
}

int main() {

    // Input Matrix A
	vector<vector<int>> A = {{2, 2, 3, 1},{1, 4, 1, 2},{2, 3, 1, 1}, {1, 3, 1, 2}};
    
	// Input Matrix B
    vector<vector<int>> B = {{2, 1, 2, 1},{3, 1, 2, 1},{3, 2, 1, 1}, {1, 4, 3, 2}};

    // Getting the result
	vector<vector<int> > result(strassen_multiplication(A, B));

    // Printing the result
	print(result);
}

Output

20 14 14 9 
19 15 17 10 
17 11 14 8 
16 14 15 9

Time Complexity

Using strassen's matrix multiplication method we can split the problem of size n into 7 subproblems of size (n - 2).
The recurrence equation for strassen's matrix multiplication method is T(n) = 7.T(n/2). After solving the recurrence relation we get O(n^2.81) as the running time of Strassen’s matrix multiplication algorithm.

Space Complexity

A new matrix is used to store the result of the multiplication. So, the space complexity of Strassen’s matrix multiplication method is O(N^2).

Read More - Time Complexity of Sorting Algorithms

Frequently Asked Questions

What is Strassen’s Matrix Multiplication?

The Strassen algorithm is a recursive method for matrix multiplication in which each recursive step divides the matrix into four submatrices of dimensions n/2 x n/2.

Which technique is used in Strassen's matrix?

Strassen's matrix is an efficient technique that is used in matrix multiplication. This technique used the divide and conquer approach which reduces the number of calculations in matrix multiplications. The larger matrices get decomposed into smaller matrices which is an efficient approach. 

Why Strassen's matrix multiplication is better than ordinary matrix multiplication?

Strassen's matrix multiplication is better than ordinary matrix multiplication because of its approach. The approach is divide and conquer which lesser the number of calculations which ultimately saves time and also exhibits better cache memory than any ordinary matrix multiplication. 

What are the 4 methods of matrix?

The popular 4 methods of matrix multiplication are standard matrix multiplication which is the naive method, the second is matrix chain multiplication, the next is Strassen's matrix multiplication and the fourth is block matrix multiplication. 

Conclusion

In this article, we have extensively discussed a coding problem where we have to multiply two matrices and we used two different approaches to solve the problem, one is the brute force approach and the other is Strassen’s Matrix Multiplication method. We hope that this blog has helped you enhance your knowledge about the above question and if you would like to learn more. Check out more of our blogs related to coding questions First non-repeating character in a streamHow to efficiently implement k Queues in a single arraySorting of Queue, and many more on our Website.

Recommended Problems:

You can also check Interview Experiences and Interview Preparation Resources if you are interested in cracking the technical interviews at top Product-based companies like Amazon, Microsoft, Uber, etc.

Live masterclass