Skip to content

Commit 8262d23

Browse files
authored
Merge pull request #1650 from tursodatabase/vector-search-compression
vector search: neighbors compression (1bit quantization)
2 parents e4c2afc + 4c38e5f commit 8262d23

13 files changed

Lines changed: 1804 additions & 931 deletions

File tree

libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Lines changed: 597 additions & 310 deletions
Large diffs are not rendered by default.

libsql-ffi/bundled/src/sqlite3.c

Lines changed: 597 additions & 310 deletions
Large diffs are not rendered by default.

libsql-sqlite3/Makefile.in

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
195195
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
196196
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
197197
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
198-
vector.lo vectorfloat32.lo vectorfloat64.lo \
198+
vector.lo vectorfloat32.lo vectorfloat64.lo vector1bit.lo \
199199
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
200200
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
201201
vdbetrace.lo vdbevtab.lo \
@@ -302,6 +302,7 @@ SRC = \
302302
$(TOP)/src/util.c \
303303
$(TOP)/src/vacuum.c \
304304
$(TOP)/src/vector.c \
305+
$(TOP)/src/vector1bit.c \
305306
$(TOP)/src/vectorInt.h \
306307
$(TOP)/src/vectorfloat32.c \
307308
$(TOP)/src/vectorfloat64.c \
@@ -1138,6 +1139,9 @@ vacuum.lo: $(TOP)/src/vacuum.c $(HDR)
11381139
vector.lo: $(TOP)/src/vector.c $(HDR)
11391140
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector.c
11401141

1142+
vector1bit.lo: $(TOP)/src/vector1bit.c $(HDR)
1143+
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector1bit.c
1144+
11411145
vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
11421146
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c
11431147

libsql-sqlite3/src/vector.c

Lines changed: 95 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4141
return dims * sizeof(float);
4242
case VECTOR_TYPE_FLOAT64:
4343
return dims * sizeof(double);
44+
case VECTOR_TYPE_1BIT:
45+
assert( dims > 0 );
46+
return (dims + 7) / 8;
4447
default:
4548
assert(0);
4649
}
@@ -72,10 +75,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){
7275
** Note that the vector object points to the blob so if
7376
** you free the blob, the vector becomes invalid.
7477
**/
75-
void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){
76-
pVector->type = type;
78+
void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){
7779
pVector->flags = VECTOR_FLAGS_STATIC;
78-
vectorInitFromBlob(pVector, pBlob, nBlobSize);
80+
pVector->type = type;
81+
pVector->dims = dims;
82+
pVector->data = pBlob;
7983
}
8084

8185
/*
@@ -111,6 +115,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
111115
return vectorF32DistanceCos(pVector1, pVector2);
112116
case VECTOR_TYPE_FLOAT64:
113117
return vectorF64DistanceCos(pVector1, pVector2);
118+
case VECTOR_TYPE_1BIT:
119+
return vector1BitDistanceHamming(pVector1, pVector2);
114120
default:
115121
assert(0);
116122
}
@@ -247,16 +253,34 @@ static int vectorParseSqliteText(
247253
return -1;
248254
}
249255

250-
int vectorParseSqliteBlob(
256+
int vectorParseSqliteBlobWithType(
251257
sqlite3_value *arg,
252258
Vector *pVector,
253259
char **pzErrMsg
254260
){
261+
const unsigned char *pBlob;
262+
size_t nBlobSize;
263+
264+
assert( sqlite3_value_type(arg) == SQLITE_BLOB );
265+
266+
pBlob = sqlite3_value_blob(arg);
267+
nBlobSize = sqlite3_value_bytes(arg);
268+
if( nBlobSize % 2 == 1 ){
269+
nBlobSize--;
270+
}
271+
272+
if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){
273+
*pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize);
274+
return SQLITE_ERROR;
275+
}
276+
255277
switch (pVector->type) {
256278
case VECTOR_TYPE_FLOAT32:
257-
return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg);
279+
vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize);
280+
return 0;
258281
case VECTOR_TYPE_FLOAT64:
259-
return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg);
282+
vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize);
283+
return 0;
260284
default:
261285
assert(0);
262286
}
@@ -339,14 +363,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD
339363
}
340364
}
341365

342-
int vectorParse(
366+
int vectorParseWithType(
343367
sqlite3_value *arg,
344368
Vector *pVector,
345369
char **pzErrMsg
346370
){
347371
switch( sqlite3_value_type(arg) ){
348372
case SQLITE_BLOB:
349-
return vectorParseSqliteBlob(arg, pVector, pzErrMsg);
373+
return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg);
350374
case SQLITE_TEXT:
351375
return vectorParseSqliteText(arg, pVector, pzErrMsg);
352376
default:
@@ -363,6 +387,9 @@ void vectorDump(const Vector *pVector){
363387
case VECTOR_TYPE_FLOAT64:
364388
vectorF64Dump(pVector);
365389
break;
390+
case VECTOR_TYPE_1BIT:
391+
vector1BitDump(pVector);
392+
break;
366393
default:
367394
assert(0);
368395
}
@@ -384,20 +411,47 @@ void vectorMarshalToText(
384411
}
385412
}
386413

387-
void vectorSerialize(
414+
void vectorSerializeWithType(
388415
sqlite3_context *context,
389416
const Vector *pVector
390417
){
418+
unsigned char *pBlob;
419+
size_t nBlobSize, nDataSize;
420+
421+
assert( pVector->dims <= MAX_VECTOR_SZ );
422+
423+
nDataSize = vectorDataSize(pVector->type, pVector->dims);
424+
nBlobSize = nDataSize;
425+
if( pVector->type != VECTOR_TYPE_FLOAT32 ){
426+
nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2);
427+
}
428+
429+
if( nBlobSize == 0 ){
430+
sqlite3_result_zeroblob(context, 0);
431+
return;
432+
}
433+
434+
pBlob = sqlite3_malloc64(nBlobSize);
435+
if( pBlob == NULL ){
436+
sqlite3_result_error_nomem(context);
437+
return;
438+
}
439+
440+
if( pVector->type != VECTOR_TYPE_FLOAT32 ){
441+
pBlob[nBlobSize - 1] = pVector->type;
442+
}
443+
391444
switch (pVector->type) {
392445
case VECTOR_TYPE_FLOAT32:
393-
vectorF32Serialize(context, pVector);
446+
vectorF32SerializeToBlob(pVector, pBlob, nDataSize);
394447
break;
395448
case VECTOR_TYPE_FLOAT64:
396-
vectorF64Serialize(context, pVector);
449+
vectorF64SerializeToBlob(pVector, pBlob, nDataSize);
397450
break;
398451
default:
399452
assert(0);
400453
}
454+
sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free);
401455
}
402456

403457
size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){
@@ -406,18 +460,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t
406460
return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize);
407461
case VECTOR_TYPE_FLOAT64:
408462
return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize);
409-
default:
410-
assert(0);
411-
}
412-
return 0;
413-
}
414-
415-
size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){
416-
switch (pVector->type) {
417-
case VECTOR_TYPE_FLOAT32:
418-
return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize);
419-
case VECTOR_TYPE_FLOAT64:
420-
return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize);
463+
case VECTOR_TYPE_1BIT:
464+
return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize);
421465
default:
422466
assert(0);
423467
}
@@ -437,6 +481,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo
437481
}
438482
}
439483

484+
void vectorConvert(const Vector *pFrom, Vector *pTo){
485+
int i;
486+
u8 *bitData;
487+
float *floatData;
488+
489+
assert( pFrom->dims == pTo->dims );
490+
491+
if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){
492+
floatData = pFrom->data;
493+
bitData = pTo->data;
494+
for(i = 0; i < pFrom->dims; i += 8){
495+
bitData[i / 8] = 0;
496+
}
497+
for(i = 0; i < pFrom->dims; i++){
498+
if( floatData[i] > 0 ){
499+
bitData[i / 8] |= (1 << (i & 7));
500+
}
501+
}
502+
}else{
503+
assert(0);
504+
}
505+
}
506+
440507
/**************************************************************************
441508
** SQL function implementations
442509
****************************************************************************/
@@ -465,12 +532,12 @@ static void vectorFuncHintedType(
465532
if( pVector==NULL ){
466533
return;
467534
}
468-
if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){
535+
if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){
469536
sqlite3_result_error(context, pzErrMsg, -1);
470537
sqlite3_free(pzErrMsg);
471538
goto out_free_vec;
472539
}
473-
vectorSerialize(context, pVector);
540+
vectorSerializeWithType(context, pVector);
474541
out_free_vec:
475542
vectorFree(pVector);
476543
}
@@ -515,7 +582,7 @@ static void vectorExtractFunc(
515582
if( pVector==NULL ){
516583
return;
517584
}
518-
if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){
585+
if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){
519586
sqlite3_result_error(context, pzErrMsg, -1);
520587
sqlite3_free(pzErrMsg);
521588
goto out_free;
@@ -570,12 +637,12 @@ static void vectorDistanceCosFunc(
570637
if( pVector2==NULL ){
571638
goto out_free;
572639
}
573-
if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){
640+
if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){
574641
sqlite3_result_error(context, pzErrMsg, -1);
575642
sqlite3_free(pzErrMsg);
576643
goto out_free;
577644
}
578-
if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){
645+
if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){
579646
sqlite3_result_error(context, pzErrMsg, -1);
580647
sqlite3_free(pzErrMsg);
581648
goto out_free;

libsql-sqlite3/src/vector1bit.c

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
** 2024-07-04
3+
**
4+
** Copyright 2024 the libSQL authors
5+
**
6+
** Permission is hereby granted, free of charge, to any person obtaining a copy of
7+
** this software and associated documentation files (the "Software"), to deal in
8+
** the Software without restriction, including without limitation the rights to
9+
** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
10+
** the Software, and to permit persons to whom the Software is furnished to do so,
11+
** subject to the following conditions:
12+
**
13+
** The above copyright notice and this permission notice shall be included in all
14+
** copies or substantial portions of the Software.
15+
**
16+
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
18+
** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
19+
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
20+
** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21+
** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22+
**
23+
******************************************************************************
24+
**
25+
** 1-bit vector format utilities.
26+
*/
27+
#ifndef SQLITE_OMIT_VECTOR
28+
#include "sqliteInt.h"
29+
30+
#include "vectorInt.h"
31+
32+
#include <math.h>
33+
34+
/**************************************************************************
35+
** Utility routines for debugging
36+
**************************************************************************/
37+
38+
void vector1BitDump(const Vector *pVec){
39+
u8 *elems = pVec->data;
40+
unsigned i;
41+
42+
assert( pVec->type == VECTOR_TYPE_1BIT );
43+
44+
printf("f1bit: [");
45+
for(i = 0; i < pVec->dims; i++){
46+
printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1);
47+
}
48+
printf("]\n");
49+
}
50+
51+
/**************************************************************************
52+
** Utility routines for vector serialization and deserialization
53+
**************************************************************************/
54+
55+
size_t vector1BitSerializeToBlob(
56+
const Vector *pVector,
57+
unsigned char *pBlob,
58+
size_t nBlobSize
59+
){
60+
u8 *elems = pVector->data;
61+
u8 *pPtr = pBlob;
62+
unsigned i;
63+
64+
assert( pVector->type == VECTOR_TYPE_1BIT );
65+
assert( pVector->dims <= MAX_VECTOR_SZ );
66+
assert( nBlobSize >= (pVector->dims + 7) / 8 );
67+
68+
for(i = 0; i < (pVector->dims + 7) / 8; i++){
69+
pPtr[i] = elems[i];
70+
}
71+
return (pVector->dims + 7) / 8;
72+
}
73+
74+
// [sum(map(int, bin(i)[2:])) for i in range(256)]
75+
static int BitsCount[256] = {
76+
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
77+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
78+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
79+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
80+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
81+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
82+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
83+
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
84+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
85+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
86+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
87+
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
88+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
89+
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
90+
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
91+
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
92+
};
93+
94+
static inline int sqlite3PopCount32(u32 a){
95+
#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER)
96+
return __builtin_popcount(a);
97+
#else
98+
return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff];
99+
#endif
100+
}
101+
102+
int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){
103+
int diff = 0;
104+
u8 *e1U8 = v1->data;
105+
u32 *e1U32 = v1->data;
106+
u8 *e2U8 = v2->data;
107+
u32 *e2U32 = v2->data;
108+
int i, len8, len32, offset8;
109+
110+
assert( v1->dims == v2->dims );
111+
assert( v1->type == VECTOR_TYPE_1BIT );
112+
assert( v2->type == VECTOR_TYPE_1BIT );
113+
114+
len8 = (v1->dims + 7) / 8;
115+
len32 = v1->dims / 32;
116+
offset8 = len32 * 4;
117+
118+
for(i = 0; i < len32; i++){
119+
diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]);
120+
}
121+
for(i = offset8; i < len8; i++){
122+
diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]);
123+
}
124+
return diff;
125+
}
126+
127+
#endif /* !defined(SQLITE_OMIT_VECTOR) */

0 commit comments

Comments
 (0)