@@ -122,6 +122,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
122122 return vector1BitDistanceHamming (pVector1 , pVector2 );
123123 case VECTOR_TYPE_FLOAT8 :
124124 return vectorF8DistanceCos (pVector1 , pVector2 );
125+ case VECTOR_TYPE_FLOAT16 :
126+ return vectorF16DistanceCos (pVector1 , pVector2 );
125127 default :
126128 assert (0 );
127129 }
@@ -137,6 +139,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
137139 return vectorF64DistanceL2 (pVector1 , pVector2 );
138140 case VECTOR_TYPE_FLOAT8 :
139141 return vectorF8DistanceL2 (pVector1 , pVector2 );
142+ case VECTOR_TYPE_FLOAT16 :
143+ return vectorF16DistanceL2 (pVector1 , pVector2 );
140144 default :
141145 assert (0 );
142146 }
@@ -303,6 +307,13 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
303307 nTrailingBytes = pBlob [nBlobSize - 1 ];
304308 * pDims = (nBlobSize - 2 ) - sizeof (float ) - sizeof (float ) - nTrailingBytes ;
305309 * pDataSize = nBlobSize - 2 ;
310+ }else if ( * pType == VECTOR_TYPE_FLOAT16 ){
311+ if ( nBlobSize % 2 != 0 ){
312+ * pzErrMsg = sqlite3_mprintf ("vector: float16 vector blob length must be divisible by 2 (excluding 'type'-byte): length=%d" , nBlobSize );
313+ return SQLITE_ERROR ;
314+ }
315+ * pDims = nBlobSize / sizeof (u16 );
316+ * pDataSize = nBlobSize ;
306317 }else {
307318 * pzErrMsg = sqlite3_mprintf ("vector: unexpected binary type: %d" , * pType );
308319 return SQLITE_ERROR ;
@@ -351,6 +362,9 @@ int vectorParseSqliteBlobWithType(
351362 case VECTOR_TYPE_FLOAT8 :
352363 vectorF8DeserializeFromBlob (pVector , pBlob , nDataSize );
353364 return 0 ;
365+ case VECTOR_TYPE_FLOAT16 :
366+ vectorF16DeserializeFromBlob (pVector , pBlob , nDataSize );
367+ return 0 ;
354368 default :
355369 assert (0 );
356370 }
@@ -452,6 +466,9 @@ void vectorDump(const Vector *pVector){
452466 case VECTOR_TYPE_FLOAT8 :
453467 vectorF8Dump (pVector );
454468 break ;
469+ case VECTOR_TYPE_FLOAT16 :
470+ vectorF16Dump (pVector );
471+ break ;
455472 default :
456473 assert (0 );
457474 }
@@ -477,7 +494,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
477494 int nDataSize ;
478495 if ( type == VECTOR_TYPE_FLOAT32 ){
479496 return 0 ;
480- }else if ( type == VECTOR_TYPE_FLOAT64 ){
497+ }else if ( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 ){
481498 return 1 ;
482499 }else if ( type == VECTOR_TYPE_FLOAT1BIT ){
483500 nDataSize = vectorDataSize (type , dims );
@@ -496,14 +513,14 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
496513static void vectorSerializeMeta (const Vector * pVector , size_t nDataSize , unsigned char * pBlob , size_t nBlobSize ){
497514 if ( pVector -> type == VECTOR_TYPE_FLOAT32 ){
498515 // no meta for f32 type as this is "default" vector type
499- }else if ( pVector -> type == VECTOR_TYPE_FLOAT64 ){
516+ }else if ( pVector -> type == VECTOR_TYPE_FLOAT64 || pVector -> type == VECTOR_TYPE_FLOAT16 ){
500517 assert ( nDataSize % 2 == 0 );
501518 assert ( nBlobSize == nDataSize + 1 );
502- pBlob [nBlobSize - 1 ] = VECTOR_TYPE_FLOAT64 ;
519+ pBlob [nBlobSize - 1 ] = pVector -> type ;
503520 }else if ( pVector -> type == VECTOR_TYPE_FLOAT1BIT ){
504521 assert ( nBlobSize % 2 == 1 );
505522 assert ( nBlobSize >= 3 );
506- pBlob [nBlobSize - 1 ] = VECTOR_TYPE_FLOAT1BIT ;
523+ pBlob [nBlobSize - 1 ] = pVector -> type ;
507524 pBlob [nBlobSize - 2 ] = 8 * (nBlobSize - 1 ) - pVector -> dims ;
508525 if ( vectorMetaSize (pVector -> type , pVector -> dims ) == 3 ){
509526 pBlob [nBlobSize - 3 ] = 0 ;
@@ -512,8 +529,9 @@ static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigne
512529 assert ( nBlobSize % 2 == 1 );
513530 assert ( nDataSize % 2 == 0 );
514531 assert ( nBlobSize == nDataSize + 3 );
515- pBlob [nBlobSize - 1 ] = VECTOR_TYPE_FLOAT8 ;
532+ pBlob [nBlobSize - 1 ] = pVector -> type ;
516533 pBlob [nBlobSize - 2 ] = ALIGN (pVector -> dims , sizeof (float )) - pVector -> dims ;
534+ pBlob [nBlobSize - 3 ] = 0 ;
517535 }else {
518536 assert ( 0 );
519537 }
@@ -561,6 +579,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
561579 case VECTOR_TYPE_FLOAT8 :
562580 vectorF8SerializeToBlob (pVector , pBlob , nBlobSize );
563581 break ;
582+ case VECTOR_TYPE_FLOAT16 :
583+ vectorF16SerializeToBlob (pVector , pBlob , nBlobSize );
584+ break ;
564585 default :
565586 assert (0 );
566587 }
@@ -576,6 +597,7 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
576597
577598 u8 * dst1Bit ;
578599 double * dstF64 ;
600+ u16 * dstF16 ;
579601
580602 assert ( pFrom -> dims == pTo -> dims );
581603 assert ( pFrom -> type != pTo -> type );
@@ -597,6 +619,11 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
597619 dst1Bit [i / 8 ] |= (1 << (i & 7 ));
598620 }
599621 }
622+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
623+ dstF16 = pTo -> data ;
624+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
625+ dstF16 [i ] = vectorF16FromFloat (src [i ]);
626+ }
600627 }else {
601628 assert ( 0 );
602629 }
@@ -608,6 +635,7 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
608635
609636 u8 * dst1Bit ;
610637 float * dstF32 ;
638+ u16 * dstF16 ;
611639
612640 assert ( pFrom -> dims == pTo -> dims );
613641 assert ( pFrom -> type != pTo -> type );
@@ -629,6 +657,11 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
629657 dst1Bit [i / 8 ] |= (1 << (i & 7 ));
630658 }
631659 }
660+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
661+ dstF16 = pTo -> data ;
662+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
663+ dstF16 [i ] = vectorF16FromFloat (src [i ]);
664+ }
632665 }else {
633666 assert ( 0 );
634667 }
@@ -640,6 +673,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
640673
641674 float * dstF32 ;
642675 double * dstF64 ;
676+ u16 * dstF16 ;
643677
644678 assert ( pFrom -> dims == pTo -> dims );
645679 assert ( pFrom -> type != pTo -> type );
@@ -664,6 +698,17 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
664698 dstF64 [i ] = -1 ;
665699 }
666700 }
701+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
702+ u16 positive = vectorF16FromFloat (+1 );
703+ u16 negative = vectorF16FromFloat (-1 );
704+ dstF16 = pTo -> data ;
705+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
706+ if ( ((src [i / 8 ] >> (i & 7 )) & 1 ) == 1 ){
707+ dstF16 [i ] = positive ;
708+ }else {
709+ dstF16 [i ] = negative ;
710+ }
711+ }
667712 }else {
668713 assert ( 0 );
669714 }
@@ -677,6 +722,7 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
677722 float * dstF32 ;
678723 double * dstF64 ;
679724 u8 * dst1Bit ;
725+ u16 * dstF16 ;
680726
681727 assert ( pFrom -> dims == pTo -> dims );
682728 assert ( pFrom -> type != pTo -> type );
@@ -705,6 +751,49 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
705751 dst1Bit [i / 8 ] |= (1 << (i & 7 ));
706752 }
707753 }
754+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT16 ){
755+ dstF16 = pTo -> data ;
756+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
757+ dstF16 [i ] = vectorF16FromFloat (alpha * src [i ] + shift );
758+ }
759+ }else {
760+ assert ( 0 );
761+ }
762+ }
763+
764+ static void vectorConvertFromF16 (const Vector * pFrom , Vector * pTo ){
765+ int i ;
766+ u16 * src ;
767+
768+ float * dstF32 ;
769+ double * dstF64 ;
770+ u8 * dst1Bit ;
771+
772+ assert ( pFrom -> dims == pTo -> dims );
773+ assert ( pFrom -> type != pTo -> type );
774+ assert ( pFrom -> type == VECTOR_TYPE_FLOAT16 );
775+
776+ src = pFrom -> data ;
777+ if ( pTo -> type == VECTOR_TYPE_FLOAT32 ){
778+ dstF32 = pTo -> data ;
779+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
780+ dstF32 [i ] = vectorF16ToFloat (src [i ]);
781+ }
782+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT64 ){
783+ dstF64 = pTo -> data ;
784+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
785+ dstF64 [i ] = vectorF16ToFloat (src [i ]);
786+ }
787+ }else if ( pTo -> type == VECTOR_TYPE_FLOAT1BIT ){
788+ dst1Bit = pTo -> data ;
789+ for (i = 0 ; i < pFrom -> dims ; i += 8 ){
790+ dst1Bit [i / 8 ] = 0 ;
791+ }
792+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
793+ if ( vectorF16ToFloat (src [i ]) > 0 ){
794+ dst1Bit [i / 8 ] |= (1 << (i & 7 ));
795+ }
796+ }
708797 }else {
709798 assert ( 0 );
710799 }
@@ -730,6 +819,7 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
730819 float * srcF32 ;
731820 double * srcF64 ;
732821 u8 * src1Bit ;
822+ u16 * srcF16 ;
733823
734824 assert ( pFrom -> dims == pTo -> dims );
735825 assert ( pFrom -> type != pTo -> type );
@@ -766,6 +856,16 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
766856 for (i = 0 ; i < pFrom -> dims ; i ++ ){
767857 dst [i ] = clip (((((src1Bit [i / 8 ] >> (i & 7 )) & 1 ) ? +1 : -1 ) - shift ) / alpha , 0 , 255 );
768858 }
859+ }else if ( pFrom -> type == VECTOR_TYPE_FLOAT16 ){
860+ srcF16 = pFrom -> data ;
861+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
862+ MINMAX (i , vectorF16ToFloat (srcF16 [i ]), minF , maxF );
863+ }
864+ shift = minF ;
865+ alpha = (maxF - minF ) / 255 ;
866+ for (i = 0 ; i < pFrom -> dims ; i ++ ){
867+ dst [i ] = clip ((vectorF16ToFloat (srcF16 [i ]) - shift ) / alpha , 0 , 255 );
868+ }
769869 }else {
770870 assert ( 0 );
771871 }
@@ -791,6 +891,8 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
791891 vectorConvertFrom1Bit (pFrom , pTo );
792892 }else if ( pFrom -> type == VECTOR_TYPE_FLOAT8 ){
793893 vectorConvertFromF8 (pFrom , pTo );
894+ }else if ( pFrom -> type == VECTOR_TYPE_FLOAT16 ){
895+ vectorConvertFromF16 (pFrom , pTo );
794896 }else {
795897 assert ( 0 );
796898 }
@@ -875,6 +977,14 @@ static void vector8Func(
875977 vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_FLOAT8 );
876978}
877979
980+ static void vector16Func (
981+ sqlite3_context * context ,
982+ int argc ,
983+ sqlite3_value * * argv
984+ ){
985+ vectorFuncHintedType (context , argc , argv , VECTOR_TYPE_FLOAT16 );
986+ }
987+
878988static void vector1BitFunc (
879989 sqlite3_context * context ,
880990 int argc ,
@@ -1033,6 +1143,7 @@ void sqlite3RegisterVectorFunctions(void){
10331143 FUNCTION (vector64 , 1 , 0 , 0 , vector64Func ),
10341144 FUNCTION (vector1bit , 1 , 0 , 0 , vector1BitFunc ),
10351145 FUNCTION (vector8 , 1 , 0 , 0 , vector8Func ),
1146+ FUNCTION (vector16 , 1 , 0 , 0 , vector16Func ),
10361147 FUNCTION (vector_extract , 1 , 0 , 0 , vectorExtractFunc ),
10371148 FUNCTION (vector_distance_cos , 2 , 0 , 0 , vectorDistanceCosFunc ),
10381149 FUNCTION (vector_distance_l2 , 2 , 0 , 0 , vectorDistanceL2Func ),
0 commit comments