Skip to content

Commit 0bcfed6

Browse files
committed
test: add alpha & beta tests for blas/base/dgemm
1 parent 0c31eaf commit 0bcfed6

4 files changed

Lines changed: 87 additions & 0 deletions

File tree

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"transA": "no-transpose",
3+
"transB": "no-transpose",
4+
"M": 2,
5+
"N": 4,
6+
"K": 3,
7+
"alpha": 2.0,
8+
"A": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ],
9+
"strideA1": 3,
10+
"strideA2": 1,
11+
"offsetA": 0,
12+
"B": [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
13+
"strideB1": 4,
14+
"strideB2": 1,
15+
"offsetB": 0,
16+
"beta": 3.0,
17+
"C": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ],
18+
"strideC1": 4,
19+
"strideC2": 1,
20+
"offsetC": 0,
21+
"C_out": [ 15.0, 18.0, 21.0, 24.0, 45.0, 48.0, 51.0, 54.0 ]
22+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"order": "row-major",
3+
"transA": "no-transpose",
4+
"transB": "no-transpose",
5+
"M": 2,
6+
"N": 4,
7+
"K": 3,
8+
"alpha": 2.0,
9+
"A": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 ],
10+
"lda": 3,
11+
"B": [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
12+
"ldb": 4,
13+
"beta": 3.0,
14+
"C": [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 ],
15+
"ldc": 4,
16+
"C_out": [ 15.0, 18.0, 21.0, 24.0, 45.0, 48.0, 51.0, 54.0 ]
17+
}

lib/node_modules/@stdlib/blas/base/dgemm/test/test.dgemm.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ var rtantb = require( './fixtures/row_major_ta_ntb.json' );
4040
var rntatb = require( './fixtures/row_major_nta_tb.json' );
4141
var rtatb = require( './fixtures/row_major_ta_tb.json' );
4242

43+
var rntantbAlpha2Beta3 = require( './fixtures/row_major_nta_ntb_alpha2_beta3.json' );
44+
4345

4446
// TESTS //
4547

@@ -704,3 +706,25 @@ tape( 'if `α` is `0` and `β` is neither `0` nor `1`, the function returns the
704706

705707
t.end();
706708
});
709+
710+
tape( 'the function correctly applies both `α` and `β` scalars (row-major, no-transpose, no-transpose, α=2, β=3)', function test( t ) {
711+
var expected;
712+
var data;
713+
var out;
714+
var a;
715+
var b;
716+
var c;
717+
718+
data = rntantbAlpha2Beta3;
719+
720+
a = new Float64Array( data.A );
721+
b = new Float64Array( data.B );
722+
c = new Float64Array( data.C );
723+
724+
expected = new Float64Array( data.C_out );
725+
726+
out = dgemm( data.order, data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.lda, b, data.ldb, data.beta, c, data.ldc );
727+
t.strictEqual( out, c, 'returns expected value' );
728+
t.deepEqual( out, expected, 'returns expected value' );
729+
t.end();
730+
});

lib/node_modules/@stdlib/blas/base/dgemm/test/test.ndarray.js

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ var rarbrcntantbob = require( './fixtures/ra_rb_rc_nta_ntb_ob.json' );
8383
var rarbrcntantboc = require( './fixtures/ra_rb_rc_nta_ntb_oc.json' );
8484
var cap = require( './fixtures/ra_rb_rc_nta_ntb_complex_access_pattern.json' );
8585

86+
var rarbrcntantbAlpha2Beta3 = require( './fixtures/ra_rb_rc_nta_ntb_alpha2_beta3.json' );
87+
8688

8789
// TESTS //
8890

@@ -1487,3 +1489,25 @@ tape( 'the function supports computation over large arrays (column-major, column
14871489
t.deepEqual( out, expected, 'returns expected value' );
14881490
t.end();
14891491
});
1492+
1493+
tape( 'the function correctly applies both `α` and `β` scalars (row_major, row_major, row_major, no-transpose, no-transpose, α=2, β=3)', function test( t ) {
1494+
var expected;
1495+
var data;
1496+
var out;
1497+
var a;
1498+
var b;
1499+
var c;
1500+
1501+
data = rarbrcntantbAlpha2Beta3;
1502+
1503+
a = new Float64Array( data.A );
1504+
b = new Float64Array( data.B );
1505+
c = new Float64Array( data.C );
1506+
1507+
expected = new Float64Array( data.C_out );
1508+
1509+
out = dgemm( data.transA, data.transB, data.M, data.N, data.K, data.alpha, a, data.strideA1, data.strideA2, data.offsetA, b, data.strideB1, data.strideB2, data.offsetB, data.beta, c, data.strideC1, data.strideC2, data.offsetC );
1510+
t.strictEqual( out, c, 'returns expected value' );
1511+
t.deepEqual( out, expected, 'returns expected value' );
1512+
t.end();
1513+
});

0 commit comments

Comments
 (0)