@@ -47,6 +47,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4747 return ALIGN (dims , sizeof (float )) + sizeof (float ) /* alpha */ + sizeof (float ) /* shift */ ;
4848 case VECTOR_TYPE_FLOAT16 :
4949 return dims * sizeof (u16 );
50+ case VECTOR_TYPE_FLOATB16 :
51+ return dims * sizeof (u16 );
5052 default :
5153 assert (0 );
5254 }
@@ -124,6 +126,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
124126 return vectorF8DistanceCos (pVector1 , pVector2 );
125127 case VECTOR_TYPE_FLOAT16 :
126128 return vectorF16DistanceCos (pVector1 , pVector2 );
129+ case VECTOR_TYPE_FLOATB16 :
130+ return vectorFB16DistanceCos (pVector1 , pVector2 );
127131 default :
128132 assert (0 );
129133 }
@@ -141,6 +145,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
141145 return vectorF8DistanceL2 (pVector1 , pVector2 );
142146 case VECTOR_TYPE_FLOAT16 :
143147 return vectorF16DistanceL2 (pVector1 , pVector2 );
148+ case VECTOR_TYPE_FLOATB16 :
149+ return vectorFB16DistanceL2 (pVector1 , pVector2 );
144150 default :
145151 assert (0 );
146152 }
@@ -314,6 +320,13 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
314320 }
315321 * pDims = nBlobSize / sizeof (u16 );
316322 * pDataSize = nBlobSize ;
323+ }else if ( * pType == VECTOR_TYPE_FLOATB16 ){
324+ if ( nBlobSize % 2 != 0 ){
325+ * pzErrMsg = sqlite3_mprintf ("vector: floatb16 vector blob length must be divisible by 2 (excluding 'type'-byte): length=%d" , nBlobSize );
326+ return SQLITE_ERROR ;
327+ }
328+ * pDims = nBlobSize / sizeof (u16 );
329+ * pDataSize = nBlobSize ;
317330 }else {
318331 * pzErrMsg = sqlite3_mprintf ("vector: unexpected binary type: %d" , * pType );
319332 return SQLITE_ERROR ;
@@ -365,6 +378,9 @@ int vectorParseSqliteBlobWithType(
365378 case VECTOR_TYPE_FLOAT16 :
366379 vectorF16DeserializeFromBlob (pVector , pBlob , nDataSize );
367380 return 0 ;
381+ case VECTOR_TYPE_FLOATB16 :
382+ vectorFB16DeserializeFromBlob (pVector , pBlob , nDataSize );
383+ return 0 ;
368384 default :
369385 assert (0 );
370386 }
@@ -469,6 +485,9 @@ void vectorDump(const Vector *pVector){
469485 case VECTOR_TYPE_FLOAT16 :
470486 vectorF16Dump (pVector );
471487 break ;
488+ case VECTOR_TYPE_FLOATB16 :
489+ vectorFB16Dump (pVector );
490+ break ;
472491 default :
473492 assert (0 );
474493 }
@@ -494,7 +513,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
494513 int nDataSize ;
495514 if ( type == VECTOR_TYPE_FLOAT32 ){
496515 return 0 ;
497- }else if ( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 ){
516+ }else if ( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 || type == VECTOR_TYPE_FLOATB16 ){
498517 return 1 ;
499518 }else if ( type == VECTOR_TYPE_FLOAT1BIT ){
500519 nDataSize = vectorDataSize (type , dims );
@@ -513,7 +532,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
513532static void vectorSerializeMeta (const Vector * pVector , size_t nDataSize , unsigned char * pBlob , size_t nBlobSize ){
514533 if ( pVector -> type == VECTOR_TYPE_FLOAT32 ){
515534 // no meta for f32 type as this is "default" vector type
516- }else if ( pVector -> type == VECTOR_TYPE_FLOAT64 || pVector -> type == VECTOR_TYPE_FLOAT16 ){
535+ }else if ( pVector -> type == VECTOR_TYPE_FLOAT64 || pVector -> type == VECTOR_TYPE_FLOAT16 || pVector -> type == VECTOR_TYPE_FLOATB16 ){
517536 assert ( nDataSize % 2 == 0 );
518537 assert ( nBlobSize == nDataSize + 1 );
519538 pBlob [nBlobSize - 1 ] = pVector -> type ;
@@ -582,6 +601,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
582601 case VECTOR_TYPE_FLOAT16 :
583602 vectorF16SerializeToBlob (pVector , pBlob , nBlobSize );
584603 break ;
604+ case VECTOR_TYPE_FLOATB16 :
605+ vectorFB16SerializeToBlob (pVector , pBlob , nBlobSize );
606+ break ;
585607 default :
586608 assert (0 );
587609 }
@@ -624,6 +646,11 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
624646 for (i = 0 ; i < pFrom -> dims ; i ++ ){
625647 dstF16 [i ] = vectorF16FromFloat (src [i ]);
626648 }
649+ }else if ( pTo -> type == VECTOR_TYPE_FLOATB16 ){
650+ dstF16 = pTo -> data ;
651+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
652+ dstF16 [i ] = vectorFB16FromFloat (src [i ]);
653+ }
627654 }else {
628655 assert ( 0 );
629656 }
@@ -662,6 +689,11 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
662689 for (i = 0 ; i < pFrom -> dims ; i ++ ){
663690 dstF16 [i ] = vectorF16FromFloat (src [i ]);
664691 }
692+ }else if ( pTo -> type == VECTOR_TYPE_FLOATB16 ){
693+ dstF16 = pTo -> data ;
694+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
695+ dstF16 [i ] = vectorFB16FromFloat (src [i ]);
696+ }
665697 }else {
666698 assert ( 0 );
667699 }
@@ -673,7 +705,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
673705
674706 float * dstF32 ;
675707 double * dstF64 ;
676- u16 * dstF16 ;
708+ u16 * dstU16 ;
677709
678710 assert ( pFrom -> dims == pTo -> dims );
679711 assert ( pFrom -> type != pTo -> type );
@@ -701,12 +733,23 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
701733 }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
702734 u16 positive = vectorF16FromFloat (+1 );
703735 u16 negative = vectorF16FromFloat (-1 );
704- dstF16 = pTo -> data ;
736+ dstU16 = pTo -> data ;
737+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
738+ if ( ((src [i / 8 ] >> (i & 7 )) & 1 ) == 1 ){
739+ dstU16 [i ] = positive ;
740+ }else {
741+ dstU16 [i ] = negative ;
742+ }
743+ }
744+ }else if ( pTo -> type == VECTOR_TYPE_FLOATB16 ){
745+ u16 positive = vectorFB16FromFloat (+1 );
746+ u16 negative = vectorFB16FromFloat (-1 );
747+ dstU16 = pTo -> data ;
705748 for (i = 0 ; i < pFrom -> dims ; i ++ ){
706749 if ( ((src [i / 8 ] >> (i & 7 )) & 1 ) == 1 ){
707- dstF16 [i ] = positive ;
750+ dstU16 [i ] = positive ;
708751 }else {
709- dstF16 [i ] = negative ;
752+ dstU16 [i ] = negative ;
710753 }
711754 }
712755 }else {
@@ -756,6 +799,11 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
756799 for (i = 0 ; i < pFrom -> dims ; i ++ ){
757800 dstF16 [i ] = vectorF16FromFloat (alpha * src [i ] + shift );
758801 }
802+ }else if ( pTo -> type == VECTOR_TYPE_FLOATB16 ){
803+ dstF16 = pTo -> data ;
804+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
805+ dstF16 [i ] = vectorFB16FromFloat (alpha * src [i ] + shift );
806+ }
759807 }else {
760808 assert ( 0 );
761809 }
@@ -768,6 +816,7 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
768816 float * dstF32 ;
769817 double * dstF64 ;
770818 u8 * dst1Bit ;
819+ u16 * dstU16 ;
771820
772821 assert ( pFrom -> dims == pTo -> dims );
773822 assert ( pFrom -> type != pTo -> type );
@@ -784,6 +833,11 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
784833 for (i = 0 ; i < pFrom -> dims ; i ++ ){
785834 dstF64 [i ] = vectorF16ToFloat (src [i ]);
786835 }
836+ }else if ( pTo -> type == VECTOR_TYPE_FLOATB16 ){
837+ dstU16 = pTo -> data ;
838+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
839+ dstU16 [i ] = vectorFB16FromFloat (vectorF16ToFloat (src [i ]));
840+ }
787841 }else if ( pTo -> type == VECTOR_TYPE_FLOAT1BIT ){
788842 dst1Bit = pTo -> data ;
789843 for (i = 0 ; i < pFrom -> dims ; i += 8 ){
@@ -799,6 +853,50 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
799853 }
800854}
801855
856+ static void vectorConvertFromFB16 (const Vector * pFrom , Vector * pTo ){
857+ int i ;
858+ u16 * src ;
859+
860+ float * dstF32 ;
861+ double * dstF64 ;
862+ u8 * dst1Bit ;
863+ u16 * dstU16 ;
864+
865+ assert ( pFrom -> dims == pTo -> dims );
866+ assert ( pFrom -> type != pTo -> type );
867+ assert ( pFrom -> type == VECTOR_TYPE_FLOATB16 );
868+
869+ src = pFrom -> data ;
870+ if ( pTo -> type == VECTOR_TYPE_FLOAT32 ){
871+ dstF32 = pTo -> data ;
872+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
873+ dstF32 [i ] = vectorFB16ToFloat (src [i ]);
874+ }
875+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT64 ){
876+ dstF64 = pTo -> data ;
877+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
878+ dstF64 [i ] = vectorFB16ToFloat (src [i ]);
879+ }
880+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
881+ dstU16 = pTo -> data ;
882+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
883+ dstU16 [i ] = vectorF16FromFloat (vectorFB16ToFloat (src [i ]));
884+ }
885+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT1BIT ){
886+ dst1Bit = pTo -> data ;
887+ for (i = 0 ; i < pFrom -> dims ; i += 8 ){
888+ dst1Bit [i / 8 ] = 0 ;
889+ }
890+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
891+ if ( vectorFB16ToFloat (src [i ]) > 0 ){
892+ dst1Bit [i / 8 ] |= (1 << (i & 7 ));
893+ }
894+ }
895+ }else {
896+ assert ( 0 );
897+ }
898+ }
899+
802900static inline int clip (float f , int minF , int maxF ){
803901 if ( f < minF ){
804902 return minF ;
@@ -819,7 +917,7 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
819917 float * srcF32 ;
820918 double * srcF64 ;
821919 u8 * src1Bit ;
822- u16 * srcF16 ;
920+ u16 * srcU16 ;
823921
824922 assert ( pFrom -> dims == pTo -> dims );
825923 assert ( pFrom -> type != pTo -> type );
@@ -857,14 +955,24 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
857955 dst [i ] = clip (((((src1Bit [i / 8 ] >> (i & 7 )) & 1 ) ? +1 : -1 ) - shift ) / alpha , 0 , 255 );
858956 }
859957 }else if ( pFrom -> type == VECTOR_TYPE_FLOAT16 ){
860- srcF16 = pFrom -> data ;
958+ srcU16 = pFrom -> data ;
959+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
960+ MINMAX (i , vectorF16ToFloat (srcU16 [i ]), minF , maxF );
961+ }
962+ shift = minF ;
963+ alpha = (maxF - minF ) / 255 ;
964+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
965+ dst [i ] = clip ((vectorF16ToFloat (srcU16 [i ]) - shift ) / alpha , 0 , 255 );
966+ }
967+ }else if ( pFrom -> type == VECTOR_TYPE_FLOATB16 ){
968+ srcU16 = pFrom -> data ;
861969 for (i = 0 ; i < pFrom -> dims ; i ++ ){
862- MINMAX (i , vectorF16ToFloat ( srcF16 [i ]), minF , maxF );
970+ MINMAX (i , vectorFB16ToFloat ( srcU16 [i ]), minF , maxF );
863971 }
864972 shift = minF ;
865973 alpha = (maxF - minF ) / 255 ;
866974 for (i = 0 ; i < pFrom -> dims ; i ++ ){
867- dst [i ] = clip ((vectorF16ToFloat ( srcF16 [i ]) - shift ) / alpha , 0 , 255 );
975+ dst [i ] = clip ((vectorFB16ToFloat ( srcU16 [i ]) - shift ) / alpha , 0 , 255 );
868976 }
869977 }else {
870978 assert ( 0 );
@@ -893,6 +1001,8 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
8931001 vectorConvertFromF8 (pFrom , pTo );
8941002 }else if ( pFrom -> type == VECTOR_TYPE_FLOAT16 ){
8951003 vectorConvertFromF16 (pFrom , pTo );
1004+ }else if ( pFrom -> type == VECTOR_TYPE_FLOATB16 ){
1005+ vectorConvertFromFB16 (pFrom , pTo );
8961006 }else {
8971007 assert ( 0 );
8981008 }
@@ -985,6 +1095,14 @@ static void vector16Func(
9851095 vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_FLOAT16 );
9861096}
9871097
1098+ static void vectorb16Func (
1099+ sqlite3_context * context ,
1100+ int argc ,
1101+ sqlite3_value * * argv
1102+ ){
1103+ vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_FLOATB16 );
1104+ }
1105+
9881106static void vector1BitFunc (
9891107 sqlite3_context * context ,
9901108 int argc ,
@@ -1144,6 +1262,7 @@ void sqlite3RegisterVectorFunctions(void){
11441262 FUNCTION (vector1bit , 1 , 0 , 0 , vector1BitFunc ),
11451263 FUNCTION (vector8 , 1 , 0 , 0 , vector8Func ),
11461264 FUNCTION (vector16 , 1 , 0 , 0 , vector16Func ),
1265+ FUNCTION (vectorb16 , 1 , 0 , 0 , vectorb16Func ),
11471266 FUNCTION (vector_extract , 1 , 0 , 0 , vectorExtractFunc ),
11481267 FUNCTION (vector_distance_cos , 2 , 0 , 0 , vectorDistanceCosFunc ),
11491268 FUNCTION (vector_distance_l2 , 2 , 0 , 0 , vectorDistanceL2Func ),
0 commit comments