Skip to content

Commit 88623b8

Browse files
authored
Add ErrorHandler for middleware (#557)
1 parent dce4fa7 commit 88623b8

2 files changed

Lines changed: 16 additions & 3 deletions

File tree

oapi_validate.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ import (
2929
echomiddleware "github.com/labstack/echo/v4/middleware"
3030
)
3131

32-
const EchoContextKey = "oapi-codegen/echo-context"
33-
const UserDataKey = "oapi-codegen/user-data"
32+
const (
33+
EchoContextKey = "oapi-codegen/echo-context"
34+
UserDataKey = "oapi-codegen/user-data"
35+
)
3436

3537
// This is an Echo middleware function which validates incoming HTTP requests
3638
// to make sure that they conform to the given OAPI 3.0 specification. When
@@ -56,9 +58,13 @@ func OapiRequestValidator(swagger *openapi3.T) echo.MiddlewareFunc {
5658
return OapiRequestValidatorWithOptions(swagger, nil)
5759
}
5860

61+
// ErrorHandler is called when there is an error in validation
62+
type ErrorHandler func(c echo.Context, err *echo.HTTPError) error
63+
5964
// Options to customize request validation. These are passed through to
6065
// openapi3filter.
6166
type Options struct {
67+
ErrorHandler ErrorHandler
6268
Options openapi3filter.Options
6369
ParamDecoder openapi3filter.ContentParameterDecoder
6470
UserData interface{}
@@ -81,6 +87,9 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) echo
8187

8288
err := ValidateRequestFromContext(c, router, options)
8389
if err != nil {
90+
if options != nil && options.ErrorHandler != nil {
91+
return options.ErrorHandler(c, err)
92+
}
8493
return err
8594
}
8695
return next(c)
@@ -90,7 +99,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) echo
9099

91100
// ValidateRequestFromContext is called from the middleware above and actually does the work
92101
// of validating a request.
93-
func ValidateRequestFromContext(ctx echo.Context, router routers.Router, options *Options) error {
102+
func ValidateRequestFromContext(ctx echo.Context, router routers.Router, options *Options) *echo.HTTPError {
94103
req := ctx.Request()
95104
route, pathParams, err := router.FindRoute(req)
96105

oapi_validate_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ func TestOapiRequestValidator(t *testing.T) {
137137
// Set up an authenticator to check authenticated function. It will allow
138138
// access to "someScope", but disallow others.
139139
options := Options{
140+
ErrorHandler: func(c echo.Context, err *echo.HTTPError) error {
141+
return c.String(err.Code, "test: "+err.Error())
142+
},
140143
Options: openapi3filter.Options{
141144
AuthenticationFunc: func(c context.Context, input *openapi3filter.AuthenticationInput) error {
142145
// The echo context should be propagated into here.
@@ -268,6 +271,7 @@ func TestOapiRequestValidator(t *testing.T) {
268271
{
269272
rec := doGet(t, e, "http://deepmap.ai/protected_resource_401")
270273
assert.Equal(t, http.StatusUnauthorized, rec.Code)
274+
assert.Equal(t, "test: code=401, message=Unauthorized", rec.Body.String())
271275
assert.False(t, called, "Handler should not have been called")
272276
called = false
273277
}

0 commit comments

Comments
 (0)