1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11 UpstreamOAuthProvider,
12};
13use mas_storage::{
14 Clock, Page, Pagination,
15 upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError, DatabaseInconsistencyError,
26 filter::{Filter, StatementExt},
27 iden::UpstreamOAuthAuthorizationSessions,
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32impl Filter for UpstreamOAuthSessionFilter<'_> {
33 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
34 sea_query::Condition::all()
35 .add_option(self.provider().map(|provider| {
36 Expr::col((
37 UpstreamOAuthAuthorizationSessions::Table,
38 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
39 ))
40 .eq(Uuid::from(provider.id))
41 }))
42 .add_option(self.sub_claim().map(|sub| {
43 Expr::col((
44 UpstreamOAuthAuthorizationSessions::Table,
45 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
46 ))
47 .cast_json_field("sub")
48 .eq(sub)
49 }))
50 .add_option(self.sid_claim().map(|sid| {
51 Expr::col((
52 UpstreamOAuthAuthorizationSessions::Table,
53 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
54 ))
55 .cast_json_field("sid")
56 .eq(sid)
57 }))
58 }
59}
60
61pub struct PgUpstreamOAuthSessionRepository<'c> {
64 conn: &'c mut PgConnection,
65}
66
67impl<'c> PgUpstreamOAuthSessionRepository<'c> {
68 pub fn new(conn: &'c mut PgConnection) -> Self {
71 Self { conn }
72 }
73}
74
75#[derive(sqlx::FromRow)]
76#[enum_def]
77struct SessionLookup {
78 upstream_oauth_authorization_session_id: Uuid,
79 upstream_oauth_provider_id: Uuid,
80 upstream_oauth_link_id: Option<Uuid>,
81 state: String,
82 code_challenge_verifier: Option<String>,
83 nonce: Option<String>,
84 id_token: Option<String>,
85 id_token_claims: Option<serde_json::Value>,
86 userinfo: Option<serde_json::Value>,
87 created_at: DateTime<Utc>,
88 completed_at: Option<DateTime<Utc>>,
89 consumed_at: Option<DateTime<Utc>>,
90 extra_callback_parameters: Option<serde_json::Value>,
91 unlinked_at: Option<DateTime<Utc>>,
92}
93
94impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
95 type Error = DatabaseInconsistencyError;
96
97 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
98 let id = value.upstream_oauth_authorization_session_id.into();
99 let state = match (
100 value.upstream_oauth_link_id,
101 value.id_token,
102 value.id_token_claims,
103 value.extra_callback_parameters,
104 value.userinfo,
105 value.completed_at,
106 value.consumed_at,
107 value.unlinked_at,
108 ) {
109 (None, None, None, None, None, None, None, None) => {
110 UpstreamOAuthAuthorizationSessionState::Pending
111 }
112 (
113 Some(link_id),
114 id_token,
115 id_token_claims,
116 extra_callback_parameters,
117 userinfo,
118 Some(completed_at),
119 None,
120 None,
121 ) => UpstreamOAuthAuthorizationSessionState::Completed {
122 completed_at,
123 link_id: link_id.into(),
124 id_token,
125 id_token_claims,
126 extra_callback_parameters,
127 userinfo,
128 },
129 (
130 Some(link_id),
131 id_token,
132 id_token_claims,
133 extra_callback_parameters,
134 userinfo,
135 Some(completed_at),
136 Some(consumed_at),
137 None,
138 ) => UpstreamOAuthAuthorizationSessionState::Consumed {
139 completed_at,
140 link_id: link_id.into(),
141 id_token,
142 id_token_claims,
143 extra_callback_parameters,
144 userinfo,
145 consumed_at,
146 },
147 (
148 _,
149 id_token,
150 id_token_claims,
151 _,
152 _,
153 Some(completed_at),
154 consumed_at,
155 Some(unlinked_at),
156 ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
157 completed_at,
158 id_token,
159 id_token_claims,
160 consumed_at,
161 unlinked_at,
162 },
163 _ => {
164 return Err(DatabaseInconsistencyError::on(
165 "upstream_oauth_authorization_sessions",
166 )
167 .row(id));
168 }
169 };
170
171 Ok(Self {
172 id,
173 provider_id: value.upstream_oauth_provider_id.into(),
174 state_str: value.state,
175 nonce: value.nonce,
176 code_challenge_verifier: value.code_challenge_verifier,
177 created_at: value.created_at,
178 state,
179 })
180 }
181}
182
183#[async_trait]
184impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
185 type Error = DatabaseError;
186
187 #[tracing::instrument(
188 name = "db.upstream_oauth_authorization_session.lookup",
189 skip_all,
190 fields(
191 db.query.text,
192 upstream_oauth_provider.id = %id,
193 ),
194 err,
195 )]
196 async fn lookup(
197 &mut self,
198 id: Ulid,
199 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
200 let res = sqlx::query_as!(
201 SessionLookup,
202 r#"
203 SELECT
204 upstream_oauth_authorization_session_id,
205 upstream_oauth_provider_id,
206 upstream_oauth_link_id,
207 state,
208 code_challenge_verifier,
209 nonce,
210 id_token,
211 id_token_claims,
212 extra_callback_parameters,
213 userinfo,
214 created_at,
215 completed_at,
216 consumed_at,
217 unlinked_at
218 FROM upstream_oauth_authorization_sessions
219 WHERE upstream_oauth_authorization_session_id = $1
220 "#,
221 Uuid::from(id),
222 )
223 .traced()
224 .fetch_optional(&mut *self.conn)
225 .await?;
226
227 let Some(res) = res else { return Ok(None) };
228
229 Ok(Some(res.try_into()?))
230 }
231
232 #[tracing::instrument(
233 name = "db.upstream_oauth_authorization_session.add",
234 skip_all,
235 fields(
236 db.query.text,
237 %upstream_oauth_provider.id,
238 upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
239 %upstream_oauth_provider.client_id,
240 upstream_oauth_authorization_session.id,
241 ),
242 err,
243 )]
244 async fn add(
245 &mut self,
246 rng: &mut (dyn RngCore + Send),
247 clock: &dyn Clock,
248 upstream_oauth_provider: &UpstreamOAuthProvider,
249 state_str: String,
250 code_challenge_verifier: Option<String>,
251 nonce: Option<String>,
252 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
253 let created_at = clock.now();
254 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
255 tracing::Span::current().record(
256 "upstream_oauth_authorization_session.id",
257 tracing::field::display(id),
258 );
259
260 sqlx::query!(
261 r#"
262 INSERT INTO upstream_oauth_authorization_sessions (
263 upstream_oauth_authorization_session_id,
264 upstream_oauth_provider_id,
265 state,
266 code_challenge_verifier,
267 nonce,
268 created_at,
269 completed_at,
270 consumed_at,
271 id_token,
272 userinfo
273 ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
274 "#,
275 Uuid::from(id),
276 Uuid::from(upstream_oauth_provider.id),
277 &state_str,
278 code_challenge_verifier.as_deref(),
279 nonce,
280 created_at,
281 )
282 .traced()
283 .execute(&mut *self.conn)
284 .await?;
285
286 Ok(UpstreamOAuthAuthorizationSession {
287 id,
288 state: UpstreamOAuthAuthorizationSessionState::default(),
289 provider_id: upstream_oauth_provider.id,
290 state_str,
291 code_challenge_verifier,
292 nonce,
293 created_at,
294 })
295 }
296
297 #[tracing::instrument(
298 name = "db.upstream_oauth_authorization_session.complete_with_link",
299 skip_all,
300 fields(
301 db.query.text,
302 %upstream_oauth_authorization_session.id,
303 %upstream_oauth_link.id,
304 ),
305 err,
306 )]
307 async fn complete_with_link(
308 &mut self,
309 clock: &dyn Clock,
310 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
311 upstream_oauth_link: &UpstreamOAuthLink,
312 id_token: Option<String>,
313 id_token_claims: Option<serde_json::Value>,
314 extra_callback_parameters: Option<serde_json::Value>,
315 userinfo: Option<serde_json::Value>,
316 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
317 let completed_at = clock.now();
318
319 sqlx::query!(
320 r#"
321 UPDATE upstream_oauth_authorization_sessions
322 SET upstream_oauth_link_id = $1
323 , completed_at = $2
324 , id_token = $3
325 , id_token_claims = $4
326 , extra_callback_parameters = $5
327 , userinfo = $6
328 WHERE upstream_oauth_authorization_session_id = $7
329 "#,
330 Uuid::from(upstream_oauth_link.id),
331 completed_at,
332 id_token,
333 id_token_claims,
334 extra_callback_parameters,
335 userinfo,
336 Uuid::from(upstream_oauth_authorization_session.id),
337 )
338 .traced()
339 .execute(&mut *self.conn)
340 .await?;
341
342 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
343 .complete(
344 completed_at,
345 upstream_oauth_link,
346 id_token,
347 id_token_claims,
348 extra_callback_parameters,
349 userinfo,
350 )
351 .map_err(DatabaseError::to_invalid_operation)?;
352
353 Ok(upstream_oauth_authorization_session)
354 }
355
356 #[tracing::instrument(
358 name = "db.upstream_oauth_authorization_session.consume",
359 skip_all,
360 fields(
361 db.query.text,
362 %upstream_oauth_authorization_session.id,
363 ),
364 err,
365 )]
366 async fn consume(
367 &mut self,
368 clock: &dyn Clock,
369 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
370 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
371 let consumed_at = clock.now();
372 sqlx::query!(
373 r#"
374 UPDATE upstream_oauth_authorization_sessions
375 SET consumed_at = $1
376 WHERE upstream_oauth_authorization_session_id = $2
377 "#,
378 consumed_at,
379 Uuid::from(upstream_oauth_authorization_session.id),
380 )
381 .traced()
382 .execute(&mut *self.conn)
383 .await?;
384
385 let upstream_oauth_authorization_session = upstream_oauth_authorization_session
386 .consume(consumed_at)
387 .map_err(DatabaseError::to_invalid_operation)?;
388
389 Ok(upstream_oauth_authorization_session)
390 }
391
392 #[tracing::instrument(
393 name = "db.upstream_oauth_authorization_session.list",
394 skip_all,
395 fields(
396 db.query.text,
397 ),
398 err,
399 )]
400 async fn list(
401 &mut self,
402 filter: UpstreamOAuthSessionFilter<'_>,
403 pagination: Pagination,
404 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
405 let (sql, arguments) = Query::select()
406 .expr_as(
407 Expr::col((
408 UpstreamOAuthAuthorizationSessions::Table,
409 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
410 )),
411 SessionLookupIden::UpstreamOauthAuthorizationSessionId,
412 )
413 .expr_as(
414 Expr::col((
415 UpstreamOAuthAuthorizationSessions::Table,
416 UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
417 )),
418 SessionLookupIden::UpstreamOauthProviderId,
419 )
420 .expr_as(
421 Expr::col((
422 UpstreamOAuthAuthorizationSessions::Table,
423 UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
424 )),
425 SessionLookupIden::UpstreamOauthLinkId,
426 )
427 .expr_as(
428 Expr::col((
429 UpstreamOAuthAuthorizationSessions::Table,
430 UpstreamOAuthAuthorizationSessions::State,
431 )),
432 SessionLookupIden::State,
433 )
434 .expr_as(
435 Expr::col((
436 UpstreamOAuthAuthorizationSessions::Table,
437 UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
438 )),
439 SessionLookupIden::CodeChallengeVerifier,
440 )
441 .expr_as(
442 Expr::col((
443 UpstreamOAuthAuthorizationSessions::Table,
444 UpstreamOAuthAuthorizationSessions::Nonce,
445 )),
446 SessionLookupIden::Nonce,
447 )
448 .expr_as(
449 Expr::col((
450 UpstreamOAuthAuthorizationSessions::Table,
451 UpstreamOAuthAuthorizationSessions::IdToken,
452 )),
453 SessionLookupIden::IdToken,
454 )
455 .expr_as(
456 Expr::col((
457 UpstreamOAuthAuthorizationSessions::Table,
458 UpstreamOAuthAuthorizationSessions::IdTokenClaims,
459 )),
460 SessionLookupIden::IdTokenClaims,
461 )
462 .expr_as(
463 Expr::col((
464 UpstreamOAuthAuthorizationSessions::Table,
465 UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
466 )),
467 SessionLookupIden::ExtraCallbackParameters,
468 )
469 .expr_as(
470 Expr::col((
471 UpstreamOAuthAuthorizationSessions::Table,
472 UpstreamOAuthAuthorizationSessions::Userinfo,
473 )),
474 SessionLookupIden::Userinfo,
475 )
476 .expr_as(
477 Expr::col((
478 UpstreamOAuthAuthorizationSessions::Table,
479 UpstreamOAuthAuthorizationSessions::CreatedAt,
480 )),
481 SessionLookupIden::CreatedAt,
482 )
483 .expr_as(
484 Expr::col((
485 UpstreamOAuthAuthorizationSessions::Table,
486 UpstreamOAuthAuthorizationSessions::CompletedAt,
487 )),
488 SessionLookupIden::CompletedAt,
489 )
490 .expr_as(
491 Expr::col((
492 UpstreamOAuthAuthorizationSessions::Table,
493 UpstreamOAuthAuthorizationSessions::ConsumedAt,
494 )),
495 SessionLookupIden::ConsumedAt,
496 )
497 .expr_as(
498 Expr::col((
499 UpstreamOAuthAuthorizationSessions::Table,
500 UpstreamOAuthAuthorizationSessions::UnlinkedAt,
501 )),
502 SessionLookupIden::UnlinkedAt,
503 )
504 .from(UpstreamOAuthAuthorizationSessions::Table)
505 .apply_filter(filter)
506 .generate_pagination(
507 (
508 UpstreamOAuthAuthorizationSessions::Table,
509 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
510 ),
511 pagination,
512 )
513 .build_sqlx(PostgresQueryBuilder);
514
515 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
516 .traced()
517 .fetch_all(&mut *self.conn)
518 .await?;
519
520 let page = pagination
521 .process(edges)
522 .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
523
524 Ok(page)
525 }
526
527 #[tracing::instrument(
528 name = "db.upstream_oauth_authorization_session.count",
529 skip_all,
530 fields(
531 db.query.text,
532 ),
533 err,
534 )]
535 async fn count(
536 &mut self,
537 filter: UpstreamOAuthSessionFilter<'_>,
538 ) -> Result<usize, Self::Error> {
539 let (sql, arguments) = Query::select()
540 .expr(
541 Expr::col((
542 UpstreamOAuthAuthorizationSessions::Table,
543 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
544 ))
545 .count(),
546 )
547 .from(UpstreamOAuthAuthorizationSessions::Table)
548 .apply_filter(filter)
549 .build_sqlx(PostgresQueryBuilder);
550
551 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
552 .traced()
553 .fetch_one(&mut *self.conn)
554 .await?;
555
556 count
557 .try_into()
558 .map_err(DatabaseError::to_invalid_operation)
559 }
560}