Skip to content

Commit 59f189e

Browse files
authored
Merge pull request #1671 from tursodatabase/format-listen-sse-endpoint
feat(libsql-server): send correct sse response
2 parents 5d25018 + f679602 commit 59f189e

1 file changed

Lines changed: 97 additions & 48 deletions

File tree

libsql-server/src/http/user/listen.rs

Lines changed: 97 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,30 @@ use crate::{
66
namespace::{NamespaceName, NamespaceStore},
77
};
88
use axum::extract::State as AxumState;
9-
use axum::http::header::{CACHE_CONTROL, CONTENT_TYPE};
10-
use axum::http::{HeaderValue, Uri};
11-
use axum::response::{IntoResponse, Redirect, Response};
12-
use axum_extra::{extract::Query, json_lines::JsonLines};
9+
use axum::http::Uri;
10+
use axum::response::{
11+
sse::{Event, Sse},
12+
IntoResponse, Redirect,
13+
};
14+
use axum_extra::extract::Query;
1315
use futures::{Stream, StreamExt};
1416
use hyper::HeaderMap;
1517
use serde::{Deserialize, Serialize};
18+
use std::boxed::Box;
19+
use std::convert::Infallible;
20+
use std::pin::Pin;
21+
use std::time::Duration;
1622
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
1723

1824
use super::db_factory::namespace_from_headers;
1925
use super::AppState;
2026

27+
const LAGGED_MSG: &str = "some changes were lost";
28+
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(15);
29+
const KEEP_ALIVE_TEXT: &str = "keep-alive";
30+
31+
type SseStream = Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>;
32+
2133
#[derive(Debug, Clone, Deserialize)]
2234
#[serde(rename_all = "lowercase")]
2335
pub enum Action {
@@ -33,16 +45,35 @@ pub struct ListenQuery {
3345
action: Option<Vec<Action>>,
3446
}
3547

36-
const EVENT_STREAM: HeaderValue = HeaderValue::from_static("text/event-stream");
37-
const NO_CACHE: HeaderValue = HeaderValue::from_static("no-cache");
48+
#[derive(Debug, Serialize)]
49+
#[serde(rename_all = "snake_case")]
50+
enum AggregatorEvent {
51+
Error(&'static str),
52+
#[serde(untagged)]
53+
Changes(BroadcastMsg),
54+
}
55+
56+
enum ListenResponse {
57+
SSE(Sse<SseStream>),
58+
Redirect(Redirect),
59+
}
60+
61+
impl IntoResponse for ListenResponse {
62+
fn into_response(self) -> axum::response::Response {
63+
match self {
64+
ListenResponse::SSE(sse) => sse.into_response(),
65+
ListenResponse::Redirect(redirect) => redirect.into_response(),
66+
}
67+
}
68+
}
3869

3970
pub(super) async fn handle_listen(
4071
auth: Authenticated,
4172
AxumState(state): AxumState<AppState>,
4273
headers: HeaderMap,
4374
uri: Uri,
4475
query: Query<ListenQuery>,
45-
) -> crate::Result<Response> {
76+
) -> crate::Result<impl IntoResponse> {
4677
let namespace = namespace_from_headers(
4778
&headers,
4879
state.disable_default_namespace,
@@ -53,47 +84,51 @@ pub(super) async fn handle_listen(
5384
return Err(Error::NamespaceDoesntExist(namespace.to_string()));
5485
}
5586

56-
if let Some(primary_url) = state.primary_url {
57-
let url = primary_url + uri.path_and_query().map_or("", |x| x.as_str());
58-
return Ok(Redirect::temporary(&url).into_response());
87+
if let Some(primary_url) = state.primary_url.as_ref() {
88+
let url = format!(
89+
"{}{}",
90+
primary_url,
91+
uri.path_and_query().map_or("", |x| x.as_str())
92+
);
93+
return Ok(ListenResponse::Redirect(Redirect::temporary(&url)));
5994
}
6095

61-
let stream = listen_stream(
96+
let stream = sse_stream(
6297
state.namespaces.clone(),
6398
namespace,
6499
query.table.clone(),
65100
query.action.clone(),
66101
)
67102
.await;
68103

69-
let mut response = JsonLines::new(stream).into_response();
70-
let headers = response.headers_mut();
71-
headers.insert(CONTENT_TYPE, EVENT_STREAM);
72-
headers.insert(CACHE_CONTROL, NO_CACHE);
73-
74-
Ok(response)
104+
Ok(ListenResponse::SSE(
105+
Sse::new(stream).keep_alive(
106+
axum::response::sse::KeepAlive::new()
107+
.interval(KEEP_ALIVE_INTERVAL)
108+
.text(KEEP_ALIVE_TEXT),
109+
),
110+
))
75111
}
76112

77-
static LAGGED_MSG: &str = "some changes were lost";
78-
79-
#[derive(Debug, Serialize)]
80-
#[serde(rename_all = "snake_case")]
81-
enum AggregatorEvent {
82-
Error(&'static str),
83-
#[serde(untagged)]
84-
Changes(BroadcastMsg),
85-
}
86-
87-
struct Subscription {
113+
async fn sse_stream(
88114
store: NamespaceStore,
89115
namespace: NamespaceName,
90116
table: String,
91-
}
92-
93-
impl Drop for Subscription {
94-
fn drop(&mut self) {
95-
self.store.unsubscribe(self.namespace.clone(), &self.table);
96-
}
117+
actions: Option<Vec<Action>>,
118+
) -> SseStream {
119+
Box::pin(
120+
listen_stream(store, namespace, table, actions)
121+
.await
122+
.map(|result| {
123+
Ok(match result {
124+
Ok(AggregatorEvent::Error(msg)) => Event::default().event("error").data(msg),
125+
Ok(AggregatorEvent::Changes(msg)) => {
126+
Event::default().event("changes").json_data(msg).unwrap()
127+
}
128+
Err(e) => Event::default().event("error").data(e.to_string()),
129+
})
130+
}),
131+
)
97132
}
98133

99134
async fn listen_stream(
@@ -103,23 +138,18 @@ async fn listen_stream(
103138
actions: Option<Vec<Action>>,
104139
) -> impl Stream<Item = crate::Result<AggregatorEvent>> {
105140
async_stream::try_stream! {
106-
let _sub = Subscription {
107-
store: store.clone(),
108-
namespace: namespace.clone(),
109-
table: table.clone(),
110-
};
111-
141+
let _sub = Subscription::new(store.clone(), namespace.clone(), table.clone());
112142
let mut stream = store.subscribe(namespace.clone(), table.clone());
113143

114-
while let Some(item) = stream.next().await {
144+
while let Some(item) = stream.next().await {
115145
match item {
116146
Ok(msg) => if filter_actions(&msg, &actions) {
117147
LISTEN_EVENTS_SENT.increment(1);
118148
yield AggregatorEvent::Changes(msg);
119149
},
120150
Err(BroadcastStreamRecvError::Lagged(n)) => {
121151
LISTEN_EVENTS_DROPPED.increment(n as u64);
122-
yield AggregatorEvent::Error(&LAGGED_MSG);
152+
yield AggregatorEvent::Error(LAGGED_MSG);
123153
},
124154
}
125155
}
@@ -128,17 +158,36 @@ async fn listen_stream(
128158

129159
fn filter_actions(msg: &BroadcastMsg, actions: &Option<Vec<Action>>) -> bool {
130160
actions.as_ref().map_or(true, |actions| {
131-
for action in actions {
161+
actions.iter().any(|action| {
132162
let count = match action {
133163
Action::DELETE => msg.delete,
134164
Action::INSERT => msg.insert,
135165
Action::UPDATE => msg.update,
136166
Action::UNKNOWN => msg.unknown,
137167
};
138-
if count > 0 {
139-
return true;
140-
}
141-
}
142-
false
168+
count > 0
169+
})
143170
})
144171
}
172+
173+
struct Subscription {
174+
store: NamespaceStore,
175+
namespace: NamespaceName,
176+
table: String,
177+
}
178+
179+
impl Subscription {
180+
fn new(store: NamespaceStore, namespace: NamespaceName, table: String) -> Self {
181+
Self {
182+
store,
183+
namespace,
184+
table,
185+
}
186+
}
187+
}
188+
189+
impl Drop for Subscription {
190+
fn drop(&mut self) {
191+
self.store.unsubscribe(self.namespace.clone(), &self.table);
192+
}
193+
}

0 commit comments

Comments
 (0)