Skip to content

Commit 69bce66

Browse files
committed
properly integrate float16
1 parent f8128d2 commit 69bce66

6 files changed

Lines changed: 218 additions & 28 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
122122
return vector1BitDistanceHamming(pVector1, pVector2);
123123
case VECTOR_TYPE_FLOAT8:
124124
return vectorF8DistanceCos(pVector1, pVector2);
125+
case VECTOR_TYPE_FLOAT16:
126+
return vectorF16DistanceCos(pVector1, pVector2);
125127
default:
126128
assert(0);
127129
}
@@ -137,6 +139,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
137139
return vectorF64DistanceL2(pVector1, pVector2);
138140
case VECTOR_TYPE_FLOAT8:
139141
return vectorF8DistanceL2(pVector1, pVector2);
142+
case VECTOR_TYPE_FLOAT16:
143+
return vectorF16DistanceL2(pVector1, pVector2);
140144
default:
141145
assert(0);
142146
}
@@ -303,6 +307,13 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
303307
nTrailingBytes = pBlob[nBlobSize - 1];
304308
*pDims = (nBlobSize - 2) - sizeof(float) - sizeof(float) - nTrailingBytes;
305309
*pDataSize = nBlobSize - 2;
310+
}else if( *pType == VECTOR_TYPE_FLOAT16 ){
311+
if( nBlobSize % 2 != 0 ){
312+
*pzErrMsg = sqlite3_mprintf("vector: float16 vector blob length must be divisible by 2 (excluding 'type'-byte): length=%d", nBlobSize);
313+
return SQLITE_ERROR;
314+
}
315+
*pDims = nBlobSize / sizeof(u16);
316+
*pDataSize = nBlobSize;
306317
}else{
307318
*pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType);
308319
return SQLITE_ERROR;
@@ -351,6 +362,9 @@ int vectorParseSqliteBlobWithType(
351362
case VECTOR_TYPE_FLOAT8:
352363
vectorF8DeserializeFromBlob(pVector, pBlob, nDataSize);
353364
return 0;
365+
case VECTOR_TYPE_FLOAT16:
366+
vectorF16DeserializeFromBlob(pVector, pBlob, nDataSize);
367+
return 0;
354368
default:
355369
assert(0);
356370
}
@@ -452,6 +466,9 @@ void vectorDump(const Vector *pVector){
452466
case VECTOR_TYPE_FLOAT8:
453467
vectorF8Dump(pVector);
454468
break;
469+
case VECTOR_TYPE_FLOAT16:
470+
vectorF16Dump(pVector);
471+
break;
455472
default:
456473
assert(0);
457474
}
@@ -477,7 +494,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
477494
int nDataSize;
478495
if( type == VECTOR_TYPE_FLOAT32 ){
479496
return 0;
480-
}else if( type == VECTOR_TYPE_FLOAT64 ){
497+
}else if( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 ){
481498
return 1;
482499
}else if( type == VECTOR_TYPE_FLOAT1BIT ){
483500
nDataSize = vectorDataSize(type, dims);
@@ -496,14 +513,14 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
496513
static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){
497514
if( pVector->type == VECTOR_TYPE_FLOAT32 ){
498515
// no meta for f32 type as this is "default" vector type
499-
}else if( pVector->type == VECTOR_TYPE_FLOAT64 ){
516+
}else if( pVector->type == VECTOR_TYPE_FLOAT64 || pVector->type == VECTOR_TYPE_FLOAT16 ){
500517
assert( nDataSize % 2 == 0 );
501518
assert( nBlobSize == nDataSize + 1 );
502-
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT64;
519+
pBlob[nBlobSize - 1] = pVector->type;
503520
}else if( pVector->type == VECTOR_TYPE_FLOAT1BIT ){
504521
assert( nBlobSize % 2 == 1 );
505522
assert( nBlobSize >= 3 );
506-
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT1BIT;
523+
pBlob[nBlobSize - 1] = pVector->type;
507524
pBlob[nBlobSize - 2] = 8 * (nBlobSize - 1) - pVector->dims;
508525
if( vectorMetaSize(pVector->type, pVector->dims) == 3 ){
509526
pBlob[nBlobSize - 3] = 0;
@@ -512,8 +529,9 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne
512529
assert( nBlobSize % 2 == 1 );
513530
assert( nDataSize % 2 == 0 );
514531
assert( nBlobSize == nDataSize + 3 );
515-
pBlob[nBlobSize - 1] = VECTOR_TYPE_FLOAT8;
532+
pBlob[nBlobSize - 1] = pVector->type;
516533
pBlob[nBlobSize - 2] = ALIGN(pVector->dims, sizeof(float)) - pVector->dims;
534+
pBlob[nBlobSize - 3] = 0;
517535
}else{
518536
assert( 0 );
519537
}
@@ -561,6 +579,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
561579
case VECTOR_TYPE_FLOAT8:
562580
vectorF8SerializeToBlob(pVector, pBlob, nBlobSize);
563581
break;
582+
case VECTOR_TYPE_FLOAT16:
583+
vectorF16SerializeToBlob(pVector, pBlob, nBlobSize);
584+
break;
564585
default:
565586
assert(0);
566587
}
@@ -576,6 +597,7 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
576597

577598
u8 *dst1Bit;
578599
double *dstF64;
600+
u16 *dstF16;
579601

580602
assert( pFrom->dims == pTo->dims );
581603
assert( pFrom->type != pTo->type );
@@ -597,6 +619,11 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
597619
dst1Bit[i / 8] |= (1 << (i & 7));
598620
}
599621
}
622+
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
623+
dstF16 = pTo->data;
624+
for(i = 0; i < pFrom->dims; i++){
625+
dstF16[i] = vectorF16FromFloat(src[i]);
626+
}
600627
}else{
601628
assert( 0 );
602629
}
@@ -608,6 +635,7 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
608635

609636
u8 *dst1Bit;
610637
float *dstF32;
638+
u16 *dstF16;
611639

612640
assert( pFrom->dims == pTo->dims );
613641
assert( pFrom->type != pTo->type );
@@ -629,6 +657,11 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
629657
dst1Bit[i / 8] |= (1 << (i & 7));
630658
}
631659
}
660+
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
661+
dstF16 = pTo->data;
662+
for(i = 0; i < pFrom->dims; i++){
663+
dstF16[i] = vectorF16FromFloat(src[i]);
664+
}
632665
}else{
633666
assert( 0 );
634667
}
@@ -640,6 +673,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
640673

641674
float *dstF32;
642675
double *dstF64;
676+
u16 *dstF16;
643677

644678
assert( pFrom->dims == pTo->dims );
645679
assert( pFrom->type != pTo->type );
@@ -664,6 +698,17 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
664698
dstF64[i] = -1;
665699
}
666700
}
701+
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
702+
u16 positive = vectorF16FromFloat(+1);
703+
u16 negative = vectorF16FromFloat(-1);
704+
dstF16 = pTo->data;
705+
for(i = 0; i < pFrom->dims; i++){
706+
if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){
707+
dstF16[i] = positive;
708+
}else{
709+
dstF16[i] = negative;
710+
}
711+
}
667712
}else{
668713
assert( 0 );
669714
}
@@ -677,6 +722,7 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
677722
float *dstF32;
678723
double *dstF64;
679724
u8 *dst1Bit;
725+
u16 *dstF16;
680726

681727
assert( pFrom->dims == pTo->dims );
682728
assert( pFrom->type != pTo->type );
@@ -705,6 +751,49 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
705751
dst1Bit[i / 8] |= (1 << (i & 7));
706752
}
707753
}
754+
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
755+
dstF16 = pTo->data;
756+
for(i = 0; i < pFrom->dims; i++){
757+
dstF16[i] = vectorF16FromFloat(alpha * src[i] + shift);
758+
}
759+
}else{
760+
assert( 0 );
761+
}
762+
}
763+
764+
static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
765+
int i;
766+
u16 *src;
767+
768+
float *dstF32;
769+
double *dstF64;
770+
u8 *dst1Bit;
771+
772+
assert( pFrom->dims == pTo->dims );
773+
assert( pFrom->type != pTo->type );
774+
assert( pFrom->type == VECTOR_TYPE_FLOAT16 );
775+
776+
src = pFrom->data;
777+
if( pTo->type == VECTOR_TYPE_FLOAT32 ){
778+
dstF32 = pTo->data;
779+
for(i = 0; i < pFrom->dims; i++){
780+
dstF32[i] = vectorF16ToFloat(src[i]);
781+
}
782+
}else if( pTo->type == VECTOR_TYPE_FLOAT64 ){
783+
dstF64 = pTo->data;
784+
for(i = 0; i < pFrom->dims; i++){
785+
dstF64[i] = vectorF16ToFloat(src[i]);
786+
}
787+
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
788+
dst1Bit = pTo->data;
789+
for(i = 0; i < pFrom->dims; i += 8){
790+
dst1Bit[i / 8] = 0;
791+
}
792+
for(i = 0; i < pFrom->dims; i++){
793+
if( vectorF16ToFloat(src[i]) > 0 ){
794+
dst1Bit[i / 8] |= (1 << (i & 7));
795+
}
796+
}
708797
}else{
709798
assert( 0 );
710799
}
@@ -730,6 +819,7 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
730819
float *srcF32;
731820
double *srcF64;
732821
u8 *src1Bit;
822+
u16 *srcF16;
733823

734824
assert( pFrom->dims == pTo->dims );
735825
assert( pFrom->type != pTo->type );
@@ -766,6 +856,16 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
766856
for(i = 0; i < pFrom->dims; i++){
767857
dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255);
768858
}
859+
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
860+
srcF16 = pFrom->data;
861+
for(i = 0; i < pFrom->dims; i++){
862+
MINMAX(i, vectorF16ToFloat(srcF16[i]), minF, maxF);
863+
}
864+
shift = minF;
865+
alpha = (maxF - minF) / 255;
866+
for(i = 0; i < pFrom->dims; i++){
867+
dst[i] = clip((vectorF16ToFloat(srcF16[i]) - shift) / alpha, 0, 255);
868+
}
769869
}else{
770870
assert( 0 );
771871
}
@@ -791,6 +891,8 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
791891
vectorConvertFrom1Bit(pFrom, pTo);
792892
}else if( pFrom->type == VECTOR_TYPE_FLOAT8 ){
793893
vectorConvertFromF8(pFrom, pTo);
894+
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
895+
vectorConvertFromF16(pFrom, pTo);
794896
}else{
795897
assert( 0 );
796898
}
@@ -875,6 +977,14 @@ static void vector8Func(
875977
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT8);
876978
}
877979

980+
static void vector16Func(
981+
sqlite3_context *context,
982+
int argc,
983+
sqlite3_value **argv
984+
){
985+
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT16);
986+
}
987+
878988
static void vector1BitFunc(
879989
sqlite3_context *context,
880990
int argc,
@@ -1033,6 +1143,7 @@ void sqlite3RegisterVectorFunctions(void){
10331143
FUNCTION(vector64, 1, 0, 0, vector64Func),
10341144
FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc),
10351145
FUNCTION(vector8, 1, 0, 0, vector8Func),
1146+
FUNCTION(vector16, 1, 0, 0, vector16Func),
10361147
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
10371148
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
10381149
FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func),

libsql-sqlite3/src/vectorIndex.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
385385
{ "F1BIT_BLOB", VECTOR_TYPE_FLOAT1BIT },
386386
{ "FLOAT8", VECTOR_TYPE_FLOAT8 },
387387
{ "F8_BLOB", VECTOR_TYPE_FLOAT8 },
388+
{ "FLOAT16", VECTOR_TYPE_FLOAT16 },
389+
{ "F16_BLOB", VECTOR_TYPE_FLOAT16 },
388390
};
389391

390392
/*
@@ -405,6 +407,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = {
405407
{ "metric", VECTOR_METRIC_TYPE_PARAM_ID, 0, "l2", VECTOR_METRIC_TYPE_L2 },
406408
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT },
407409
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 },
410+
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float16", VECTOR_TYPE_FLOAT16 },
408411
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 },
409412
{ "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 },
410413
{ "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 },

libsql-sqlite3/src/vectorInt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t);
150150
void vectorInitStatic(Vector *, VectorType, VectorDims, void *);
151151
void vectorInitFromBlob(Vector *, const unsigned char *, size_t);
152152

153+
u16 vectorF16FromFloat(float);
154+
float vectorF16ToFloat(u16);
155+
153156
void vectorConvert(const Vector *, Vector *);
154157

155158
/* Detect type and dimension of vector provided with first parameter of sqlite3_value * type */

libsql-sqlite3/src/vectorfloat16.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
******************************************************************************
2424
**
2525
** 16-bit (FLOAT16) floating point vector format utilities.
26+
**
27+
** See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
2628
*/
2729
#ifndef SQLITE_OMIT_VECTOR
2830
#include "sqliteInt.h"
@@ -40,7 +42,7 @@
4042
// f16: [ffffffffffeeeees]
4143
// 0123456789012345
4244

43-
static float vectorF16ToFloat(u16 f16){
45+
float vectorF16ToFloat(u16 f16){
4446
u32 f32;
4547
// sng: [0000000000000000000000000000000s]
4648
u32 sgn = ((u32)f16 & 0x8000) << 16;
@@ -72,7 +74,7 @@ static float vectorF16ToFloat(u16 f16){
7274
return *((float*)&f32);
7375
}
7476

75-
static u16 vectorF16FromFloat(float f){
77+
u16 vectorF16FromFloat(float f){
7678
u32 i = *((u32*)&f);
7779

7880
// sng: [000000000000000s]
@@ -160,7 +162,7 @@ float vectorF16DistanceL2(const Vector *v1, const Vector *v2){
160162
int i;
161163
float sum = 0;
162164
float value1, value2;
163-
u8 *data1 = v1->data, *data2 = v2->data;
165+
u16 *data1 = v1->data, *data2 = v2->data;
164166

165167
assert( v1->dims == v2->dims );
166168
assert( v1->type == VECTOR_TYPE_FLOAT16 );

0 commit comments

Comments
 (0)