Skip to content

Commit d5866fe

Browse files
[Rust-Axum] Implement support for Basic and Bearer auth in Claims (#20584)
* Implement a custom error handler for unhandled or generic endpoint errors * Pass in method, host and cookies to error handler * Update axum to 0.8 * Make API methods take references instead of ownership * Rebase error handler * Rebase with updated error handler * Update deps * Fix capture group syntax * Rebase rust-axum-error-handling * Update docs * Multipart is also part of the axum update * Update samples * Update docs
1 parent eb668b6 commit d5866fe

23 files changed

Lines changed: 267 additions & 207 deletions

File tree

docs/generators/rust-axum.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,16 @@ These options may be applied as additional-properties (cli) or configOptions (pl
209209
|Union|✗|OAS3
210210
|allOf|✗|OAS2,OAS3
211211
|anyOf|✗|OAS3
212-
|oneOf||OAS3
212+
|oneOf||OAS3
213213
|not|✗|OAS3
214214

215215
### Security Feature
216216
| Name | Supported | Defined By |
217217
| ---- | --------- | ---------- |
218-
|BasicAuth||OAS2,OAS3
218+
|BasicAuth||OAS2,OAS3
219219
|ApiKey|✓|OAS2,OAS3
220220
|OpenIDConnect|✗|OAS3
221-
|BearerToken||OAS3
221+
|BearerToken||OAS3
222222
|OAuth2_Implicit|✗|OAS2,OAS3
223223
|OAuth2_Password|✗|OAS2,OAS3
224224
|OAuth2_ClientCredentials|✗|OAS2,OAS3

modules/openapi-generator/src/main/java/org/openapitools/codegen/languages/RustAxumServerCodegen.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ public class RustAxumServerCodegen extends AbstractRustCodegen implements Codege
8585
// Grouping (Method, Operation) by Path.
8686
private final Map<String, ArrayList<MethodOperation>> pathMethodOpMap = new HashMap<>();
8787
private boolean havingAuthMethods = false;
88+
private boolean havingBasicAuthMethods = false;
8889

8990
// Logger
9091
private final Logger LOGGER = LoggerFactory.getLogger(RustAxumServerCodegen.class);
@@ -98,7 +99,14 @@ public RustAxumServerCodegen() {
9899
WireFormatFeature.Custom
99100
))
100101
.securityFeatures(EnumSet.of(
101-
SecurityFeature.ApiKey
102+
SecurityFeature.ApiKey,
103+
SecurityFeature.BasicAuth,
104+
SecurityFeature.BearerToken
105+
))
106+
.schemaSupportFeatures(EnumSet.of(
107+
SchemaSupportFeature.Simple,
108+
SchemaSupportFeature.Composite,
109+
SchemaSupportFeature.oneOf
102110
))
103111
.excludeGlobalFeatures(
104112
GlobalFeature.Info,
@@ -777,6 +785,16 @@ private boolean postProcessOperationWithModels(final CodegenOperation op) {
777785

778786
op.vendorExtensions.put("x-has-auth-methods", true);
779787
hasAuthMethod = true;
788+
} else if (s.isBasic) {
789+
op.vendorExtensions.put("x-has-basic-auth-methods", true);
790+
op.vendorExtensions.put("x-is-basic-bearer", s.isBasicBearer);
791+
op.vendorExtensions.put("x-api-auth-header-name", "authorization");
792+
793+
op.vendorExtensions.put("x-has-auth-methods", true);
794+
hasAuthMethod = true;
795+
796+
if (!this.havingBasicAuthMethods)
797+
this.havingBasicAuthMethods = true;
780798
}
781799
}
782800
}
@@ -878,6 +896,7 @@ public Map<String, Object> postProcessSupportingFileData(Map<String, Object> bun
878896
.collect(Collectors.toList());
879897
bundle.put("pathMethodOps", pathMethodOps);
880898
if (havingAuthMethods) bundle.put("havingAuthMethods", true);
899+
if (havingBasicAuthMethods) bundle.put("havingBasicAuthMethods", true);
881900

882901
return super.postProcessSupportingFileData(bundle);
883902
}

modules/openapi-generator/src/main/resources/rust-axum/apis-mod.mustache

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,43 @@ pub mod {{classFilename}};
1212
pub trait CookieAuthentication {
1313
type Claims;
1414
15-
/// Extracting Claims from Cookie. Return None if the Claims is invalid.
15+
/// Extracting Claims from Cookie. Return None if the Claims are invalid.
1616
async fn extract_claims_from_cookie(&self, cookies: &axum_extra::extract::CookieJar, key: &str) -> Option<Self::Claims>;
1717
}
18+
1819
{{/isKeyInCookie}}
1920
{{#isKeyInHeader}}
2021
/// API Key Authentication - Header.
2122
#[async_trait::async_trait]
2223
pub trait ApiKeyAuthHeader {
2324
type Claims;
2425
25-
/// Extracting Claims from Header. Return None if the Claims is invalid.
26+
/// Extracting Claims from Header. Return None if the Claims are invalid.
2627
async fn extract_claims_from_header(&self, headers: &axum::http::header::HeaderMap, key: &str) -> Option<Self::Claims>;
2728
}
29+
2830
{{/isKeyInHeader}}
2931
{{/isApiKey}}
3032
{{/authMethods}}
33+
{{#havingBasicAuthMethods}}
34+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35+
#[non_exhaustive]
36+
pub enum BasicAuthKind {
37+
Basic,
38+
Bearer,
39+
}
40+
41+
/// API Key Authentication - Authentication Header.
42+
/// For `Basic token` and `Bearer token`
43+
#[async_trait::async_trait]
44+
pub trait ApiAuthBasic {
45+
type Claims;
46+
47+
/// Extracting Claims from Header. Return None if the Claims are invalid.
48+
async fn extract_claims_from_auth_header(&self, kind: BasicAuthKind, headers: &axum::http::header::HeaderMap, key: &str) -> Option<Self::Claims>;
49+
}
50+
51+
{{/havingBasicAuthMethods}}
3152

3253
// Error handler for unhandled errors.
3354
#[async_trait::async_trait]

modules/openapi-generator/src/main/resources/rust-axum/server-operation.mustache

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A, E{
1212
{{#x-has-header-auth-methods}}
1313
headers: HeaderMap,
1414
{{/x-has-header-auth-methods}}
15+
{{^x-has-header-auth-methods}}
16+
{{#x-has-basic-auth-methods}}
17+
headers: HeaderMap,
18+
{{/x-has-basic-auth-methods}}
19+
{{/x-has-header-auth-methods}}
1520
{{/vendorExtensions}}
1621
{{/headerParams.size}}
1722
{{#pathParams.size}}
@@ -54,7 +59,7 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A, E{
5459
) -> Result<Response, StatusCode>
5560
where
5661
I: AsRef<A> + Send + Sync,
57-
A: apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}>{{#vendorExtensions}}{{#x-has-cookie-auth-methods}}+ apis::CookieAuthentication<Claims = C>{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}+ apis::ApiKeyAuthHeader<Claims = C>{{/x-has-header-auth-methods}}{{/vendorExtensions}} + Send + Sync,
62+
A: apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}>{{#vendorExtensions}}{{#x-has-cookie-auth-methods}}+ apis::CookieAuthentication<Claims = C>{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}+ apis::ApiKeyAuthHeader<Claims = C>{{/x-has-header-auth-methods}}{{#x-has-basic-auth-methods}}+ apis::ApiAuthBasic<Claims = C>{{/x-has-basic-auth-methods}}{{/vendorExtensions}} + Send + Sync,
5863
E: std::fmt::Debug + Send + Sync + 'static,
5964
{
6065
{{#vendorExtensions}}
@@ -67,14 +72,20 @@ where
6772
{{#x-has-header-auth-methods}}
6873
let claims_in_header = api_impl.as_ref().extract_claims_from_header(&headers, "{{x-api-key-header-name}}").await;
6974
{{/x-has-header-auth-methods}}
75+
{{#x-has-basic-auth-methods}}
76+
let claims_in_auth_header = api_impl.as_ref().extract_claims_from_auth_header(apis::BasicAuthKind::{{#x-is-basic-bearer}}Bearer{{/x-is-basic-bearer}}{{^x-is-basic-bearer}}Basic{{/x-is-basic-bearer}}, &headers, "{{x-api-auth-header-name}}").await;
77+
{{/x-has-basic-auth-methods}}
7078
{{#x-has-auth-methods}}
7179
let claims = None
7280
{{#x-has-cookie-auth-methods}}
7381
.or(claims_in_cookie)
7482
{{/x-has-cookie-auth-methods}}
7583
{{#x-has-header-auth-methods}}
76-
.or(claims_in_header)
84+
.or(claims_in_header)
7785
{{/x-has-header-auth-methods}}
86+
{{#x-has-basic-auth-methods}}
87+
.or(claims_in_auth_header)
88+
{{/x-has-basic-auth-methods}}
7889
;
7990
let Some(claims) = claims else {
8091
return Response::builder()
@@ -346,7 +357,6 @@ where
346357
Err(why) => {
347358
// Application code returned an error. This should not happen, as the implementation should
348359
// return a valid response.
349-
350360
return api_impl.as_ref().handle_error(&method, &host, &cookies, why).await;
351361
},
352362
};

modules/openapi-generator/src/main/resources/rust-axum/server-route.mustache

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
pub fn new<I, A, E{{#havingAuthMethods}}, C{{/havingAuthMethods}}>(api_impl: I) -> Router
33
where
44
I: AsRef<A> + Clone + Send + Sync + 'static,
5-
A: {{#apiInfo}}{{#apis}}{{#operations}}apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}> + {{/operations}}{{/apis}}{{/apiInfo}}{{#authMethods}}{{#isApiKey}}{{#isKeyInCookie}}apis::CookieAuthentication<Claims = C> + {{/isKeyInCookie}}{{#isKeyInHeader}}apis::ApiKeyAuthHeader<Claims = C> + {{/isKeyInHeader}}{{/isApiKey}}{{/authMethods}}Send + Sync + 'static,
5+
A: {{#apiInfo}}{{#apis}}{{#operations}}apis::{{classFilename}}::{{classnamePascalCase}}<E{{#havingAuthMethod}}, Claims = C{{/havingAuthMethod}}> + {{/operations}}{{/apis}}{{/apiInfo}}{{#authMethods}}{{#isApiKey}}{{#isKeyInCookie}}apis::CookieAuthentication<Claims = C> + {{/isKeyInCookie}}{{#isKeyInHeader}}apis::ApiKeyAuthHeader<Claims = C> + {{/isKeyInHeader}}{{/isApiKey}}{{#isBasic}}apis::ApiAuthBasic<Claims = C> + {{/isBasic}}{{/authMethods}}Send + Sync + 'static,
66
E: std::fmt::Debug + Send + Sync + 'static,
77
{{#havingAuthMethods}}C: Send + Sync + 'static,{{/havingAuthMethods}}
88
{

samples/server/petstore/rust-axum/output/apikey-auths/src/apis/mod.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,49 @@ pub mod payments;
55
pub trait ApiKeyAuthHeader {
66
type Claims;
77

8-
/// Extracting Claims from Header. Return None if the Claims is invalid.
8+
/// Extracting Claims from Header. Return None if the Claims are invalid.
99
async fn extract_claims_from_header(
1010
&self,
1111
headers: &axum::http::header::HeaderMap,
1212
key: &str,
1313
) -> Option<Self::Claims>;
1414
}
15+
1516
/// Cookie Authentication.
1617
#[async_trait::async_trait]
1718
pub trait CookieAuthentication {
1819
type Claims;
1920

20-
/// Extracting Claims from Cookie. Return None if the Claims is invalid.
21+
/// Extracting Claims from Cookie. Return None if the Claims are invalid.
2122
async fn extract_claims_from_cookie(
2223
&self,
2324
cookies: &axum_extra::extract::CookieJar,
2425
key: &str,
2526
) -> Option<Self::Claims>;
2627
}
2728

29+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30+
#[non_exhaustive]
31+
pub enum BasicAuthKind {
32+
Basic,
33+
Bearer,
34+
}
35+
36+
/// API Key Authentication - Authentication Header.
37+
/// For `Basic token` and `Bearer token`
38+
#[async_trait::async_trait]
39+
pub trait ApiAuthBasic {
40+
type Claims;
41+
42+
/// Extracting Claims from Header. Return None if the Claims are invalid.
43+
async fn extract_claims_from_auth_header(
44+
&self,
45+
kind: BasicAuthKind,
46+
headers: &axum::http::header::HeaderMap,
47+
key: &str,
48+
) -> Option<Self::Claims>;
49+
}
50+
2851
// Error handler for unhandled errors.
2952
#[async_trait::async_trait]
3053
pub trait ErrorHandler<E: std::fmt::Debug + Send + Sync + 'static = ()> {

samples/server/petstore/rust-axum/output/apikey-auths/src/apis/payments.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub trait Payments<E: std::fmt::Debug + Send + Sync + 'static = ()>:
5151
method: &Method,
5252
host: &Host,
5353
cookies: &CookieJar,
54+
claims: &Self::Claims,
5455
path_params: &models::GetPaymentMethodByIdPathParams,
5556
) -> Result<GetPaymentMethodByIdResponse, E>;
5657

@@ -62,6 +63,7 @@ pub trait Payments<E: std::fmt::Debug + Send + Sync + 'static = ()>:
6263
method: &Method,
6364
host: &Host,
6465
cookies: &CookieJar,
66+
claims: &Self::Claims,
6567
) -> Result<GetPaymentMethodsResponse, E>;
6668

6769
/// Make a payment.

samples/server/petstore/rust-axum/output/apikey-auths/src/server/mod.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pub fn new<I, A, E, C>(api_impl: I) -> Router
1717
where
1818
I: AsRef<A> + Clone + Send + Sync + 'static,
1919
A: apis::payments::Payments<E, Claims = C>
20+
+ apis::ApiAuthBasic<Claims = C>
21+
+ apis::ApiAuthBasic<Claims = C>
2022
+ apis::ApiKeyAuthHeader<Claims = C>
2123
+ apis::CookieAuthentication<Claims = C>
2224
+ Send
@@ -53,14 +55,28 @@ async fn get_payment_method_by_id<I, A, E, C>(
5355
method: Method,
5456
host: Host,
5557
cookies: CookieJar,
58+
headers: HeaderMap,
5659
Path(path_params): Path<models::GetPaymentMethodByIdPathParams>,
5760
State(api_impl): State<I>,
5861
) -> Result<Response, StatusCode>
5962
where
6063
I: AsRef<A> + Send + Sync,
61-
A: apis::payments::Payments<E, Claims = C> + Send + Sync,
64+
A: apis::payments::Payments<E, Claims = C> + apis::ApiAuthBasic<Claims = C> + Send + Sync,
6265
E: std::fmt::Debug + Send + Sync + 'static,
6366
{
67+
// Authentication
68+
let claims_in_auth_header = api_impl
69+
.as_ref()
70+
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
71+
.await;
72+
let claims = None.or(claims_in_auth_header);
73+
let Some(claims) = claims else {
74+
return Response::builder()
75+
.status(StatusCode::UNAUTHORIZED)
76+
.body(Body::empty())
77+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
78+
};
79+
6480
#[allow(clippy::redundant_closure)]
6581
let validation =
6682
tokio::task::spawn_blocking(move || get_payment_method_by_id_validation(path_params))
@@ -76,7 +92,7 @@ where
7692

7793
let result = api_impl
7894
.as_ref()
79-
.get_payment_method_by_id(&method, &host, &cookies, &path_params)
95+
.get_payment_method_by_id(&method, &host, &cookies, &claims, &path_params)
8096
.await;
8197

8298
let mut response = Response::builder();
@@ -133,7 +149,6 @@ where
133149
Err(why) => {
134150
// Application code returned an error. This should not happen, as the implementation should
135151
// return a valid response.
136-
137152
return api_impl
138153
.as_ref()
139154
.handle_error(&method, &host, &cookies, why)
@@ -157,13 +172,27 @@ async fn get_payment_methods<I, A, E, C>(
157172
method: Method,
158173
host: Host,
159174
cookies: CookieJar,
175+
headers: HeaderMap,
160176
State(api_impl): State<I>,
161177
) -> Result<Response, StatusCode>
162178
where
163179
I: AsRef<A> + Send + Sync,
164-
A: apis::payments::Payments<E, Claims = C> + Send + Sync,
180+
A: apis::payments::Payments<E, Claims = C> + apis::ApiAuthBasic<Claims = C> + Send + Sync,
165181
E: std::fmt::Debug + Send + Sync + 'static,
166182
{
183+
// Authentication
184+
let claims_in_auth_header = api_impl
185+
.as_ref()
186+
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
187+
.await;
188+
let claims = None.or(claims_in_auth_header);
189+
let Some(claims) = claims else {
190+
return Response::builder()
191+
.status(StatusCode::UNAUTHORIZED)
192+
.body(Body::empty())
193+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
194+
};
195+
167196
#[allow(clippy::redundant_closure)]
168197
let validation = tokio::task::spawn_blocking(move || get_payment_methods_validation())
169198
.await
@@ -178,7 +207,7 @@ where
178207

179208
let result = api_impl
180209
.as_ref()
181-
.get_payment_methods(&method, &host, &cookies)
210+
.get_payment_methods(&method, &host, &cookies, &claims)
182211
.await;
183212

184213
let mut response = Response::builder();
@@ -212,7 +241,6 @@ where
212241
Err(why) => {
213242
// Application code returned an error. This should not happen, as the implementation should
214243
// return a valid response.
215-
216244
return api_impl
217245
.as_ref()
218246
.handle_error(&method, &host, &cookies, why)
@@ -250,13 +278,15 @@ async fn post_make_payment<I, A, E, C>(
250278
method: Method,
251279
host: Host,
252280
cookies: CookieJar,
281+
headers: HeaderMap,
253282
State(api_impl): State<I>,
254283
Json(body): Json<Option<models::Payment>>,
255284
) -> Result<Response, StatusCode>
256285
where
257286
I: AsRef<A> + Send + Sync,
258287
A: apis::payments::Payments<E, Claims = C>
259288
+ apis::CookieAuthentication<Claims = C>
289+
+ apis::ApiAuthBasic<Claims = C>
260290
+ Send
261291
+ Sync,
262292
E: std::fmt::Debug + Send + Sync + 'static,
@@ -266,7 +296,11 @@ where
266296
.as_ref()
267297
.extract_claims_from_cookie(&cookies, "X-API-Key")
268298
.await;
269-
let claims = None.or(claims_in_cookie);
299+
let claims_in_auth_header = api_impl
300+
.as_ref()
301+
.extract_claims_from_auth_header(apis::BasicAuthKind::Bearer, &headers, "authorization")
302+
.await;
303+
let claims = None.or(claims_in_cookie).or(claims_in_auth_header);
270304
let Some(claims) = claims else {
271305
return Response::builder()
272306
.status(StatusCode::UNAUTHORIZED)
@@ -345,7 +379,6 @@ where
345379
Err(why) => {
346380
// Application code returned an error. This should not happen, as the implementation should
347381
// return a valid response.
348-
349382
return api_impl
350383
.as_ref()
351384
.handle_error(&method, &host, &cookies, why)

0 commit comments

Comments
 (0)