@@ -6,18 +6,30 @@ use crate::{
66 namespace:: { NamespaceName , NamespaceStore } ,
77} ;
88use axum:: extract:: State as AxumState ;
9- use axum:: http:: header:: { CACHE_CONTROL , CONTENT_TYPE } ;
10- use axum:: http:: { HeaderValue , Uri } ;
11- use axum:: response:: { IntoResponse , Redirect , Response } ;
12- use axum_extra:: { extract:: Query , json_lines:: JsonLines } ;
9+ use axum:: http:: Uri ;
10+ use axum:: response:: {
11+ sse:: { Event , Sse } ,
12+ IntoResponse , Redirect ,
13+ } ;
14+ use axum_extra:: extract:: Query ;
1315use futures:: { Stream , StreamExt } ;
1416use hyper:: HeaderMap ;
1517use serde:: { Deserialize , Serialize } ;
18+ use std:: boxed:: Box ;
19+ use std:: convert:: Infallible ;
20+ use std:: pin:: Pin ;
21+ use std:: time:: Duration ;
1622use tokio_stream:: wrappers:: errors:: BroadcastStreamRecvError ;
1723
1824use super :: db_factory:: namespace_from_headers;
1925use super :: AppState ;
2026
27+ const LAGGED_MSG : & str = "some changes were lost" ;
28+ const KEEP_ALIVE_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
29+ const KEEP_ALIVE_TEXT : & str = "keep-alive" ;
30+
31+ type SseStream = Pin < Box < dyn Stream < Item = Result < Event , Infallible > > + Send > > ;
32+
2133#[ derive( Debug , Clone , Deserialize ) ]
2234#[ serde( rename_all = "lowercase" ) ]
2335pub enum Action {
@@ -33,16 +45,35 @@ pub struct ListenQuery {
3345 action : Option < Vec < Action > > ,
3446}
3547
36- const EVENT_STREAM : HeaderValue = HeaderValue :: from_static ( "text/event-stream" ) ;
37- const NO_CACHE : HeaderValue = HeaderValue :: from_static ( "no-cache" ) ;
48+ #[ derive( Debug , Serialize ) ]
49+ #[ serde( rename_all = "snake_case" ) ]
50+ enum AggregatorEvent {
51+ Error ( & ' static str ) ,
52+ #[ serde( untagged) ]
53+ Changes ( BroadcastMsg ) ,
54+ }
55+
56+ enum ListenResponse {
57+ SSE ( Sse < SseStream > ) ,
58+ Redirect ( Redirect ) ,
59+ }
60+
61+ impl IntoResponse for ListenResponse {
62+ fn into_response ( self ) -> axum:: response:: Response {
63+ match self {
64+ ListenResponse :: SSE ( sse) => sse. into_response ( ) ,
65+ ListenResponse :: Redirect ( redirect) => redirect. into_response ( ) ,
66+ }
67+ }
68+ }
3869
3970pub ( super ) async fn handle_listen (
4071 auth : Authenticated ,
4172 AxumState ( state) : AxumState < AppState > ,
4273 headers : HeaderMap ,
4374 uri : Uri ,
4475 query : Query < ListenQuery > ,
45- ) -> crate :: Result < Response > {
76+ ) -> crate :: Result < impl IntoResponse > {
4677 let namespace = namespace_from_headers (
4778 & headers,
4879 state. disable_default_namespace ,
@@ -53,47 +84,51 @@ pub(super) async fn handle_listen(
5384 return Err ( Error :: NamespaceDoesntExist ( namespace. to_string ( ) ) ) ;
5485 }
5586
56- if let Some ( primary_url) = state. primary_url {
57- let url = primary_url + uri. path_and_query ( ) . map_or ( "" , |x| x. as_str ( ) ) ;
58- return Ok ( Redirect :: temporary ( & url) . into_response ( ) ) ;
87+ if let Some ( primary_url) = state. primary_url . as_ref ( ) {
88+ let url = format ! (
89+ "{}{}" ,
90+ primary_url,
91+ uri. path_and_query( ) . map_or( "" , |x| x. as_str( ) )
92+ ) ;
93+ return Ok ( ListenResponse :: Redirect ( Redirect :: temporary ( & url) ) ) ;
5994 }
6095
61- let stream = listen_stream (
96+ let stream = sse_stream (
6297 state. namespaces . clone ( ) ,
6398 namespace,
6499 query. table . clone ( ) ,
65100 query. action . clone ( ) ,
66101 )
67102 . await ;
68103
69- let mut response = JsonLines :: new ( stream) . into_response ( ) ;
70- let headers = response. headers_mut ( ) ;
71- headers. insert ( CONTENT_TYPE , EVENT_STREAM ) ;
72- headers. insert ( CACHE_CONTROL , NO_CACHE ) ;
73-
74- Ok ( response)
104+ Ok ( ListenResponse :: SSE (
105+ Sse :: new ( stream) . keep_alive (
106+ axum:: response:: sse:: KeepAlive :: new ( )
107+ . interval ( KEEP_ALIVE_INTERVAL )
108+ . text ( KEEP_ALIVE_TEXT ) ,
109+ ) ,
110+ ) )
75111}
76112
77- static LAGGED_MSG : & str = "some changes were lost" ;
78-
79- #[ derive( Debug , Serialize ) ]
80- #[ serde( rename_all = "snake_case" ) ]
81- enum AggregatorEvent {
82- Error ( & ' static str ) ,
83- #[ serde( untagged) ]
84- Changes ( BroadcastMsg ) ,
85- }
86-
87- struct Subscription {
113+ async fn sse_stream (
88114 store : NamespaceStore ,
89115 namespace : NamespaceName ,
90116 table : String ,
91- }
92-
93- impl Drop for Subscription {
94- fn drop ( & mut self ) {
95- self . store . unsubscribe ( self . namespace . clone ( ) , & self . table ) ;
96- }
117+ actions : Option < Vec < Action > > ,
118+ ) -> SseStream {
119+ Box :: pin (
120+ listen_stream ( store, namespace, table, actions)
121+ . await
122+ . map ( |result| {
123+ Ok ( match result {
124+ Ok ( AggregatorEvent :: Error ( msg) ) => Event :: default ( ) . event ( "error" ) . data ( msg) ,
125+ Ok ( AggregatorEvent :: Changes ( msg) ) => {
126+ Event :: default ( ) . event ( "changes" ) . json_data ( msg) . unwrap ( )
127+ }
128+ Err ( e) => Event :: default ( ) . event ( "error" ) . data ( e. to_string ( ) ) ,
129+ } )
130+ } ) ,
131+ )
97132}
98133
99134async fn listen_stream (
@@ -103,23 +138,18 @@ async fn listen_stream(
103138 actions : Option < Vec < Action > > ,
104139) -> impl Stream < Item = crate :: Result < AggregatorEvent > > {
105140 async_stream:: try_stream! {
106- let _sub = Subscription {
107- store: store. clone( ) ,
108- namespace: namespace. clone( ) ,
109- table: table. clone( ) ,
110- } ;
111-
141+ let _sub = Subscription :: new( store. clone( ) , namespace. clone( ) , table. clone( ) ) ;
112142 let mut stream = store. subscribe( namespace. clone( ) , table. clone( ) ) ;
113143
114- while let Some ( item) = stream. next( ) . await {
144+ while let Some ( item) = stream. next( ) . await {
115145 match item {
116146 Ok ( msg) => if filter_actions( & msg, & actions) {
117147 LISTEN_EVENTS_SENT . increment( 1 ) ;
118148 yield AggregatorEvent :: Changes ( msg) ;
119149 } ,
120150 Err ( BroadcastStreamRecvError :: Lagged ( n) ) => {
121151 LISTEN_EVENTS_DROPPED . increment( n as u64 ) ;
122- yield AggregatorEvent :: Error ( & LAGGED_MSG ) ;
152+ yield AggregatorEvent :: Error ( LAGGED_MSG ) ;
123153 } ,
124154 }
125155 }
@@ -128,17 +158,36 @@ async fn listen_stream(
128158
129159fn filter_actions ( msg : & BroadcastMsg , actions : & Option < Vec < Action > > ) -> bool {
130160 actions. as_ref ( ) . map_or ( true , |actions| {
131- for action in actions {
161+ actions. iter ( ) . any ( |action| {
132162 let count = match action {
133163 Action :: DELETE => msg. delete ,
134164 Action :: INSERT => msg. insert ,
135165 Action :: UPDATE => msg. update ,
136166 Action :: UNKNOWN => msg. unknown ,
137167 } ;
138- if count > 0 {
139- return true ;
140- }
141- }
142- false
168+ count > 0
169+ } )
143170 } )
144171}
172+
173+ struct Subscription {
174+ store : NamespaceStore ,
175+ namespace : NamespaceName ,
176+ table : String ,
177+ }
178+
179+ impl Subscription {
180+ fn new ( store : NamespaceStore , namespace : NamespaceName , table : String ) -> Self {
181+ Self {
182+ store,
183+ namespace,
184+ table,
185+ }
186+ }
187+ }
188+
189+ impl Drop for Subscription {
190+ fn drop ( & mut self ) {
191+ self . store . unsubscribe ( self . namespace . clone ( ) , & self . table ) ;
192+ }
193+ }
0 commit comments