-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhelpers.go
More file actions
440 lines (402 loc) · 12.8 KB
/
helpers.go
File metadata and controls
440 lines (402 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
package params
//oapi-runtime:function params/ParamHelpers
import (
"bytes"
"encoding"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/oapi-codegen/oapi-codegen-exp/codegen/internal/runtime/types"
)
// ParamLocation indicates where a parameter is located in an HTTP request.
type ParamLocation int
const (
ParamLocationUndefined ParamLocation = iota
ParamLocationQuery
ParamLocationPath
ParamLocationHeader
ParamLocationCookie
)
// Binder is an interface for types that can bind themselves from a string value.
type Binder interface {
Bind(value string) error
}
// MissingRequiredParameterError is returned when a required parameter is not
// present in the request. Upper layers can use errors.As to detect this and
// produce an appropriate HTTP error response.
type MissingRequiredParameterError struct {
ParamName string
}
func (e *MissingRequiredParameterError) Error() string {
return fmt.Sprintf("parameter '%s' is required", e.ParamName)
}
// primitiveToString converts a primitive value to a string representation.
// It handles basic Go types, time.Time, types.Date, and types that implement
// json.Marshaler or fmt.Stringer.
func primitiveToString(value any) (string, error) {
// Check for known types first (time, date, uuid)
if res, ok := marshalKnownTypes(value); ok {
return res, nil
}
// Dereference pointers for optional values
v := reflect.Indirect(reflect.ValueOf(value))
t := v.Type()
kind := t.Kind()
switch kind {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return strconv.FormatInt(v.Int(), 10), nil
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return strconv.FormatUint(v.Uint(), 10), nil
case reflect.Float64:
return strconv.FormatFloat(v.Float(), 'f', -1, 64), nil
case reflect.Float32:
return strconv.FormatFloat(v.Float(), 'f', -1, 32), nil
case reflect.Bool:
if v.Bool() {
return "true", nil
}
return "false", nil
case reflect.String:
return v.String(), nil
case reflect.Struct:
// Check if it's a UUID
if u, ok := value.(uuid.UUID); ok {
return u.String(), nil
}
// Check if it implements json.Marshaler
if m, ok := value.(json.Marshaler); ok {
buf, err := m.MarshalJSON()
if err != nil {
return "", fmt.Errorf("failed to marshal to JSON: %w", err)
}
e := json.NewDecoder(bytes.NewReader(buf))
e.UseNumber()
var i2 any
if err = e.Decode(&i2); err != nil {
return "", fmt.Errorf("failed to decode JSON: %w", err)
}
return primitiveToString(i2)
}
fallthrough
default:
if s, ok := value.(fmt.Stringer); ok {
return s.String(), nil
}
return "", fmt.Errorf("unsupported type %s", reflect.TypeOf(value).String())
}
}
// marshalKnownTypes checks for special types (time.Time, Date, UUID) and marshals them.
func marshalKnownTypes(value any) (string, bool) {
v := reflect.Indirect(reflect.ValueOf(value))
t := v.Type()
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
tt := v.Convert(reflect.TypeOf(time.Time{}))
timeVal := tt.Interface().(time.Time)
return timeVal.Format(time.RFC3339Nano), true
}
if t.ConvertibleTo(reflect.TypeOf(types.Date{})) {
d := v.Convert(reflect.TypeOf(types.Date{}))
dateVal := d.Interface().(types.Date)
return dateVal.Format(types.DateFormat), true
}
if t.ConvertibleTo(reflect.TypeOf(uuid.UUID{})) {
u := v.Convert(reflect.TypeOf(uuid.UUID{}))
uuidVal := u.Interface().(uuid.UUID)
return uuidVal.String(), true
}
return "", false
}
// escapeParameterName escapes a parameter name for use in query strings and
// paths. This ensures characters like [] in parameter names (e.g. user_ids[])
// are properly percent-encoded per RFC 3986.
func escapeParameterName(name string, paramLocation ParamLocation) string {
// Parameter names should always be encoded regardless of allowReserved,
// which only applies to values per the OpenAPI spec.
return escapeParameterString(name, paramLocation, false)
}
// escapeParameterString escapes a parameter value based on its location.
// Query and path parameters need URL escaping; headers and cookies do not.
// When allowReserved is true and the location is query, RFC 3986 reserved
// characters are left unencoded per the OpenAPI allowReserved specification.
func escapeParameterString(value string, paramLocation ParamLocation, allowReserved bool) string {
switch paramLocation {
case ParamLocationQuery:
if allowReserved {
return escapeQueryAllowReserved(value)
}
return url.QueryEscape(value)
case ParamLocationPath:
return url.PathEscape(value)
default:
return value
}
}
// escapeQueryAllowReserved percent-encodes a query parameter value while
// leaving RFC 3986 reserved characters (:/?#[]@!$&'()*+,;=) unencoded, as
// specified by OpenAPI's allowReserved parameter option.
func escapeQueryAllowReserved(value string) string {
const reserved = `:/?#[]@!$&'()*+,;=`
var buf strings.Builder
for _, b := range []byte(value) {
if isUnreserved(b) || strings.IndexByte(reserved, b) >= 0 {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "%%%02X", b)
}
}
return buf.String()
}
// isUnreserved reports whether the byte is an RFC 3986 unreserved character:
// ALPHA / DIGIT / "-" / "." / "_" / "~"
func isUnreserved(c byte) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~'
}
// unescapeParameterString unescapes a parameter value based on its location.
func unescapeParameterString(value string, paramLocation ParamLocation) (string, error) {
switch paramLocation {
case ParamLocationQuery, ParamLocationUndefined:
return url.QueryUnescape(value)
case ParamLocationPath:
return url.PathUnescape(value)
default:
return value, nil
}
}
// sortedKeys returns the keys of a map in sorted order.
func sortedKeys(m map[string]string) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
// BindStringToObject binds a string value to a destination object.
// It handles primitives, encoding.TextUnmarshaler, and the Binder interface.
func BindStringToObject(src string, dst any) error {
// Check for TextUnmarshaler
if tu, ok := dst.(encoding.TextUnmarshaler); ok {
return tu.UnmarshalText([]byte(src))
}
// Check for Binder interface
if b, ok := dst.(Binder); ok {
return b.Bind(src)
}
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dst must be a pointer, got %T", dst)
}
v = v.Elem()
switch v.Kind() {
case reflect.String:
v.SetString(src)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
i, err := strconv.ParseInt(src, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse int: %w", err)
}
v.SetInt(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
u, err := strconv.ParseUint(src, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse uint: %w", err)
}
v.SetUint(u)
case reflect.Float32, reflect.Float64:
f, err := strconv.ParseFloat(src, 64)
if err != nil {
return fmt.Errorf("failed to parse float: %w", err)
}
v.SetFloat(f)
case reflect.Bool:
b, err := strconv.ParseBool(src)
if err != nil {
return fmt.Errorf("failed to parse bool: %w", err)
}
v.SetBool(b)
default:
// Try JSON unmarshal as a fallback
return json.Unmarshal([]byte(src), dst)
}
return nil
}
// bindSplitPartsToDestinationArray binds a slice of string parts to a destination slice.
func bindSplitPartsToDestinationArray(parts []string, dest any) error {
v := reflect.Indirect(reflect.ValueOf(dest))
t := v.Type()
newArray := reflect.MakeSlice(t, len(parts), len(parts))
for i, p := range parts {
err := BindStringToObject(p, newArray.Index(i).Addr().Interface())
if err != nil {
return fmt.Errorf("error setting array element: %w", err)
}
}
v.Set(newArray)
return nil
}
// bindSplitPartsToDestinationStruct binds string parts to a destination struct via JSON.
func bindSplitPartsToDestinationStruct(paramName string, parts []string, explode bool, dest any) error {
var fields []string
if explode {
fields = make([]string, len(parts))
for i, property := range parts {
propertyParts := strings.Split(property, "=")
if len(propertyParts) != 2 {
return fmt.Errorf("parameter '%s' has invalid exploded format", paramName)
}
fields[i] = "\"" + propertyParts[0] + "\":\"" + propertyParts[1] + "\""
}
} else {
if len(parts)%2 != 0 {
return fmt.Errorf("parameter '%s' has invalid format, property/values need to be pairs", paramName)
}
fields = make([]string, len(parts)/2)
for i := 0; i < len(parts); i += 2 {
key := parts[i]
value := parts[i+1]
fields[i/2] = "\"" + key + "\":\"" + value + "\""
}
}
jsonParam := "{" + strings.Join(fields, ",") + "}"
return json.Unmarshal([]byte(jsonParam), dest)
}
// splitStyledParameter splits a styled parameter string value into parts based
// on the OpenAPI style. The object flag indicates whether the destination is a
// struct/map (affects matrix explode handling).
func splitStyledParameter(style string, explode bool, object bool, paramName string, value string) ([]string, error) {
switch style {
case "simple":
// In the simple case, we always split on comma
return strings.Split(value, ","), nil
case "label":
if explode {
// Exploded: .a.b.c or .key=value.key=value
parts := strings.Split(value, ".")
if parts[0] != "" {
return nil, fmt.Errorf("invalid format for label parameter '%s', should start with '.'", paramName)
}
return parts[1:], nil
}
// Unexploded: .a,b,c
if value[0] != '.' {
return nil, fmt.Errorf("invalid format for label parameter '%s', should start with '.'", paramName)
}
return strings.Split(value[1:], ","), nil
case "matrix":
if explode {
// Exploded: ;a;b;c or ;key=value;key=value
parts := strings.Split(value, ";")
if parts[0] != "" {
return nil, fmt.Errorf("invalid format for matrix parameter '%s', should start with ';'", paramName)
}
parts = parts[1:]
if !object {
prefix := paramName + "="
for i := range parts {
parts[i] = strings.TrimPrefix(parts[i], prefix)
}
}
return parts, nil
}
// Unexploded: ;paramName=a,b,c
prefix := ";" + paramName + "="
if !strings.HasPrefix(value, prefix) {
return nil, fmt.Errorf("expected parameter '%s' to start with %s", paramName, prefix)
}
return strings.Split(strings.TrimPrefix(value, prefix), ","), nil
case "form":
if explode {
parts := strings.Split(value, "&")
if !object {
prefix := paramName + "="
for i := range parts {
parts[i] = strings.TrimPrefix(parts[i], prefix)
}
}
return parts, nil
}
parts := strings.Split(value, ",")
prefix := paramName + "="
for i := range parts {
parts[i] = strings.TrimPrefix(parts[i], prefix)
}
return parts, nil
}
return nil, fmt.Errorf("unhandled parameter style: %s", style)
}
// findRawQueryParam extracts values for a named parameter from a raw
// (undecoded) query string. The parameter key is decoded for comparison
// purposes, but the returned values remain in their original encoded form.
func findRawQueryParam(rawQuery, paramName string) (values []string, found bool) {
for rawQuery != "" {
var part string
if i := strings.IndexByte(rawQuery, '&'); i >= 0 {
part = rawQuery[:i]
rawQuery = rawQuery[i+1:]
} else {
part = rawQuery
rawQuery = ""
}
if part == "" {
continue
}
key := part
var val string
if i := strings.IndexByte(part, '='); i >= 0 {
key = part[:i]
val = part[i+1:]
}
decodedKey, err := url.QueryUnescape(key)
if err != nil {
// Skip malformed keys.
continue
}
if decodedKey == paramName {
values = append(values, val)
found = true
}
}
return values, found
}
// isByteSlice reports whether t is []byte (or equivalently []uint8).
func isByteSlice(t reflect.Type) bool {
return t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8
}
// base64Decode decodes s as base64.
//
// Per OpenAPI 3.0, format: byte uses RFC 4648 Section 4 (standard alphabet,
// padded). We use padding presence to select the right decoder, rather than
// blindly cascading (which can produce corrupt output when RawStdEncoding
// silently accepts padded input and treats '=' as data).
func base64Decode(s string) ([]byte, error) {
if s == "" {
return []byte{}, nil
}
if strings.ContainsRune(s, '=') {
if strings.ContainsAny(s, "-_") {
return base64Decode1(base64.URLEncoding, s)
}
return base64Decode1(base64.StdEncoding, s)
}
if strings.ContainsAny(s, "-_") {
return base64Decode1(base64.RawURLEncoding, s)
}
return base64Decode1(base64.RawStdEncoding, s)
}
func base64Decode1(enc *base64.Encoding, s string) ([]byte, error) {
b, err := enc.DecodeString(s)
if err != nil {
return nil, fmt.Errorf("failed to base64-decode string %q: %w", s, err)
}
return b, nil
}