Skip to content

Commit 6273473

Browse files
feat: add kmeans high-level api skeleton
1 parent f5939db commit 6273473

2 files changed

Lines changed: 153 additions & 1 deletion

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/* eslint-disable valid-jsdoc */
2+
/**
3+
* @license Apache-2.0
4+
*
5+
* Copyright (c) 2026 The Stdlib Authors.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
20+
'use strict';
21+
22+
// MODULES //
23+
24+
var dkmeanselk = require( '@stdlib/ml/cluster/strided/dkmeanselk' );
25+
var dkmeansld = require( '@stdlib/ml/cluster/strided/dkmeansld' );
26+
var setReadOnly = require( '@stdlib/utils/define-nonenumerable-read-only-property' );
27+
var isMatrixLike = require( '@stdlib/assert/is-matrix-like' );
28+
var isInteger = require( '@stdlib/assert/is-integer' );
29+
var format = require( '@stdlib/string/format' );
30+
var initCentroids = require( './init_centroids.js' );
31+
32+
33+
// MAIN //
34+
35+
/**
36+
* Kmeans clustering.
37+
*
38+
* @private
39+
* @param {PositiveInteger} k - number of clusters
40+
* @param {(string|ndarray)} init - initialization method or initial centroids
41+
* @param {(PositiveInteger|string)} replicates - number of replicates or 'auto'
42+
* @throws {TypeError} first argument must be a positive integer
43+
* @throws {TypeError} second argument must be a valid initialization method or matrix
44+
* @throws {TypeError} third argument must be a positive integer or 'auto'
45+
* @returns {Function} fit function
46+
*
47+
* @example
48+
* var Float64Array = require( '@stdlib/array/float64' );
49+
* var ndarray = require( '@stdlib/ndarray/ctor' );
50+
* var kmeans = require( '@stdlib/ml/cluster/strided/dkmeansld' );
51+
*
52+
*/
53+
function dkmeans( k, init, replicates, maxIter, tol, metric, algorithm ) { // This will live in ml/cluster/kmeans/ctor, kept here just for reference
54+
// TODO: refactor functions arguments and include a `options` argument to follow same pattern as `ml/incr/kmeans`
55+
var model;
56+
var reps;
57+
58+
// TODO: validate function arguments
59+
60+
if ( replicates === 'auto' ) {
61+
if ( init === 'kmeans++' || isMatrixLike( init ) ) {
62+
reps = 1;
63+
} else if ( init === 'random' ) {
64+
// reps = ??
65+
} else if ( init === 'forgy' ) {
66+
// reps = ??
67+
}
68+
} else if ( isInteger( replicates ) ) {
69+
reps = replicates;
70+
} else {
71+
throw new TypeError( format( 'invalid argument. Argument specifying method for initialization must be either `kmeans++`, `random`, `forgy` or matrix specifying initial centroids. Value: `%s`.', init ) );
72+
}
73+
74+
// TODO: update the below attachment to follow similar pattern to stats/strided/ztests
75+
setReadOnly( model, 'fit', fit );
76+
77+
return model;
78+
79+
/**
80+
* Computes fitted cluster results using kmeans clustering.
81+
*
82+
* @private
83+
* @param {MatrixLike} X - input data matrix
84+
* @throws {TypeError} first argument must be a matrix-like object
85+
* @returns {Object} clustering results
86+
*
87+
* @example
88+
* var Float64Array = require( '@stdlib/array/float64' );
89+
* var ndarray = require( '@stdlib/ndarray/ctor' );
90+
*
91+
*/
92+
function fit( X ) {
93+
var kmeansSingle;
94+
var centroids;
95+
var singleOut;
96+
var out;
97+
var sx1;
98+
var sx2;
99+
var ox;
100+
var M;
101+
var N;
102+
var i;
103+
104+
// TODO: Step 1 : validate input matrix
105+
106+
// TODO: Step 2 : define arguments
107+
M = X.shape[ 0 ];
108+
N = X.shape[ 1 ];
109+
sx1 = X.stride[ 0 ];
110+
sx2 = X.stride[ 1 ];
111+
ox = X.offset;
112+
113+
/**
114+
* NOTE : M should be greater than k (M > k)
115+
* ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L876
116+
*/
117+
118+
if ( algorithm === 'elkan' ) {
119+
kmeansSingle = dkmeanselk;
120+
} else if ( algorithm === 'lloyd' ) {
121+
kmeansSingle = dkmeansld;
122+
}
123+
124+
for ( i = 0; i < reps; i++ ) {
125+
centroids = initCentroids( X, init, k ); // ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L961
126+
singleOut = kmeansSingle( M, N, k, metric, maxIter, tol, X, sx1, sx2, ox, centroids, k, N, 0 ); // magic number `0` because we generate the centroid array with no offset
127+
128+
/**
129+
* According to sklearn, `singleOut` should be { labels, inertia, centers, nIter }
130+
* ref: https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L1531
131+
* ??? How should we handle this ???
132+
*/
133+
}
134+
135+
/**
136+
* TODO : Check convergence issue
137+
* ref : https://github.com/scikit-learn/scikit-learn/blob/d3898d9d57aeb1e960d266613a2e31b07bca39d7/sklearn/cluster/_kmeans.py#L1545
138+
*/
139+
140+
/**
141+
* TODO : Build the `out` object
142+
* ref : https://github.com/stdlib-js/stdlib/pull/9703#discussion_r2681280854
143+
*/
144+
145+
return out;
146+
}
147+
}
148+
149+
150+
// EXPORTS //
151+
152+
module.exports = dkmeans;

lib/node_modules/@stdlib/ml/cluster/strided/dkmeansld/lib/index.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* @license Apache-2.0
33
*
4-
* Copyright (c) 2018 The Stdlib Authors.
4+
* Copyright (c) 2026 The Stdlib Authors.
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
77
* you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)