Skip to content

Commit d86739a

Browse files
committed
Improve request validation
Allow passing through openapi3filter Options and parameter decoder, as well as passing through the echo context to inner kin-openapi3 callbacks. This allows for implementing proper request validation.
1 parent 34f5cb8 commit d86739a

2 files changed

Lines changed: 134 additions & 11 deletions

File tree

oapi_validate.go

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ package middleware
1717
import (
1818
"context"
1919
"fmt"
20+
"io/ioutil"
21+
"net/http"
22+
2023
"github.com/getkin/kin-openapi/openapi3"
2124
"github.com/getkin/kin-openapi/openapi3filter"
2225
"github.com/labstack/echo/v4"
23-
"io/ioutil"
24-
"net/http"
2526
)
2627

28+
const EchoContextKey = "oapi-codegen/echo-context"
29+
const UserDataKey = "oapi-codegen/user-data"
30+
2731
// This is an Echo middleware function which validates incoming HTTP requests
2832
// to make sure that they conform to the given OAPI 3.0 specification. When
2933
// OAPI validation failes on the request, we return an HTTP/400.
@@ -45,10 +49,23 @@ func OapiValidatorFromYamlFile(path string) (echo.MiddlewareFunc, error) {
4549

4650
// Create a validator from a swagger object.
4751
func OapiRequestValidator(swagger *openapi3.Swagger) echo.MiddlewareFunc {
52+
return OapiRequestValidatorWithOptions(swagger, nil)
53+
}
54+
55+
// Options to customize request validation. These are passed through to
56+
// openapi3filter.
57+
type Options struct {
58+
Options openapi3filter.Options
59+
ParamDecoder openapi3filter.ContentParameterDecoder
60+
UserData interface{}
61+
}
62+
63+
// Create a validator from a swagger object, with validation options
64+
func OapiRequestValidatorWithOptions(swagger *openapi3.Swagger, options *Options) echo.MiddlewareFunc {
4865
router := openapi3filter.NewRouter().WithSwagger(swagger)
4966
return func(next echo.HandlerFunc) echo.HandlerFunc {
5067
return func(c echo.Context) error {
51-
err := ValidateRequestFromContext(c, router)
68+
err := ValidateRequestFromContext(c, router, options)
5269
if err != nil {
5370
return err
5471
}
@@ -59,7 +76,7 @@ func OapiRequestValidator(swagger *openapi3.Swagger) echo.MiddlewareFunc {
5976

6077
// This function is called from the middleware above and actually does the work
6178
// of validating a request.
62-
func ValidateRequestFromContext(ctx echo.Context, router *openapi3filter.Router) error {
79+
func ValidateRequestFromContext(ctx echo.Context, router *openapi3filter.Router, options *Options) error {
6380
req := ctx.Request()
6481
route, pathParams, err := router.FindRoute(req.Method, req.URL)
6582

@@ -78,17 +95,30 @@ func ValidateRequestFromContext(ctx echo.Context, router *openapi3filter.Router)
7895
}
7996
}
8097

81-
err = openapi3filter.ValidateRequest(context.Background(),
82-
&openapi3filter.RequestValidationInput{
83-
Request: req,
84-
PathParams: pathParams,
85-
Route: route,
86-
})
98+
validationInput := &openapi3filter.RequestValidationInput{
99+
Request: req,
100+
PathParams: pathParams,
101+
Route: route,
102+
}
103+
104+
// Pass the Echo context into the request validator, so that any callbacks
105+
// which it invokes make it available.
106+
requestContext := context.WithValue(context.Background(), EchoContextKey, ctx)
107+
108+
if options != nil {
109+
validationInput.Options = &options.Options
110+
validationInput.ParamDecoder = options.ParamDecoder
111+
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData)
112+
}
113+
114+
err = openapi3filter.ValidateRequest(requestContext, validationInput)
87115
if err != nil {
88116
switch e := err.(type) {
89117
case *openapi3filter.RequestError:
90118
// We've got a bad request
91119
return echo.NewHTTPError(http.StatusBadRequest, e.Reason)
120+
case *openapi3filter.SecurityRequirementsError:
121+
return echo.NewHTTPError(http.StatusForbidden, e.Error())
92122
default:
93123
// This should never happen today, but if our upstream code changes,
94124
// we don't want to crash the server, so handle the unexpected error.
@@ -98,3 +128,21 @@ func ValidateRequestFromContext(ctx echo.Context, router *openapi3filter.Router)
98128
}
99129
return nil
100130
}
131+
132+
// Helper function to get the echo context from within requests. It returns
133+
// nil if not found or wrong type.
134+
func GetEchoContext(c context.Context) echo.Context {
135+
iface := c.Value(EchoContextKey)
136+
if iface == nil {
137+
return nil
138+
}
139+
eCtx, ok := iface.(echo.Context)
140+
if !ok {
141+
return nil
142+
}
143+
return eCtx
144+
}
145+
146+
func GetUserData(c context.Context) interface{} {
147+
return c.Value(UserDataKey)
148+
}

oapi_validate_test.go

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package middleware
1616

1717
import (
18+
"context"
19+
"errors"
1820
"net/http"
1921
"net/http/httptest"
2022
"testing"
2123

2224
"github.com/getkin/kin-openapi/openapi3"
25+
"github.com/getkin/kin-openapi/openapi3filter"
2326
"github.com/labstack/echo/v4"
2427
"github.com/stretchr/testify/assert"
2528

@@ -66,6 +69,30 @@ paths:
6669
properties:
6770
name:
6871
type: string
72+
/protected_resource:
73+
get:
74+
operationId: getProtectedResource
75+
security:
76+
- BearerAuth:
77+
- someScope
78+
responses:
79+
'204':
80+
description: no content
81+
/protected_resource2:
82+
get:
83+
operationId: getProtectedResource
84+
security:
85+
- BearerAuth:
86+
- otherScope
87+
responses:
88+
'204':
89+
description: no content
90+
components:
91+
securitySchemes:
92+
BearerAuth:
93+
type: http
94+
scheme: bearer
95+
bearerFormat: JWT
6996
`
7097

7198
func doGet(t *testing.T, e *echo.Echo, url string) *httptest.ResponseRecorder {
@@ -85,8 +112,30 @@ func TestOapiRequestValidator(t *testing.T) {
85112
// Create a new echo router
86113
e := echo.New()
87114

115+
// Set up an authenticator to check authenticated function. It will allow
116+
// access to "someScope", but disallow others.
117+
options := Options{
118+
Options: openapi3filter.Options{
119+
AuthenticationFunc: func(c context.Context, input *openapi3filter.AuthenticationInput) error {
120+
// The echo context should be propagated into here.
121+
eCtx := GetEchoContext(c)
122+
assert.NotNil(t, eCtx)
123+
// As should user data
124+
assert.EqualValues(t, "hi!", GetUserData(c))
125+
126+
for _, s := range input.Scopes {
127+
if s == "someScope" {
128+
return nil
129+
}
130+
}
131+
return errors.New("forbidden")
132+
},
133+
},
134+
UserData:"hi!",
135+
}
136+
88137
// Install our OpenApi based request validator
89-
e.Use(OapiRequestValidator(swagger))
138+
e.Use(OapiRequestValidatorWithOptions(swagger, &options))
90139

91140
called := false
92141

@@ -159,4 +208,30 @@ func TestOapiRequestValidator(t *testing.T) {
159208
assert.False(t, called, "Handler should not have been called")
160209
called = false
161210
}
211+
212+
e.GET("/protected_resource", func(c echo.Context) error {
213+
called = true
214+
return c.NoContent(http.StatusNoContent)
215+
216+
})
217+
218+
// Call a protected function to which we have access
219+
{
220+
rec := doGet(t, e, "http://deepmap.ai/protected_resource")
221+
assert.Equal(t, http.StatusNoContent, rec.Code)
222+
assert.True(t, called, "Handler should have been called")
223+
called = false
224+
}
225+
226+
e.GET("/protected_resource2", func(c echo.Context) error {
227+
called = true
228+
return c.NoContent(http.StatusNoContent)
229+
})
230+
// Call a protected function to which we dont have access
231+
{
232+
rec := doGet(t, e, "http://deepmap.ai/protected_resource2")
233+
assert.Equal(t, http.StatusForbidden, rec.Code)
234+
assert.False(t, called, "Handler should not have been called")
235+
called = false
236+
}
162237
}

0 commit comments

Comments
 (0)