Skip to content

Commit 244a34a

Browse files
committed
fix bug in cosine distance
1 parent 8a2a019 commit 244a34a

3 files changed

Lines changed: 14 additions & 25 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -538,22 +538,7 @@ void vectorSerializeWithMeta(
538538
return;
539539
}
540540

541-
switch (pVector->type) {
542-
case VECTOR_TYPE_FLOAT32:
543-
vectorF32SerializeToBlob(pVector, pBlob, nDataSize);
544-
break;
545-
case VECTOR_TYPE_FLOAT64:
546-
vectorF64SerializeToBlob(pVector, pBlob, nDataSize);
547-
break;
548-
case VECTOR_TYPE_FLOAT1BIT:
549-
vector1BitSerializeToBlob(pVector, pBlob, nDataSize);
550-
break;
551-
case VECTOR_TYPE_FLOAT8:
552-
vectorF8SerializeToBlob(pVector, pBlob, nDataSize);
553-
break;
554-
default:
555-
assert(0);
556-
}
541+
vectorSerializeToBlob(pVector, pBlob, nDataSize);
557542
vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize);
558543
sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free);
559544
}
@@ -569,6 +554,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
569554
case VECTOR_TYPE_FLOAT1BIT:
570555
vector1BitSerializeToBlob(pVector, pBlob, nBlobSize);
571556
break;
557+
case VECTOR_TYPE_FLOAT8:
558+
vectorF8SerializeToBlob(pVector, pBlob, nBlobSize);
559+
break;
572560
default:
573561
assert(0);
574562
}

libsql-sqlite3/src/vectorfloat8.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){
100100
assert( v2->type == VECTOR_TYPE_FLOAT8 );
101101

102102
vectorF8GetParameters(v1->data, v1->dims, &alpha1, &shift1);
103-
vectorF8GetParameters(v2->data, v1->dims, &alpha2, &shift2);
103+
vectorF8GetParameters(v2->data, v2->dims, &alpha2, &shift2);
104104

105105
/*
106-
* (Ax + S)^2 = A^2 x^2 + S^2 + 2AS x -> we need to maintain 'sumsq' and 'sum'
106+
* (Ax + S)^2 = A^2 x^2 + 2AS x + S^2 -> we need to maintain 'sumsq' and 'sum'
107107
* (A1x + S1) * (A2y + S2) = A1A2 xy + A1 S2 x + A2 S1 y + S1 S2 -> we need to maintain 'dot' and 'sum' again
108108
*/
109109

@@ -112,12 +112,12 @@ float vectorF8DistanceCos(const Vector *v1, const Vector *v2){
112112
sum2 += data2[i];
113113
sumsq1 += data1[i]*data1[i];
114114
sumsq2 += data2[i]*data2[i];
115-
doti += data1[i] * data2[i];
115+
doti += data1[i]*data2[i];
116116
}
117117

118-
dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2;
119-
norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1;
120-
norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2;
118+
dot = alpha1 * alpha2 * (float)doti + alpha1 * shift2 * (float)sum1 + alpha2 * shift1 * (float)sum2 + shift1 * shift2 * v1->dims;
119+
norm1 = alpha1 * alpha1 * (float)sumsq1 + 2 * alpha1 * shift1 * (float)sum1 + shift1 * shift1 * v1->dims;
120+
norm2 = alpha2 * alpha2 * (float)sumsq2 + 2 * alpha2 * shift2 * (float)sum2 + shift2 * shift2 * v1->dims;
121121

122122
return 1.0 - (dot / sqrt(norm1 * norm2));
123123
}

libsql-sqlite3/test/libsql_vector.test

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ do_execsql_test vector-1-func-valid {
8080
{2.0}
8181
{1.0}
8282
{0.0}
83-
{-1.22070709096533e-08} {0.0}
84-
{1.54134213516954e-05} {0.000117244853754528}
85-
{-0.297326117753983} {0.0582110174000263}
83+
84+
{-6.10352568486405e-09} {0.0}
85+
{0.000111237335659098} {0.000117244853754528}
86+
{0.0576796568930149} {0.0582110174000263}
8687
}
8788

8889
do_execsql_test vector-1-conversion {

0 commit comments

Comments
 (0)