Skip to content

Commit 56baf84

Browse files
committed
Updated files
1 parent 1858c69 commit 56baf84

2 files changed

Lines changed: 113 additions & 109 deletions

File tree

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,51 @@
11
package com.thealgorithms.matrix;
22

33
public final class MatrixMultiplication {
4-
private MatrixMultiplication() {}
5-
6-
/**
7-
* Multiplies two matrices.
8-
*
9-
* @param matrixA the first matrix rowsA x colsA
10-
* @param matrixB the second matrix rowsB x colsB
11-
* @return the product of the two matrices rowsA x colsB
12-
* @throws IllegalArgumentException if the matrices cannot be multiplied
13-
*/
14-
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
15-
// Check the input matrices are not null
16-
if (matrixA == null || matrixB == null) {
17-
throw new IllegalArgumentException("Input matrices cannot be null");
4+
private MatrixMultiplication() {
185
}
196

20-
// Check for empty matrices
21-
if (matrixA.length == 0
22-
|| matrixB.length == 0
23-
|| matrixA[0].length == 0
24-
|| matrixB[0].length == 0) {
25-
throw new IllegalArgumentException("Input matrices must not be empty");
26-
}
7+
/**
8+
* Multiplies two matrices.
9+
*
10+
* @param matrixA the first matrix rowsA x colsA
11+
* @param matrixB the second matrix rowsB x colsB
12+
* @return the product of the two matrices rowsA x colsB
13+
* @throws IllegalArgumentException if the matrices cannot be multiplied
14+
*/
15+
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
16+
// Check the input matrices are not null
17+
if (matrixA == null || matrixB == null) {
18+
throw new IllegalArgumentException("Input matrices cannot be null");
19+
}
2720

28-
// Validate the matrix dimensions
29-
if (matrixA[0].length != matrixB.length) {
30-
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
31-
}
21+
// Check for empty matrices
22+
if (matrixA.length == 0
23+
|| matrixB.length == 0
24+
|| matrixA[0].length == 0
25+
|| matrixB[0].length == 0) {
26+
throw new IllegalArgumentException("Input matrices must not be empty");
27+
}
28+
29+
// Validate the matrix dimensions
30+
if (matrixA[0].length != matrixB.length) {
31+
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
32+
}
3233

33-
int rowsA = matrixA.length;
34-
int colsA = matrixA[0].length;
35-
int colsB = matrixB[0].length;
34+
int rowsA = matrixA.length;
35+
int colsA = matrixA[0].length;
36+
int colsB = matrixB[0].length;
3637

37-
// Initialize the result matrix with zeros
38-
double[][] result = new double[rowsA][colsB];
38+
// Initialize the result matrix with zeros
39+
double[][] result = new double[rowsA][colsB];
3940

40-
// Perform matrix multiplication
41-
for (int i = 0; i < rowsA; i++) {
42-
for (int j = 0; j < colsB; j++) {
43-
for (int k = 0; k < colsA; k++) {
44-
result[i][j] += matrixA[i][k] * matrixB[k][j];
41+
// Perform matrix multiplication
42+
for (int i = 0; i < rowsA; i++) {
43+
for (int j = 0; j < colsB; j++) {
44+
for (int k = 0; k < colsA; k++) {
45+
result[i][j] += matrixA[i][k] * matrixB[k][j];
46+
}
47+
}
4548
}
46-
}
49+
return result;
4750
}
48-
return result;
49-
}
5051
}

src/test/java/com/thealgorithms/matrix/MatrixMultiplicationTest.java

Lines changed: 74 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,76 +7,79 @@
77

88
public class MatrixMultiplicationTest {
99

10-
private static final double EPSILON = 1e-9; // for floating point comparison
11-
12-
@Test
13-
void testMultiply2by2() {
14-
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
15-
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
16-
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
17-
18-
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
19-
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
20-
}
21-
22-
@Test
23-
void testMultiply3by2and2by1() {
24-
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
25-
double[][] matrixB = {{7.0}, {8.0}};
26-
double[][] expected = {{23.0}, {53.0}, {83.0}};
27-
28-
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
29-
assertMatrixEquals(expected, result);
30-
}
31-
32-
@Test
33-
void testNullMatrixA() {
34-
double[][] b = {{1, 2}, {3, 4}};
35-
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
36-
}
37-
38-
@Test
39-
void testNullMatrixB() {
40-
double[][] a = {{1, 2}, {3, 4}};
41-
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
42-
}
43-
44-
@Test
45-
void testMultiplyNull() {
46-
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
47-
double[][] matrixB = null;
48-
49-
Exception exception =
50-
assertThrows(
51-
IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));
52-
53-
String expectedMessage = "Input matrices cannot be null";
54-
String actualMessage = exception.getMessage();
55-
56-
assertTrue(actualMessage.contains(expectedMessage));
57-
}
58-
59-
@Test
60-
void testIncompatibleDimensions() {
61-
double[][] a = {{1.0, 2.0}};
62-
double[][] b = {{1.0, 2.0}};
63-
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
64-
}
65-
66-
@Test
67-
void testEmptyMatrices() {
68-
double[][] a = new double[0][0];
69-
double[][] b = new double[0][0];
70-
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
71-
}
72-
73-
private void assertMatrixEquals(double[][] expected, double[][] actual) {
74-
assertEquals(expected.length, actual.length, "Row count mismatch");
75-
for (int i = 0; i < expected.length; i++) {
76-
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
77-
for (int j = 0; j < expected[i].length; j++) {
78-
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
79-
}
10+
private static final double EPSILON = 1e-9; // for floating point comparison
11+
12+
@Test
13+
void testMultiply2by2() {
14+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
15+
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
16+
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
17+
18+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
19+
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
20+
}
21+
22+
@Test
23+
void testMultiply3by2and2by1() {
24+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
25+
double[][] matrixB = {{7.0}, {8.0}};
26+
double[][] expected = {{23.0}, {53.0}, {83.0}};
27+
28+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
29+
assertMatrixEquals(expected, result);
30+
}
31+
32+
@Test
33+
void testNullMatrixA() {
34+
double[][] b = {{1, 2}, {3, 4}};
35+
assertThrows(IllegalArgumentException.class,
36+
() -> MatrixMultiplication.multiply(null, b));
37+
}
38+
39+
@Test
40+
void testNullMatrixB() {
41+
double[][] a = {{1, 2}, {3, 4}};
42+
assertThrows(IllegalArgumentException.class,
43+
() -> MatrixMultiplication.multiply(a, null));
44+
}
45+
46+
@Test
47+
void testMultiplyNull() {
48+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
49+
double[][] matrixB = null;
50+
51+
Exception exception = assertThrows(IllegalArgumentException.class,
52+
() -> MatrixMultiplication.multiply(matrixA, matrixB));
53+
54+
String expectedMessage = "Input matrices cannot be null";
55+
String actualMessage = exception.getMessage();
56+
57+
assertTrue(actualMessage.contains(expectedMessage));
58+
}
59+
60+
@Test
61+
void testIncompatibleDimensions() {
62+
double[][] a = {{1.0, 2.0}};
63+
double[][] b = {{1.0, 2.0}};
64+
assertThrows(IllegalArgumentException.class,
65+
() -> MatrixMultiplication.multiply(a, b));
66+
}
67+
68+
@Test
69+
void testEmptyMatrices() {
70+
double[][] a = new double[0][0];
71+
double[][] b = new double[0][0];
72+
assertThrows(IllegalArgumentException.class,
73+
() -> MatrixMultiplication.multiply(a, b));
74+
}
75+
76+
private void assertMatrixEquals(double[][] expected, double[][] actual) {
77+
assertEquals(expected.length, actual.length, "Row count mismatch");
78+
for (int i = 0; i < expected.length; i++) {
79+
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
80+
for (int j = 0; j < expected[i].length; j++) {
81+
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
82+
}
83+
}
8084
}
81-
}
8285
}

0 commit comments

Comments
 (0)