11use anyhow:: Context as _;
22use axum:: body:: StreamBody ;
33use axum:: extract:: { FromRef , Path , State } ;
4+ use axum:: middleware:: Next ;
45use axum:: routing:: delete;
56use axum:: Json ;
67use chrono:: NaiveDateTime ;
78use futures:: { SinkExt , StreamExt , TryStreamExt } ;
8- use hyper:: { Body , Request } ;
9+ use hyper:: { Body , Request , StatusCode } ;
910use metrics_exporter_prometheus:: { PrometheusBuilder , PrometheusHandle } ;
1011use parking_lot:: Mutex ;
1112use serde:: { Deserialize , Serialize } ;
@@ -64,6 +65,7 @@ pub async fn run<A, C>(
6465 connector : C ,
6566 disable_metrics : bool ,
6667 shutdown : Arc < Notify > ,
68+ auth : Option < Arc < str > > ,
6769) -> anyhow:: Result < ( ) >
6870where
6971 A : crate :: net:: Accept ,
@@ -162,15 +164,15 @@ where
162164 )
163165 . route ( "/v1/diagnostics" , get ( handle_diagnostics) )
164166 . route ( "/metrics" , get ( handle_metrics) )
167+ . route ( "/profile/heap/enable" , post ( enable_profile_heap) )
168+ . route ( "/profile/heap/disable/:id" , post ( disable_profile_heap) )
169+ . route ( "/profile/heap/:id" , delete ( delete_profile_heap) )
165170 . with_state ( Arc :: new ( AppState {
166171 namespaces,
167172 connector,
168173 user_http_server,
169174 metrics,
170175 } ) )
171- . route ( "/profile/heap/enable" , post ( enable_profile_heap) )
172- . route ( "/profile/heap/disable/:id" , post ( disable_profile_heap) )
173- . route ( "/profile/heap/:id" , delete ( delete_profile_heap) )
174176 . layer (
175177 tower_http:: trace:: TraceLayer :: new_for_http ( )
176178 . on_request ( trace_request)
@@ -179,7 +181,8 @@ where
179181 . level ( tracing:: Level :: DEBUG )
180182 . latency_unit ( tower_http:: LatencyUnit :: Micros ) ,
181183 ) ,
182- ) ;
184+ )
185+ . layer ( axum:: middleware:: from_fn_with_state ( auth, auth_middleware) ) ;
183186
184187 hyper:: server:: Server :: builder ( acceptor)
185188 . serve ( router. into_make_service ( ) )
@@ -190,6 +193,34 @@ where
190193 Ok ( ( ) )
191194}
192195
196+ async fn auth_middleware < B > (
197+ State ( auth) : State < Option < Arc < str > > > ,
198+ request : Request < B > ,
199+ next : Next < B > ,
200+ ) -> Result < axum:: response:: Response , StatusCode > {
201+ if let Some ( ref auth) = auth {
202+ let Some ( auth_header) = request. headers ( ) . get ( "authorization" ) else {
203+ return Err ( StatusCode :: UNAUTHORIZED ) ;
204+ } ;
205+ let Ok ( auth_str) = std:: str:: from_utf8 ( auth_header. as_bytes ( ) ) else {
206+ return Err ( StatusCode :: UNAUTHORIZED ) ;
207+ } ;
208+
209+ let mut split = auth_str. split_whitespace ( ) ;
210+ match split. next ( ) {
211+ Some ( s) if s. trim ( ) . eq_ignore_ascii_case ( "basic" ) => ( ) ,
212+ _ => return Err ( StatusCode :: UNAUTHORIZED ) ,
213+ }
214+
215+ match split. next ( ) {
216+ Some ( s) if s. trim ( ) == auth. as_ref ( ) => ( ) ,
217+ _ => return Err ( StatusCode :: UNAUTHORIZED ) ,
218+ }
219+ }
220+
221+ Ok ( next. run ( request) . await )
222+ }
223+
193224async fn handle_get_index ( ) -> & ' static str {
194225 "Welcome to the sqld admin API"
195226}
0 commit comments