Skip to content

Commit f5939db

Browse files
feat: implement skeletion for ml/cluster/strided/dkmeansld
1 parent b665247 commit f5939db

2 files changed

Lines changed: 205 additions & 0 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2018 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
/**
22+
* Compute fitted cluster results using Lloyd algorithm.
23+
*
24+
* @module @stdlib/ml/cluster/strided/dkmeansld
25+
*
26+
* @example
27+
* var Float64Array = require( '@stdlib/array/float64' );
28+
* var ndarray = require( '@stdlib/ndarray/ctor' );
29+
* var kmeans = require( '@stdlib/ml/cluster/strided/dkmeansld' );
30+
*
31+
*/
32+
33+
// MAIN //
34+
35+
var main = require( './main.js' );
36+
37+
38+
// EXPORTS //
39+
40+
module.exports = main;
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2026 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
'use strict';
20+
21+
// MODULES //
22+
23+
var dlacpy = require( '@stdlib/lapack/base/dlacpy' ).ndarray;
24+
var Float64Array = require( '@stdlib/array/float64' );
25+
var Int32Array = require( '@stdlib/array/int32' );
26+
var dfill = require( '@stdlib/blas/ext/base/dfill' );
27+
var isEqualArray = require( '@stdlib/assert/is-equal-array' );
28+
var dcopy = require( '@stdlib/blas/base/dcopy' ).ndarray;
29+
var deuclidean = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray;
30+
var dcosine = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray;
31+
var dcityblock = require( '@stdlib/stats/strided/distances/deuclidean' ).ndarray;
32+
33+
34+
// MAIN //
35+
36+
/**
37+
* Compute fitted cluster results using Lloyd algorithm.
38+
* @param {PositiveInteger} M - number of samples
39+
* @param {PositiveInteger} N - number of features
40+
* @param {PositiveInteger} k - number of clusters
41+
* @param {NonNegativeInteger} replicates - number of times to repeat clustering with different centroids
42+
* @param {String} metric - distance metric
43+
* @param {NonNegativeInteger} maxIter - maximum number of iterations.
44+
* @param {integer} tol - relative tolerance before declaring convergence.
45+
* @param {Float64Array} X - input strided matrix
46+
* @param {integer} strideX1 - stride of the first dimension.
47+
* @param {integer} strideX2 - stride of the second dimension.
48+
* @param {integer} offsetX - starting index.
49+
* @param {Float64Array} init - strided array containing initial centroid locations.
50+
* @param {integer} strideInit1 - stride of first dimension.
51+
* @param {integer} strideInit2 - stride of second dimension.
52+
* @param {integer} strideInit3 - stride of the third dimension.
53+
* @param {integer} offsetInit - initial index.
54+
* @ returns {Result} results object
55+
*/
56+
function dkmeansld( M, N, k, replicates, metric, maxIter, tol, X, strideX1, strideX2, offsetX, init, strideInit1, strideInit2, strideInit3, offsetInit ) { // eslint-disable-line max-len
57+
var centroidShift;
58+
var centroidsNew;
59+
var strictConv;
60+
var labelsOld;
61+
var centroids;
62+
var bestDist;
63+
var inertia;
64+
var labels;
65+
var counts;
66+
var shift;
67+
var same;
68+
var dist;
69+
var best;
70+
var iter;
71+
var out;
72+
var ox;
73+
var i;
74+
var j;
75+
var c;
76+
var d;
77+
78+
centroids = new Float64Array( k*N );
79+
centroidsNew = new Float64Array( k*N );
80+
labels = new Int32Array( M );
81+
labelsOld = new Int32Array( M );
82+
counts = new Int32Array( k ); // q: sklearn supports sample_weights, should we do the same? if yes, change it to Float64Array
83+
84+
// centroidShift = new Float64Array( k );
85+
86+
dlacpy( 'all', k, N, init, strideInit2, strideInit3, offsetInit, centroids, strideInit2, strideInit3, 0 );
87+
88+
if ( metric === 'euclidean' ) {
89+
dist = deuclidean; // TODO: change it to dsquared-euclidean once implemented
90+
} else if ( metric === 'cosine' ) {
91+
dist = dcosine; // TODO: change it to dsquared-cosine once implemented
92+
} else if ( metric === 'cityblock' ) {
93+
dist = dcityblock;
94+
}
95+
96+
// this is a dense implementation, sklearn also has a sparse implementation
97+
// https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L696C1-L700C46
98+
strictConv = false;
99+
for ( iter = 0; iter < maxIter; iter++ ) {
100+
dfill( k*N, 0.0, centroidsNew, 1 );
101+
dfill( k, 0, counts, 1 ); // How do I fill it with a int32?
102+
103+
ox = offsetX;
104+
for ( i = 0; i < M; i++ ) {
105+
best = 0;
106+
bestDist = dist( N, X, strideX2, ox, centroids, 1, 0 );
107+
for ( c = 1; c < k; c++ ) {
108+
d = dist( N, X, strideX2, ox, centroids, 1, c*N );
109+
if ( d < bestDist ) {
110+
bestDist = d;
111+
best = c;
112+
}
113+
}
114+
115+
labels[ i ] = best;
116+
counts[ best ] += 1;
117+
ox += strideX1;
118+
119+
for ( j = 0; j < N; j++ ) {
120+
centroidsNew[ ( best*N )+j ] += X[ offsetX + (i*strideX1) + (j*strideX2) ]; // eslint-disable-line max-len
121+
}
122+
}
123+
124+
for ( c = 0; c < k; c++ ) {
125+
if ( counts[ c ] > 0 ) {
126+
for ( j = 0; j < N; j++ ) {
127+
centroidsNew[ ( c*N )+j ] /= counts[ c ];
128+
}
129+
} else {
130+
for ( j = 0; j < N; j++ ) {
131+
centroidsNew[ ( c*N )+j ] = centroids[ ( c*N )+j ];
132+
}
133+
}
134+
}
135+
136+
d = centroidsNew[ 0 ] - centroids[ 0 ];
137+
shift = d * d;
138+
for ( i = 1; i < k * N; i++ ) {
139+
d = centroidsNew[ i ] - centroids[ i ];
140+
shift += d * d;
141+
centroids[ i ] = centroidsNew[ i ];
142+
}
143+
144+
if ( isEqualArray( labels, labelsOld ) ) {
145+
strictConv = true;
146+
break;
147+
} else {
148+
// TODO: implement center shift
149+
}
150+
dcopy( M, labels, 1, 0, labelsOld, 1, 0 ); // Magic number `1` and `0` because we assume labels are stored contiguously
151+
}
152+
153+
if (!strictConv) {
154+
// TODO: Rerun the E-step
155+
}
156+
157+
// TODO: Compute intertia
158+
159+
return out; // TODO: create a results object similar to stats/base/ztest/two-sample/results/factory
160+
}
161+
162+
163+
// EXPORTS //
164+
165+
module.exports = dkmeansld;

0 commit comments

Comments
 (0)