Skip to content

Commit f76bc0a

Browse files
authored
Merge pull request #1688 from tursodatabase/vector-search-bfloat16
vector search: implement and integrate bfloat16
2 parents 32037aa + 2d75b23 commit f76bc0a

11 files changed

Lines changed: 871 additions & 43 deletions

File tree

libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Lines changed: 272 additions & 12 deletions
Large diffs are not rendered by default.

libsql-ffi/bundled/src/sqlite3.c

Lines changed: 272 additions & 12 deletions
Large diffs are not rendered by default.

libsql-sqlite3/Makefile.in

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
195195
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
196196
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
197197
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
198-
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo \
198+
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo vectorfloatb16.lo \
199199
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
200200
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
201201
vdbetrace.lo vdbevtab.lo \
@@ -304,10 +304,12 @@ SRC = \
304304
$(TOP)/src/vector.c \
305305
$(TOP)/src/vectorInt.h \
306306
$(TOP)/src/vectorfloat1bit.c \
307+
$(TOP)/src/vectorfloat1bit.c \
307308
$(TOP)/src/vectorfloat16.c \
308309
$(TOP)/src/vectorfloat32.c \
309310
$(TOP)/src/vectorfloat64.c \
310311
$(TOP)/src/vectorfloat8.c \
312+
$(TOP)/src/vectorfloatb16.c \
311313
$(TOP)/src/vectorIndexInt.h \
312314
$(TOP)/src/vectorIndex.c \
313315
$(TOP)/src/vectordiskann.c \
@@ -1147,6 +1149,9 @@ vectorfloat1bit.lo: $(TOP)/src/vectorfloat1bit.c $(HDR)
11471149
vectorfloat16.lo: $(TOP)/src/vectorfloat16.c $(HDR)
11481150
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat16.c
11491151

1152+
vectorfloatb16.lo: $(TOP)/src/vectorfloatb16.c $(HDR)
1153+
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloatb16.c
1154+
11501155
vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
11511156
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c
11521157

libsql-sqlite3/src/vector.c

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4747
return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */;
4848
case VECTOR_TYPE_FLOAT16:
4949
return dims * sizeof(u16);
50+
case VECTOR_TYPE_FLOATB16:
51+
return dims * sizeof(u16);
5052
default:
5153
assert(0);
5254
}
@@ -124,6 +126,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
124126
return vectorF8DistanceCos(pVector1, pVector2);
125127
case VECTOR_TYPE_FLOAT16:
126128
return vectorF16DistanceCos(pVector1, pVector2);
129+
case VECTOR_TYPE_FLOATB16:
130+
return vectorFB16DistanceCos(pVector1, pVector2);
127131
default:
128132
assert(0);
129133
}
@@ -141,6 +145,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
141145
return vectorF8DistanceL2(pVector1, pVector2);
142146
case VECTOR_TYPE_FLOAT16:
143147
return vectorF16DistanceL2(pVector1, pVector2);
148+
case VECTOR_TYPE_FLOATB16:
149+
return vectorFB16DistanceL2(pVector1, pVector2);
144150
default:
145151
assert(0);
146152
}
@@ -314,6 +320,13 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
314320
}
315321
*pDims = nBlobSize / sizeof(u16);
316322
*pDataSize = nBlobSize;
323+
}else if( *pType == VECTOR_TYPE_FLOATB16 ){
324+
if( nBlobSize % 2 != 0 ){
325+
*pzErrMsg = sqlite3_mprintf("vector: floatb16 vector blob length must be divisible by 2 (excluding 'type'-byte): length=%d", nBlobSize);
326+
return SQLITE_ERROR;
327+
}
328+
*pDims = nBlobSize / sizeof(u16);
329+
*pDataSize = nBlobSize;
317330
}else{
318331
*pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType);
319332
return SQLITE_ERROR;
@@ -365,6 +378,9 @@ int vectorParseSqliteBlobWithType(
365378
case VECTOR_TYPE_FLOAT16:
366379
vectorF16DeserializeFromBlob(pVector, pBlob, nDataSize);
367380
return 0;
381+
case VECTOR_TYPE_FLOATB16:
382+
vectorFB16DeserializeFromBlob(pVector, pBlob, nDataSize);
383+
return 0;
368384
default:
369385
assert(0);
370386
}
@@ -469,6 +485,9 @@ void vectorDump(const Vector *pVector){
469485
case VECTOR_TYPE_FLOAT16:
470486
vectorF16Dump(pVector);
471487
break;
488+
case VECTOR_TYPE_FLOATB16:
489+
vectorFB16Dump(pVector);
490+
break;
472491
default:
473492
assert(0);
474493
}
@@ -494,7 +513,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
494513
int nDataSize;
495514
if( type == VECTOR_TYPE_FLOAT32 ){
496515
return 0;
497-
}else if( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 ){
516+
}else if( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 || type == VECTOR_TYPE_FLOATB16 ){
498517
return 1;
499518
}else if( type == VECTOR_TYPE_FLOAT1BIT ){
500519
nDataSize = vectorDataSize(type, dims);
@@ -513,7 +532,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
513532
static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){
514533
if( pVector->type == VECTOR_TYPE_FLOAT32 ){
515534
// no meta for f32 type as this is "default" vector type
516-
}else if( pVector->type == VECTOR_TYPE_FLOAT64 || pVector->type == VECTOR_TYPE_FLOAT16 ){
535+
}else if( pVector->type == VECTOR_TYPE_FLOAT64 || pVector->type == VECTOR_TYPE_FLOAT16 || pVector->type == VECTOR_TYPE_FLOATB16 ){
517536
assert( nDataSize % 2 == 0 );
518537
assert( nBlobSize == nDataSize + 1 );
519538
pBlob[nBlobSize - 1] = pVector->type;
@@ -582,6 +601,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
582601
case VECTOR_TYPE_FLOAT16:
583602
vectorF16SerializeToBlob(pVector, pBlob, nBlobSize);
584603
break;
604+
case VECTOR_TYPE_FLOATB16:
605+
vectorFB16SerializeToBlob(pVector, pBlob, nBlobSize);
606+
break;
585607
default:
586608
assert(0);
587609
}
@@ -624,6 +646,11 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
624646
for(i = 0; i < pFrom->dims; i++){
625647
dstF16[i] = vectorF16FromFloat(src[i]);
626648
}
649+
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
650+
dstF16 = pTo->data;
651+
for(i = 0; i < pFrom->dims; i++){
652+
dstF16[i] = vectorFB16FromFloat(src[i]);
653+
}
627654
}else{
628655
assert( 0 );
629656
}
@@ -662,6 +689,11 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
662689
for(i = 0; i < pFrom->dims; i++){
663690
dstF16[i] = vectorF16FromFloat(src[i]);
664691
}
692+
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
693+
dstF16 = pTo->data;
694+
for(i = 0; i < pFrom->dims; i++){
695+
dstF16[i] = vectorFB16FromFloat(src[i]);
696+
}
665697
}else{
666698
assert( 0 );
667699
}
@@ -673,7 +705,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
673705

674706
float *dstF32;
675707
double *dstF64;
676-
u16 *dstF16;
708+
u16 *dstU16;
677709

678710
assert( pFrom->dims == pTo->dims );
679711
assert( pFrom->type != pTo->type );
@@ -701,12 +733,23 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
701733
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
702734
u16 positive = vectorF16FromFloat(+1);
703735
u16 negative = vectorF16FromFloat(-1);
704-
dstF16 = pTo->data;
736+
dstU16 = pTo->data;
737+
for(i = 0; i < pFrom->dims; i++){
738+
if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){
739+
dstU16[i] = positive;
740+
}else{
741+
dstU16[i] = negative;
742+
}
743+
}
744+
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
745+
u16 positive = vectorFB16FromFloat(+1);
746+
u16 negative = vectorFB16FromFloat(-1);
747+
dstU16 = pTo->data;
705748
for(i = 0; i < pFrom->dims; i++){
706749
if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){
707-
dstF16[i] = positive;
750+
dstU16[i] = positive;
708751
}else{
709-
dstF16[i] = negative;
752+
dstU16[i] = negative;
710753
}
711754
}
712755
}else{
@@ -756,6 +799,11 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
756799
for(i = 0; i < pFrom->dims; i++){
757800
dstF16[i] = vectorF16FromFloat(alpha * src[i] + shift);
758801
}
802+
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
803+
dstF16 = pTo->data;
804+
for(i = 0; i < pFrom->dims; i++){
805+
dstF16[i] = vectorFB16FromFloat(alpha * src[i] + shift);
806+
}
759807
}else{
760808
assert( 0 );
761809
}
@@ -768,6 +816,7 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
768816
float *dstF32;
769817
double *dstF64;
770818
u8 *dst1Bit;
819+
u16 *dstU16;
771820

772821
assert( pFrom->dims == pTo->dims );
773822
assert( pFrom->type != pTo->type );
@@ -784,6 +833,11 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
784833
for(i = 0; i < pFrom->dims; i++){
785834
dstF64[i] = vectorF16ToFloat(src[i]);
786835
}
836+
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
837+
dstU16 = pTo->data;
838+
for(i = 0; i < pFrom->dims; i++){
839+
dstU16[i] = vectorFB16FromFloat(vectorF16ToFloat(src[i]));
840+
}
787841
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
788842
dst1Bit = pTo->data;
789843
for(i = 0; i < pFrom->dims; i += 8){
@@ -799,6 +853,50 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
799853
}
800854
}
801855

856+
static void vectorConvertFromFB16(const Vector *pFrom, Vector *pTo){
857+
int i;
858+
u16 *src;
859+
860+
float *dstF32;
861+
double *dstF64;
862+
u8 *dst1Bit;
863+
u16 *dstU16;
864+
865+
assert( pFrom->dims == pTo->dims );
866+
assert( pFrom->type != pTo->type );
867+
assert( pFrom->type == VECTOR_TYPE_FLOATB16 );
868+
869+
src = pFrom->data;
870+
if( pTo->type == VECTOR_TYPE_FLOAT32 ){
871+
dstF32 = pTo->data;
872+
for(i = 0; i < pFrom->dims; i++){
873+
dstF32[i] = vectorFB16ToFloat(src[i]);
874+
}
875+
}else if( pTo->type == VECTOR_TYPE_FLOAT64 ){
876+
dstF64 = pTo->data;
877+
for(i = 0; i < pFrom->dims; i++){
878+
dstF64[i] = vectorFB16ToFloat(src[i]);
879+
}
880+
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
881+
dstU16 = pTo->data;
882+
for(i = 0; i < pFrom->dims; i++){
883+
dstU16[i] = vectorF16FromFloat(vectorFB16ToFloat(src[i]));
884+
}
885+
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
886+
dst1Bit = pTo->data;
887+
for(i = 0; i < pFrom->dims; i += 8){
888+
dst1Bit[i / 8] = 0;
889+
}
890+
for(i = 0; i < pFrom->dims; i++){
891+
if( vectorFB16ToFloat(src[i]) > 0 ){
892+
dst1Bit[i / 8] |= (1 << (i & 7));
893+
}
894+
}
895+
}else{
896+
assert( 0 );
897+
}
898+
}
899+
802900
static inline int clip(float f, int minF, int maxF){
803901
if( f < minF ){
804902
return minF;
@@ -819,7 +917,7 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
819917
float *srcF32;
820918
double *srcF64;
821919
u8 *src1Bit;
822-
u16 *srcF16;
920+
u16 *srcU16;
823921

824922
assert( pFrom->dims == pTo->dims );
825923
assert( pFrom->type != pTo->type );
@@ -857,14 +955,24 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
857955
dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255);
858956
}
859957
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
860-
srcF16 = pFrom->data;
958+
srcU16 = pFrom->data;
959+
for(i = 0; i < pFrom->dims; i++){
960+
MINMAX(i, vectorF16ToFloat(srcU16[i]), minF, maxF);
961+
}
962+
shift = minF;
963+
alpha = (maxF - minF) / 255;
964+
for(i = 0; i < pFrom->dims; i++){
965+
dst[i] = clip((vectorF16ToFloat(srcU16[i]) - shift) / alpha, 0, 255);
966+
}
967+
}else if( pFrom->type == VECTOR_TYPE_FLOATB16 ){
968+
srcU16 = pFrom->data;
861969
for(i = 0; i < pFrom->dims; i++){
862-
MINMAX(i, vectorF16ToFloat(srcF16[i]), minF, maxF);
970+
MINMAX(i, vectorFB16ToFloat(srcU16[i]), minF, maxF);
863971
}
864972
shift = minF;
865973
alpha = (maxF - minF) / 255;
866974
for(i = 0; i < pFrom->dims; i++){
867-
dst[i] = clip((vectorF16ToFloat(srcF16[i]) - shift) / alpha, 0, 255);
975+
dst[i] = clip((vectorFB16ToFloat(srcU16[i]) - shift) / alpha, 0, 255);
868976
}
869977
}else{
870978
assert( 0 );
@@ -893,6 +1001,8 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
8931001
vectorConvertFromF8(pFrom, pTo);
8941002
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
8951003
vectorConvertFromF16(pFrom, pTo);
1004+
}else if( pFrom->type == VECTOR_TYPE_FLOATB16 ){
1005+
vectorConvertFromFB16(pFrom, pTo);
8961006
}else{
8971007
assert( 0 );
8981008
}
@@ -985,6 +1095,14 @@ static void vector16Func(
9851095
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT16);
9861096
}
9871097

1098+
static void vectorb16Func(
1099+
sqlite3_context *context,
1100+
int argc,
1101+
sqlite3_value **argv
1102+
){
1103+
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOATB16);
1104+
}
1105+
9881106
static void vector1BitFunc(
9891107
sqlite3_context *context,
9901108
int argc,
@@ -1144,6 +1262,7 @@ void sqlite3RegisterVectorFunctions(void){
11441262
FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc),
11451263
FUNCTION(vector8, 1, 0, 0, vector8Func),
11461264
FUNCTION(vector16, 1, 0, 0, vector16Func),
1265+
FUNCTION(vectorb16, 1, 0, 0, vectorb16Func),
11471266
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
11481267
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
11491268
FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func),

libsql-sqlite3/src/vectorIndex.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
387387
{ "F8_BLOB", VECTOR_TYPE_FLOAT8 },
388388
{ "FLOAT16", VECTOR_TYPE_FLOAT16 },
389389
{ "F16_BLOB", VECTOR_TYPE_FLOAT16 },
390+
{ "FLOATB16", VECTOR_TYPE_FLOATB16 },
391+
{ "FB16_BLOB", VECTOR_TYPE_FLOATB16 },
390392
};
391393

392394
/*
@@ -408,6 +410,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = {
408410
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT },
409411
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 },
410412
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float16", VECTOR_TYPE_FLOAT16 },
413+
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "floatb16", VECTOR_TYPE_FLOATB16 },
411414
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 },
412415
{ "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 },
413416
{ "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 },

0 commit comments

Comments
 (0)