Skip to content

Commit 8a2a019

Browse files
committed
add implementation of float8 vector type (int8 quantization)
1 parent 8604065 commit 8a2a019

10 files changed

Lines changed: 403 additions & 59 deletions

File tree

libsql-sqlite3/Makefile.in

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
195195
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
196196
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
197197
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
198-
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo \
198+
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo \
199199
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
200200
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
201201
vdbetrace.lo vdbevtab.lo \
@@ -306,6 +306,7 @@ SRC = \
306306
$(TOP)/src/vectorfloat1bit.c \
307307
$(TOP)/src/vectorfloat32.c \
308308
$(TOP)/src/vectorfloat64.c \
309+
$(TOP)/src/vectorfloat8.c \
309310
$(TOP)/src/vectorIndexInt.h \
310311
$(TOP)/src/vectorIndex.c \
311312
$(TOP)/src/vectordiskann.c \
@@ -1148,6 +1149,9 @@ vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
11481149
vectorfloat64.lo: $(TOP)/src/vectorfloat64.c $(HDR)
11491150
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat64.c
11501151

1152+
vectorfloat8.lo: $(TOP)/src/vectorfloat8.c $(HDR)
1153+
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat8.c
1154+
11511155
vectorIndex.lo: $(TOP)/src/vectorIndex.c $(HDR)
11521156
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorIndex.c
11531157

libsql-sqlite3/src/vector.c

Lines changed: 169 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4343
return dims * sizeof(double);
4444
case VECTOR_TYPE_FLOAT1BIT:
4545
return (dims + 7) / 8;
46+
case VECTOR_TYPE_FLOAT8:
47+
return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */;
4648
default:
4749
assert(0);
4850
}
@@ -116,6 +118,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
116118
return vectorF64DistanceCos(pVector1, pVector2);
117119
case VECTOR_TYPE_FLOAT1BIT:
118120
return vector1BitDistanceHamming(pVector1, pVector2);
121+
case VECTOR_TYPE_FLOAT8:
122+
return vectorF8DistanceCos(pVector1, pVector2);
119123
default:
120124
assert(0);
121125
}
@@ -253,7 +257,8 @@ static int vectorParseSqliteText(
253257
}
254258

255259
static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pType, int *pDims, size_t *pDataSize, char **pzErrMsg){
256-
int nLeftoverBits;
260+
int nTrailingBits;
261+
int nTrailingBytes;
257262

258263
if( nBlobSize % 2 == 0 ){
259264
*pType = VECTOR_TYPE_FLOAT32;
@@ -266,26 +271,34 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
266271

267272
if( *pType == VECTOR_TYPE_FLOAT32 ){
268273
if( nBlobSize % 4 != 0 ){
269-
*pzErrMsg = sqlite3_mprintf("vector: f32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize);
274+
*pzErrMsg = sqlite3_mprintf("vector: float32 vector blob length must be divisible by 4 (excluding optional 'type'-byte): length=%d", nBlobSize);
270275
return SQLITE_ERROR;
271276
}
272277
*pDims = nBlobSize / sizeof(float);
273278
*pDataSize = nBlobSize;
274279
}else if( *pType == VECTOR_TYPE_FLOAT64 ){
275280
if( nBlobSize % 8 != 0 ){
276-
*pzErrMsg = sqlite3_mprintf("vector: f64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize);
281+
*pzErrMsg = sqlite3_mprintf("vector: float64 vector blob length must be divisible by 8 (excluding 'type'-byte): length=%d", nBlobSize);
277282
return SQLITE_ERROR;
278283
}
279284
*pDims = nBlobSize / sizeof(double);
280285
*pDataSize = nBlobSize;
281286
}else if( *pType == VECTOR_TYPE_FLOAT1BIT ){
282287
if( nBlobSize == 0 || nBlobSize % 2 != 0 ){
283-
*pzErrMsg = sqlite3_mprintf("vector: 1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize);
288+
*pzErrMsg = sqlite3_mprintf("vector: float1bit vector blob length must be divisible by 2 and not be empty (excluding 'type'-byte): length=%d", nBlobSize);
284289
return SQLITE_ERROR;
285290
}
286-
nLeftoverBits = pBlob[nBlobSize - 1];
287-
*pDims = nBlobSize * 8 - nLeftoverBits;
291+
nTrailingBits = pBlob[nBlobSize - 1];
292+
*pDims = nBlobSize * 8 - nTrailingBits;
288293
*pDataSize = (*pDims + 7) / 8;
294+
}else if( *pType == VECTOR_TYPE_FLOAT8 ){
295+
if( nBlobSize < 2 || nBlobSize % 2 != 0 ){
296+
*pzErrMsg = sqlite3_mprintf("vector: float8 vector blob length must be divisible by 2 and has at least 2 bytes (excluding 'type'-byte): length=%d", nBlobSize);
297+
return SQLITE_ERROR;
298+
}
299+
nTrailingBytes = pBlob[nBlobSize - 1];
300+
*pDims = (nBlobSize - 2) - sizeof(float) - sizeof(float) - nTrailingBytes;
301+
*pDataSize = nBlobSize - 2;
289302
}else{
290303
*pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType);
291304
return SQLITE_ERROR;
@@ -331,6 +344,9 @@ int vectorParseSqliteBlobWithType(
331344
case VECTOR_TYPE_FLOAT1BIT:
332345
vector1BitDeserializeFromBlob(pVector, pBlob, nDataSize);
333346
return 0;
347+
case VECTOR_TYPE_FLOAT8:
348+
vectorF8DeserializeFromBlob(pVector, pBlob, nDataSize);
349+
return 0;
334350
default:
335351
assert(0);
336352
}
@@ -429,6 +445,9 @@ void vectorDump(const Vector *pVector){
429445
case VECTOR_TYPE_FLOAT1BIT:
430446
vector1BitDump(pVector);
431447
break;
448+
case VECTOR_TYPE_FLOAT8:
449+
vectorF8Dump(pVector);
450+
break;
432451
default:
433452
assert(0);
434453
}
@@ -451,20 +470,20 @@ void vectorMarshalToText(
451470
}
452471

453472
static int vectorMetaSize(VectorType type, VectorDims dims){
454-
int nMetaSize = 0;
455473
int nDataSize;
456474
if( type == VECTOR_TYPE_FLOAT32 ){
457475
return 0;
458476
}else if( type == VECTOR_TYPE_FLOAT64 ){
459477
return 1;
460478
}else if( type == VECTOR_TYPE_FLOAT1BIT ){
461479
nDataSize = vectorDataSize(type, dims);
462-
nMetaSize++; // one byte which specify amount of leftover bits
463-
if( nDataSize % 2 == 0 ){
464-
nMetaSize++; // pad "leftover-bits" byte to the even length
465-
}
466-
nMetaSize++; // one byte for vector type
467-
return nMetaSize;
480+
// optional padding byte + "trailing-bits" byte + "vector-type" byte
481+
return (nDataSize % 2 == 0 ? 1 : 0) + 1 + 1;
482+
}else if( type == VECTOR_TYPE_FLOAT8 ){
483+
nDataSize = vectorDataSize(type, dims);
484+
assert( nDataSize % 2 == 0 );
485+
/* padding byte + "trailing-bytes" byte + "vector-type" byte */
486+
return 1 + 1 + 1;
468487
}else{
469488
assert( 0 );
470489
}
@@ -482,6 +501,15 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne
482501
assert( nBlobSize >= 3 );
483502
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT;
484503
pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims;
504+
if( vectorMetaSize(pVector->type, pVector->dims) == 3 ){
505+
pBlob[nBlobSize - 3] = 0;
506+
}
507+
}else if( pVector->type == VECTOR_TYPE_FLOAT8 ){
508+
assert( nBlobSize % 2 == 1 );
509+
assert( nDataSize % 2 == 0 );
510+
assert( nBlobSize == nDataSize + 3 );
511+
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT8;
512+
pBlob[nBlobSize - 2] = ALIGN(pVector->dims, sizeof(float)) - pVector->dims;
485513
}else{
486514
assert( 0 );
487515
}
@@ -520,25 +548,30 @@ void vectorSerializeWithMeta(
520548
case VECTOR_TYPE_FLOAT1BIT:
521549
vector1BitSerializeToBlob(pVector, pBlob, nDataSize);
522550
break;
551+
case VECTOR_TYPE_FLOAT8:
552+
vectorF8SerializeToBlob(pVector, pBlob, nDataSize);
553+
break;
523554
default:
524555
assert(0);
525556
}
526557
vectorSerializeMeta(pVector, nDataSize, pBlob, nBlobSize);
527558
sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free);
528559
}
529560

530-
size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){
561+
void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){
531562
switch (pVector->type) {
532563
case VECTOR_TYPE_FLOAT32:
533-
return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize);
564+
vectorF32SerializeToBlob(pVector, pBlob, nBlobSize);
565+
break;
534566
case VECTOR_TYPE_FLOAT64:
535-
return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize);
567+
vectorF64SerializeToBlob(pVector, pBlob, nBlobSize);
568+
break;
536569
case VECTOR_TYPE_FLOAT1BIT:
537-
return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize);
570+
vector1BitSerializeToBlob(pVector, pBlob, nBlobSize);
571+
break;
538572
default:
539573
assert(0);
540574
}
541-
return 0;
542575
}
543576

544577
void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){
@@ -644,6 +677,110 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
644677
}
645678
}
646679

680+
static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
681+
int i;
682+
u8 *src;
683+
float alpha, shift;
684+
685+
float *dstF32;
686+
double *dstF64;
687+
u8 *dst1Bit;
688+
689+
assert( pFrom->dims == pTo->dims );
690+
assert( pFrom->type != pTo->type );
691+
assert( pFrom->type == VECTOR_TYPE_FLOAT8 );
692+
693+
vectorF8GetParameters(pFrom->data, pFrom->dims, &alpha, &shift);
694+
695+
src = pFrom->data;
696+
if( pTo->type == VECTOR_TYPE_FLOAT32 ){
697+
dstF32 = pTo->data;
698+
for(i = 0; i < pFrom->dims; i++){
699+
dstF32[i] = alpha * src[i] + shift;
700+
}
701+
}else if( pTo->type == VECTOR_TYPE_FLOAT64 ){
702+
dstF64 = pTo->data;
703+
for(i = 0; i < pFrom->dims; i++){
704+
dstF64[i] = alpha * src[i] + shift;
705+
}
706+
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
707+
dst1Bit = pTo->data;
708+
for(i = 0; i < pFrom->dims; i += 8){
709+
dst1Bit[i / 8] = 0;
710+
}
711+
for(i = 0; i < pFrom->dims; i++){
712+
if( (alpha * src[i] + shift) > 0 ){
713+
dst1Bit[i / 8] |= (1 << (i & 7));
714+
}
715+
}
716+
}else{
717+
assert( 0 );
718+
}
719+
}
720+
721+
static inline int clip(float f, int minF, int maxF){
722+
if( f < minF ){
723+
return minF;
724+
}else if( f > maxF ){
725+
return maxF;
726+
}
727+
return (int)(f + 0.5);
728+
}
729+
730+
#define MINMAX(i, value, minValue, maxValue) {if(i == 0){ minValue = (value); maxValue = (value);} else { minValue = MIN(minValue, (value)); maxValue = MAX(maxValue, (value)); }}
731+
732+
static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
733+
int i;
734+
u8 *dst;
735+
float alpha, shift;
736+
float minF = 0, maxF = 0;
737+
738+
float *srcF32;
739+
double *srcF64;
740+
u8 *src1Bit;
741+
742+
assert( pFrom->dims == pTo->dims );
743+
assert( pFrom->type != pTo->type );
744+
assert( pTo->type == VECTOR_TYPE_FLOAT8 );
745+
746+
dst = pTo->data;
747+
if( pFrom->type == VECTOR_TYPE_FLOAT32 ){
748+
srcF32 = pFrom->data;
749+
for(i = 0; i < pFrom->dims; i++){
750+
MINMAX(i, srcF32[i], minF, maxF);
751+
}
752+
shift = minF;
753+
alpha = (maxF - minF) / 255;
754+
for(i = 0; i < pFrom->dims; i++){
755+
dst[i] = clip((srcF32[i] - shift) / alpha, 0, 255);
756+
}
757+
}else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){
758+
srcF64 = pFrom->data;
759+
for(i = 0; i < pFrom->dims; i++){
760+
MINMAX(i, srcF64[i], minF, maxF);
761+
}
762+
shift = minF;
763+
alpha = (maxF - minF) / 255;
764+
for(i = 0; i < pFrom->dims; i++){
765+
dst[i] = clip((srcF64[i] - shift) / alpha, 0, 255);
766+
}
767+
}else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){
768+
src1Bit = pFrom->data;
769+
for(i = 0; i < pFrom->dims; i++){
770+
MINMAX(i, ((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1, minF, maxF);
771+
}
772+
shift = minF;
773+
alpha = (maxF - minF) / 255;
774+
for(i = 0; i < pFrom->dims; i++){
775+
dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255);
776+
}
777+
}else{
778+
assert( 0 );
779+
}
780+
vectorF8SetParameters(pTo->data, pTo->dims, alpha, shift);
781+
}
782+
783+
647784
void vectorConvert(const Vector *pFrom, Vector *pTo){
648785
assert( pFrom->dims == pTo->dims );
649786

@@ -652,12 +789,16 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
652789
return;
653790
}
654791

655-
if( pFrom->type == VECTOR_TYPE_FLOAT32 ){
792+
if( pTo->type == VECTOR_TYPE_FLOAT8 ){
793+
vectorConvertToF8(pFrom, pTo);
794+
}else if( pFrom->type == VECTOR_TYPE_FLOAT32 ){
656795
vectorConvertFromF32(pFrom, pTo);
657796
}else if( pFrom->type == VECTOR_TYPE_FLOAT64 ){
658797
vectorConvertFromF64(pFrom, pTo);
659798
}else if( pFrom->type == VECTOR_TYPE_FLOAT1BIT ){
660799
vectorConvertFrom1Bit(pFrom, pTo);
800+
}else if( pFrom->type == VECTOR_TYPE_FLOAT8 ){
801+
vectorConvertFromF8(pFrom, pTo);
661802
}else{
662803
assert( 0 );
663804
}
@@ -734,6 +875,14 @@ static void vector64Func(
734875
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT64);
735876
}
736877

878+
static void vector8Func(
879+
sqlite3_context *context,
880+
int argc,
881+
sqlite3_value **argv
882+
){
883+
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT8);
884+
}
885+
737886
static void vector1BitFunc(
738887
sqlite3_context *context,
739888
int argc,
@@ -873,6 +1022,7 @@ void sqlite3RegisterVectorFunctions(void){
8731022
FUNCTION(vector32, 1, 0, 0, vector32Func),
8741023
FUNCTION(vector64, 1, 0, 0, vector64Func),
8751024
FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc),
1025+
FUNCTION(vector8, 1, 0, 0, vector8Func),
8761026
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
8771027
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
8781028

libsql-sqlite3/src/vectorIndex.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
383383
{ "F64_BLOB", VECTOR_TYPE_FLOAT64 },
384384
{ "FLOAT1BIT", VECTOR_TYPE_FLOAT1BIT },
385385
{ "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT },
386+
{ "FLOAT8", VECTOR_TYPE_FLOAT8 },
387+
{ "F8_BLOB", VECTOR_TYPE_FLOAT8 },
386388
};
387389

388390
/*

0 commit comments

Comments
 (0)