Skip to content

Commit 2e0647f

Browse files
committed
expose vector_distance_l2 func
- we had it before but it's harder to add tests for l2 metric without it
1 parent 562680c commit 2e0647f

3 files changed

Lines changed: 35 additions & 7 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
133133
return vectorF32DistanceL2(pVector1, pVector2);
134134
case VECTOR_TYPE_FLOAT64:
135135
return vectorF64DistanceL2(pVector1, pVector2);
136+
case VECTOR_TYPE_FLOAT8:
137+
return vectorF8DistanceL2(pVector1, pVector2);
136138
default:
137139
assert(0);
138140
}
@@ -928,13 +930,11 @@ static void vectorExtractFunc(
928930
}
929931
}
930932

931-
/*
932-
** Implementation of vector_distance_cos(X, Y) function.
933-
*/
934-
static void vectorDistanceCosFunc(
933+
static void vectorDistanceFunc(
935934
sqlite3_context *context,
936935
int argc,
937-
sqlite3_value **argv
936+
sqlite3_value **argv,
937+
float (*vectorDistance)(const Vector *pVector1, const Vector *pVector2)
938938
){
939939
char *pzErrMsg = NULL;
940940
Vector *pVector1 = NULL, *pVector2 = NULL;
@@ -983,7 +983,7 @@ static void vectorDistanceCosFunc(
983983
sqlite3_free(pzErrMsg);
984984
goto out_free;
985985
}
986-
sqlite3_result_double(context, vectorDistanceCos(pVector1, pVector2));
986+
sqlite3_result_double(context, vectorDistance(pVector1, pVector2));
987987
out_free:
988988
if( pVector2 ){
989989
vectorFree(pVector2);
@@ -993,6 +993,20 @@ static void vectorDistanceCosFunc(
993993
}
994994
}
995995

996+
/*
997+
** Implementation of vector_distance_cos(X, Y) function.
998+
*/
999+
static void vectorDistanceCosFunc(sqlite3_context *context, int argc, sqlite3_value **argv){
1000+
vectorDistanceFunc(context, argc, argv, vectorDistanceCos);
1001+
}
1002+
1003+
/*
1004+
** Implementation of vector_distance_l2(X, Y) function.
1005+
*/
1006+
static void vectorDistanceL2Func(sqlite3_context *context, int argc, sqlite3_value **argv){
1007+
vectorDistanceFunc(context, argc, argv, vectorDistanceL2);
1008+
}
1009+
9961010
/*
9971011
* Marker function which is used in index creation syntax: CREATE INDEX idx ON t(libsql_vector_idx(emb));
9981012
*/
@@ -1013,6 +1027,7 @@ void sqlite3RegisterVectorFunctions(void){
10131027
FUNCTION(vector8, 1, 0, 0, vector8Func),
10141028
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
10151029
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
1030+
FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func),
10161031

10171032
FUNCTION(libsql_vector_idx, -1, 0, 0, libsqlVectorIdx),
10181033
};

libsql-sqlite3/src/vectorInt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *);
120120
* Calculates L2 distance between two vectors (vector must have same type and same dimensions)
121121
*/
122122
float vectorDistanceL2 (const Vector *, const Vector *);
123+
float vectorF8DistanceL2 (const Vector *, const Vector *);
123124
float vectorF32DistanceL2 (const Vector *, const Vector *);
124125
double vectorF64DistanceL2(const Vector *, const Vector *);
125126

libsql-sqlite3/src/vectorfloat8.c

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,23 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){
123123
}
124124

125125
float vectorF8DistanceL2(const Vector *v1, const Vector *v2){
126+
int i;
127+
float alpha1, shift1, alpha2, shift2;
128+
float sum = 0;
129+
u8 *data1 = v1->data, *data2 = v2->data;
130+
126131
assert( v1->dims == v2->dims );
127132
assert( v1->type == VECTOR_TYPE_FLOAT8 );
128133
assert( v2->type == VECTOR_TYPE_FLOAT8 );
129134

130-
assert( 0 );
135+
vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1);
136+
vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2);
137+
138+
for(i = 0; i < v1->dims; i++){
139+
float d = (alpha1 * data1[i] + shift1) - (alpha2 * data2[i] + shift2);
140+
sum += d*d;
141+
}
142+
return sqrt(sum);
131143
}
132144

133145
void vectorF8DeserializeFromBlob(

0 commit comments

Comments
 (0)