Skip to content

Commit f8128d2

Browse files
committed
float16 implementation
1 parent 59f189e commit f8128d2

8 files changed

Lines changed: 276 additions & 1 deletion

File tree

libsql-sqlite3/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,5 @@ libsql
6262
/crates/target/
6363
/has_tclsh*
6464
/libsql.wasm
65+
test_libsql_f16_table.h
66+
test_libsql_f16

libsql-sqlite3/Makefile.in

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
195195
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
196196
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
197197
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
198-
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo \
198+
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo \
199199
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
200200
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
201201
vdbetrace.lo vdbevtab.lo \
@@ -304,6 +304,7 @@ SRC = \
304304
$(TOP)/src/vector.c \
305305
$(TOP)/src/vectorInt.h \
306306
$(TOP)/src/vectorfloat1bit.c \
307+
$(TOP)/src/vectorfloat16.c \
307308
$(TOP)/src/vectorfloat32.c \
308309
$(TOP)/src/vectorfloat64.c \
309310
$(TOP)/src/vectorfloat8.c \
@@ -1143,6 +1144,9 @@ vector.lo: $(TOP)/src/vector.c $(HDR)
11431144
vectorfloat1bit.lo: $(TOP)/src/vectorfloat1bit.c $(HDR)
11441145
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat1bit.c
11451146

1147+
vectorfloat16.lo: $(TOP)/src/vectorfloat16.c $(HDR)
1148+
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat16.c
1149+
11461150
vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
11471151
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c
11481152

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* BUILD: cc test_libsql_diskann.c -I ../ -L ../.libs -llibsql -o test_libsql_diskann
3+
* RUN: LD_LIBRARY_PATH=../.libs ./test_libsql_diskann
4+
*/
5+
6+
#include "assert.h"
7+
#include "stdbool.h"
8+
#include "stdarg.h"
9+
#include "stddef.h"
10+
#include "vectorfloat16.c"
11+
#include "test_libsql_f16_table.h"
12+
13+
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
14+
#define ensure(condition, ...) { if (!(condition)) { eprintf(__VA_ARGS__); exit(1); } }
15+
16+
int main() {
17+
for(int i = 0; i < 65536; i++){
18+
u32 expected = F16ToF32[i];
19+
float actual = vectorF16ToFloat(i);
20+
u32 actual_u32 = *((u32*)&actual);
21+
ensure(expected == actual_u32, "conversion from %x failed: %f != %f (%x != %x)", i, *(float*)&expected, *(float*)&actual_u32, expected, actual_u32);
22+
}
23+
for(int i = 0; i < 65536; i++){
24+
u16 expected = F32ToF16[i];
25+
u16 actual = vectorF16FromFloat(*(float*)&F32[i]);
26+
ensure(expected == actual, "conversion from %x (%f, it=%d) failed: %x != %x", F32[i], *(float*)&F32[i], i, expected, actual);
27+
}
28+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import random
2+
import struct
3+
import numpy as np
4+
5+
u32_list = [random.randint(0, 2**32) for _ in range(65536)]
6+
7+
print("""
8+
u32 F32[65536] = {
9+
""")
10+
for i, x in enumerate(u32_list):
11+
if i % 8 == 0: print(" ", end='');
12+
print('{:>10}u, '.format(x), end='')
13+
if i % 8 == 7: print()
14+
print("};")
15+
16+
17+
print("""
18+
u16 F32ToF16[65536] = {
19+
""")
20+
for i, x in enumerate(u32_list):
21+
if i % 8 == 0: print(" ", end='');
22+
u32_bytes = struct.pack('<I', x)
23+
f32 = np.float16(struct.unpack('<f', u32_bytes)[0])
24+
f16_bytes = struct.pack('<e', f32)
25+
u16 = struct.unpack('<H', f16_bytes)[0]
26+
print('{:>10}, '.format(u16), end='')
27+
if i % 8 == 7: print()
28+
print("};")
29+
30+
print("""
31+
u32 F16ToF32[65536] = {
32+
""")
33+
34+
for x in range(65536):
35+
if x % 8 == 0: print(" ", end='');
36+
u16_bytes = struct.pack('<H', x)
37+
f16 = struct.unpack('<e', u16_bytes)[0]
38+
f32_bytes = struct.pack('<f', f16)
39+
u32 = struct.unpack('<I', f32_bytes)[0]
40+
print('{:>10}u, '.format(u32), end='')
41+
if x % 8 == 7: print()
42+
print("};")

libsql-sqlite3/src/vector.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
4545
return (dims + 7) / 8;
4646
case VECTOR_TYPE_FLOAT8:
4747
return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */;
48+
case VECTOR_TYPE_FLOAT16:
49+
return dims * sizeof(u16);
4850
default:
4951
assert(0);
5052
}

libsql-sqlite3/src/vectorInt.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ typedef u32 VectorDims;
5353
#define VECTOR_TYPE_FLOAT64 2
5454
#define VECTOR_TYPE_FLOAT1BIT 3
5555
#define VECTOR_TYPE_FLOAT8 4
56+
#define VECTOR_TYPE_FLOAT16 5
5657

5758
#define VECTOR_FLAGS_STATIC 1
5859

@@ -80,6 +81,7 @@ void vectorInit(Vector *, VectorType, VectorDims, void *);
8081
*/
8182
void vectorDump (const Vector *v);
8283
void vectorF8Dump (const Vector *v);
84+
void vectorF16Dump (const Vector *v);
8385
void vectorF32Dump (const Vector *v);
8486
void vectorF64Dump (const Vector *v);
8587
void vector1BitDump(const Vector *v);
@@ -99,6 +101,7 @@ void vectorF64MarshalToText(sqlite3_context *, const Vector *);
99101
*/
100102
void vectorSerializeToBlob (const Vector *, unsigned char *, size_t);
101103
void vectorF8SerializeToBlob (const Vector *, unsigned char *, size_t);
104+
void vectorF16SerializeToBlob (const Vector *, unsigned char *, size_t);
102105
void vectorF32SerializeToBlob (const Vector *, unsigned char *, size_t);
103106
void vectorF64SerializeToBlob (const Vector *, unsigned char *, size_t);
104107
void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t);
@@ -108,6 +111,7 @@ void vector1BitSerializeToBlob(const Vector *, unsigned char *, size_t);
108111
*/
109112
float vectorDistanceCos (const Vector *, const Vector *);
110113
float vectorF8DistanceCos (const Vector *, const Vector *);
114+
float vectorF16DistanceCos (const Vector *, const Vector *);
111115
float vectorF32DistanceCos (const Vector *, const Vector *);
112116
double vectorF64DistanceCos(const Vector *, const Vector *);
113117

@@ -121,6 +125,7 @@ int vector1BitDistanceHamming(const Vector *, const Vector *);
121125
*/
122126
float vectorDistanceL2 (const Vector *, const Vector *);
123127
float vectorF8DistanceL2 (const Vector *, const Vector *);
128+
float vectorF16DistanceL2 (const Vector *, const Vector *);
124129
float vectorF32DistanceL2 (const Vector *, const Vector *);
125130
double vectorF64DistanceL2(const Vector *, const Vector *);
126131

@@ -137,6 +142,7 @@ void vectorSerializeWithMeta(sqlite3_context *, const Vector *);
137142
int vectorParseSqliteBlobWithType(sqlite3_value *, Vector *, char **);
138143

139144
void vectorF8DeserializeFromBlob (Vector *, const unsigned char *, size_t);
145+
void vectorF16DeserializeFromBlob (Vector *, const unsigned char *, size_t);
140146
void vectorF32DeserializeFromBlob (Vector *, const unsigned char *, size_t);
141147
void vectorF64DeserializeFromBlob (Vector *, const unsigned char *, size_t);
142148
void vector1BitDeserializeFromBlob(Vector *, const unsigned char *, size_t);

libsql-sqlite3/src/vectorfloat16.c

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
** 2024-07-04
3+
**
4+
** Copyright 2024 the libSQL authors
5+
**
6+
** Permission is hereby granted, free of charge, to any person obtaining a copy of
7+
** this software and associated documentation files (the "Software"), to deal in
8+
** the Software without restriction, including without limitation the rights to
9+
** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
10+
** the Software, and to permit persons to whom the Software is furnished to do so,
11+
** subject to the following conditions:
12+
**
13+
** The above copyright notice and this permission notice shall be included in all
14+
** copies or substantial portions of the Software.
15+
**
16+
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
18+
** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
19+
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
20+
** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
21+
** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22+
**
23+
******************************************************************************
24+
**
25+
** 16-bit (FLOAT16) floating point vector format utilities.
26+
*/
27+
#ifndef SQLITE_OMIT_VECTOR
28+
#include "sqliteInt.h"
29+
30+
#include "vectorInt.h"
31+
32+
#include <math.h>
33+
34+
/**************************************************************************
35+
** Utility routines for vector serialization and deserialization
36+
**************************************************************************/
37+
38+
// f32: [fffffffffffffffffffffffeeeeeeees]
39+
// 01234567890123456789012345678901
40+
// f16: [ffffffffffeeeees]
41+
// 0123456789012345
42+
43+
static float vectorF16ToFloat(u16 f16){
44+
u32 f32;
45+
// sng: [0000000000000000000000000000000s]
46+
u32 sgn = ((u32)f16 & 0x8000) << 16;
47+
48+
int expBits = (f16 >> 10) & 0x1f;
49+
int exp = expBits - 15; // 15 is exp bias for f16
50+
51+
u32 mnt = ((u32)f16 & 0x3ff);
52+
u32 mntNonZero = !!mnt;
53+
54+
if( exp == 16 ){ // NaN or +/- Infinity
55+
exp = 128, mnt = mntNonZero << 22; // set mnt high bit to represent NaN if it was NaN in f16
56+
}else if( exp == -15 && mnt == 0 ){ // zero
57+
exp = -127, mnt = 0;
58+
}else if( exp == -15 ){ // denormalized value
59+
// shift mantissa until we get 1 as a high bit
60+
exp++;
61+
while( (mnt & 0x400) == 0 ){
62+
mnt <<= 1;
63+
exp--;
64+
}
65+
// then reset high bit as this will be normal value (not denormalized) in f32
66+
mnt &= 0x3ff;
67+
mnt <<= 13;
68+
}else{
69+
mnt <<= 13;
70+
}
71+
f32 = sgn | ((u32)(exp + 127) << 23) | mnt;
72+
return *((float*)&f32);
73+
}
74+
75+
static u16 vectorF16FromFloat(float f){
76+
u32 i = *((u32*)&f);
77+
78+
// sng: [000000000000000s]
79+
u32 sgn = (i >> 16) & (0x8000);
80+
81+
// expBits: [eeeeeeee]
82+
int expBits = (i >> 23) & (0xff);
83+
int exp = expBits - 127; // 127 is exp bias for f32
84+
85+
// mntBits: [fffffffffffffffffffffff]
86+
u32 mntBits = (i & 0x7fffff);
87+
u32 mntNonZero = !!mntBits;
88+
u32 mnt;
89+
90+
if( exp == 128 ){ // NaN or +/- Infinity
91+
exp = 16, mntBits = mntNonZero << 22; // set mnt high bit to represent NaN if it was NaN in f32
92+
}else if( exp > 15 ){ // just too big numbers for f16
93+
exp = 16, mntBits = 0;
94+
}else if( exp < -14 && exp >= -25 ){ // small value, but we can be represented as denormalized f16
95+
// set high bit to 1 as normally mantissa has form 1.[mnt] but denormalized mantissa has form 0.[mnt]
96+
mntBits = (mntBits | 0x800000) >> (-exp - 14);
97+
exp = -15;
98+
}else if( exp < -24 ){ // very small or denormalized value
99+
exp = -15, mntBits = 0;
100+
}
101+
// round to nearest, ties to even
102+
if( (mntBits & 0x1fff) > (0x1000 - ((mntBits >> 13) & 1)) ){
103+
mntBits += 0x2000;
104+
}
105+
mnt = mntBits >> 13;
106+
107+
// handle overflow here (note, that overflow can happen only if exp < 16)
108+
return sgn | ((u32)(exp + 15 + (mnt >> 10)) << 10) | (mnt & 0x3ff);
109+
}
110+
111+
void vectorF16Dump(const Vector *pVec){
112+
u16 *elems = pVec->data;
113+
unsigned i;
114+
115+
assert( pVec->type == VECTOR_TYPE_FLOAT16 );
116+
117+
printf("f16: [");
118+
for(i = 0; i < pVec->dims; i++){
119+
printf("%s%f", i == 0 ? "" : ", ", vectorF16ToFloat(elems[i]));
120+
}
121+
printf("]\n");
122+
}
123+
124+
void vectorF16SerializeToBlob(
125+
const Vector *pVector,
126+
unsigned char *pBlob,
127+
size_t nBlobSize
128+
){
129+
float alpha, shift;
130+
131+
assert( pVector->type == VECTOR_TYPE_FLOAT16 );
132+
assert( pVector->dims <= MAX_VECTOR_SZ );
133+
assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) );
134+
135+
memcpy(pBlob, pVector->data, pVector->dims * sizeof(u16));
136+
}
137+
138+
float vectorF16DistanceCos(const Vector *v1, const Vector *v2){
139+
int i;
140+
float dot = 0, norm1 = 0, norm2 = 0;
141+
float value1, value2;
142+
u16 *data1 = v1->data, *data2 = v2->data;
143+
144+
assert( v1->dims == v2->dims );
145+
assert( v1->type == VECTOR_TYPE_FLOAT16 );
146+
assert( v2->type == VECTOR_TYPE_FLOAT16 );
147+
148+
for(i = 0; i < v1->dims; i++){
149+
value1 = vectorF16ToFloat(data1[i]);
150+
value2 = vectorF16ToFloat(data2[i]);
151+
dot += value1*value2;
152+
norm1 += value1*value1;
153+
norm2 += value2*value2;
154+
}
155+
156+
return 1.0 - (dot / sqrt(norm1 * norm2));
157+
}
158+
159+
float vectorF16DistanceL2(const Vector *v1, const Vector *v2){
160+
int i;
161+
float sum = 0;
162+
float value1, value2;
163+
u8 *data1 = v1->data, *data2 = v2->data;
164+
165+
assert( v1->dims == v2->dims );
166+
assert( v1->type == VECTOR_TYPE_FLOAT16 );
167+
assert( v2->type == VECTOR_TYPE_FLOAT16 );
168+
169+
for(i = 0; i < v1->dims; i++){
170+
value1 = vectorF16ToFloat(data1[i]);
171+
value2 = vectorF16ToFloat(data2[i]);
172+
float d = (value1 - value2);
173+
sum += d*d;
174+
}
175+
return sqrt(sum);
176+
}
177+
178+
void vectorF16DeserializeFromBlob(
179+
Vector *pVector,
180+
const unsigned char *pBlob,
181+
size_t nBlobSize
182+
){
183+
assert( pVector->type == VECTOR_TYPE_FLOAT16 );
184+
assert( 0 <= pVector->dims && pVector->dims <= MAX_VECTOR_SZ );
185+
assert( nBlobSize >= vectorDataSize(pVector->type, pVector->dims) );
186+
187+
memcpy((u8*)pVector->data, (u8*)pBlob, pVector->dims * sizeof(u16));
188+
}
189+
190+
#endif /* !defined(SQLITE_OMIT_VECTOR) */

libsql-sqlite3/tool/mksqlite3c.tcl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ set flist {
473473
vectorfloat32.c
474474
vectorfloat64.c
475475
vectorfloat8.c
476+
vectorfloat16.c
476477
vectorIndex.c
477478
vectorvtab.c
478479
rtree.c

0 commit comments

Comments
 (0)