mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    UpstreamOAuthAuthorizationSession, UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink,
11    UpstreamOAuthProvider,
12};
13use mas_storage::{
14    Clock, Page, Pagination,
15    upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25    DatabaseError, DatabaseInconsistencyError,
26    filter::{Filter, StatementExt},
27    iden::UpstreamOAuthAuthorizationSessions,
28    pagination::QueryBuilderExt,
29    tracing::ExecuteExt,
30};
31
32impl Filter for UpstreamOAuthSessionFilter<'_> {
33    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
34        sea_query::Condition::all()
35            .add_option(self.provider().map(|provider| {
36                Expr::col((
37                    UpstreamOAuthAuthorizationSessions::Table,
38                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
39                ))
40                .eq(Uuid::from(provider.id))
41            }))
42            .add_option(self.sub_claim().map(|sub| {
43                Expr::col((
44                    UpstreamOAuthAuthorizationSessions::Table,
45                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
46                ))
47                .cast_json_field("sub")
48                .eq(sub)
49            }))
50            .add_option(self.sid_claim().map(|sid| {
51                Expr::col((
52                    UpstreamOAuthAuthorizationSessions::Table,
53                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
54                ))
55                .cast_json_field("sid")
56                .eq(sid)
57            }))
58    }
59}
60
61/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
62/// connection
63pub struct PgUpstreamOAuthSessionRepository<'c> {
64    conn: &'c mut PgConnection,
65}
66
67impl<'c> PgUpstreamOAuthSessionRepository<'c> {
68    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
69    /// PostgreSQL connection
70    pub fn new(conn: &'c mut PgConnection) -> Self {
71        Self { conn }
72    }
73}
74
75#[derive(sqlx::FromRow)]
76#[enum_def]
77struct SessionLookup {
78    upstream_oauth_authorization_session_id: Uuid,
79    upstream_oauth_provider_id: Uuid,
80    upstream_oauth_link_id: Option<Uuid>,
81    state: String,
82    code_challenge_verifier: Option<String>,
83    nonce: Option<String>,
84    id_token: Option<String>,
85    id_token_claims: Option<serde_json::Value>,
86    userinfo: Option<serde_json::Value>,
87    created_at: DateTime<Utc>,
88    completed_at: Option<DateTime<Utc>>,
89    consumed_at: Option<DateTime<Utc>>,
90    extra_callback_parameters: Option<serde_json::Value>,
91    unlinked_at: Option<DateTime<Utc>>,
92}
93
94impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
95    type Error = DatabaseInconsistencyError;
96
97    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
98        let id = value.upstream_oauth_authorization_session_id.into();
99        let state = match (
100            value.upstream_oauth_link_id,
101            value.id_token,
102            value.id_token_claims,
103            value.extra_callback_parameters,
104            value.userinfo,
105            value.completed_at,
106            value.consumed_at,
107            value.unlinked_at,
108        ) {
109            (None, None, None, None, None, None, None, None) => {
110                UpstreamOAuthAuthorizationSessionState::Pending
111            }
112            (
113                Some(link_id),
114                id_token,
115                id_token_claims,
116                extra_callback_parameters,
117                userinfo,
118                Some(completed_at),
119                None,
120                None,
121            ) => UpstreamOAuthAuthorizationSessionState::Completed {
122                completed_at,
123                link_id: link_id.into(),
124                id_token,
125                id_token_claims,
126                extra_callback_parameters,
127                userinfo,
128            },
129            (
130                Some(link_id),
131                id_token,
132                id_token_claims,
133                extra_callback_parameters,
134                userinfo,
135                Some(completed_at),
136                Some(consumed_at),
137                None,
138            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
139                completed_at,
140                link_id: link_id.into(),
141                id_token,
142                id_token_claims,
143                extra_callback_parameters,
144                userinfo,
145                consumed_at,
146            },
147            (
148                _,
149                id_token,
150                id_token_claims,
151                _,
152                _,
153                Some(completed_at),
154                consumed_at,
155                Some(unlinked_at),
156            ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
157                completed_at,
158                id_token,
159                id_token_claims,
160                consumed_at,
161                unlinked_at,
162            },
163            _ => {
164                return Err(DatabaseInconsistencyError::on(
165                    "upstream_oauth_authorization_sessions",
166                )
167                .row(id));
168            }
169        };
170
171        Ok(Self {
172            id,
173            provider_id: value.upstream_oauth_provider_id.into(),
174            state_str: value.state,
175            nonce: value.nonce,
176            code_challenge_verifier: value.code_challenge_verifier,
177            created_at: value.created_at,
178            state,
179        })
180    }
181}
182
183#[async_trait]
184impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
185    type Error = DatabaseError;
186
187    #[tracing::instrument(
188        name = "db.upstream_oauth_authorization_session.lookup",
189        skip_all,
190        fields(
191            db.query.text,
192            upstream_oauth_provider.id = %id,
193        ),
194        err,
195    )]
196    async fn lookup(
197        &mut self,
198        id: Ulid,
199    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
200        let res = sqlx::query_as!(
201            SessionLookup,
202            r#"
203                SELECT
204                    upstream_oauth_authorization_session_id,
205                    upstream_oauth_provider_id,
206                    upstream_oauth_link_id,
207                    state,
208                    code_challenge_verifier,
209                    nonce,
210                    id_token,
211                    id_token_claims,
212                    extra_callback_parameters,
213                    userinfo,
214                    created_at,
215                    completed_at,
216                    consumed_at,
217                    unlinked_at
218                FROM upstream_oauth_authorization_sessions
219                WHERE upstream_oauth_authorization_session_id = $1
220            "#,
221            Uuid::from(id),
222        )
223        .traced()
224        .fetch_optional(&mut *self.conn)
225        .await?;
226
227        let Some(res) = res else { return Ok(None) };
228
229        Ok(Some(res.try_into()?))
230    }
231
232    #[tracing::instrument(
233        name = "db.upstream_oauth_authorization_session.add",
234        skip_all,
235        fields(
236            db.query.text,
237            %upstream_oauth_provider.id,
238            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
239            %upstream_oauth_provider.client_id,
240            upstream_oauth_authorization_session.id,
241        ),
242        err,
243    )]
244    async fn add(
245        &mut self,
246        rng: &mut (dyn RngCore + Send),
247        clock: &dyn Clock,
248        upstream_oauth_provider: &UpstreamOAuthProvider,
249        state_str: String,
250        code_challenge_verifier: Option<String>,
251        nonce: Option<String>,
252    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
253        let created_at = clock.now();
254        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
255        tracing::Span::current().record(
256            "upstream_oauth_authorization_session.id",
257            tracing::field::display(id),
258        );
259
260        sqlx::query!(
261            r#"
262                INSERT INTO upstream_oauth_authorization_sessions (
263                    upstream_oauth_authorization_session_id,
264                    upstream_oauth_provider_id,
265                    state,
266                    code_challenge_verifier,
267                    nonce,
268                    created_at,
269                    completed_at,
270                    consumed_at,
271                    id_token,
272                    userinfo
273                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
274            "#,
275            Uuid::from(id),
276            Uuid::from(upstream_oauth_provider.id),
277            &state_str,
278            code_challenge_verifier.as_deref(),
279            nonce,
280            created_at,
281        )
282        .traced()
283        .execute(&mut *self.conn)
284        .await?;
285
286        Ok(UpstreamOAuthAuthorizationSession {
287            id,
288            state: UpstreamOAuthAuthorizationSessionState::default(),
289            provider_id: upstream_oauth_provider.id,
290            state_str,
291            code_challenge_verifier,
292            nonce,
293            created_at,
294        })
295    }
296
297    #[tracing::instrument(
298        name = "db.upstream_oauth_authorization_session.complete_with_link",
299        skip_all,
300        fields(
301            db.query.text,
302            %upstream_oauth_authorization_session.id,
303            %upstream_oauth_link.id,
304        ),
305        err,
306    )]
307    async fn complete_with_link(
308        &mut self,
309        clock: &dyn Clock,
310        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
311        upstream_oauth_link: &UpstreamOAuthLink,
312        id_token: Option<String>,
313        id_token_claims: Option<serde_json::Value>,
314        extra_callback_parameters: Option<serde_json::Value>,
315        userinfo: Option<serde_json::Value>,
316    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
317        let completed_at = clock.now();
318
319        sqlx::query!(
320            r#"
321                UPDATE upstream_oauth_authorization_sessions
322                SET upstream_oauth_link_id = $1
323                  , completed_at = $2
324                  , id_token = $3
325                  , id_token_claims = $4
326                  , extra_callback_parameters = $5
327                  , userinfo = $6
328                WHERE upstream_oauth_authorization_session_id = $7
329            "#,
330            Uuid::from(upstream_oauth_link.id),
331            completed_at,
332            id_token,
333            id_token_claims,
334            extra_callback_parameters,
335            userinfo,
336            Uuid::from(upstream_oauth_authorization_session.id),
337        )
338        .traced()
339        .execute(&mut *self.conn)
340        .await?;
341
342        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
343            .complete(
344                completed_at,
345                upstream_oauth_link,
346                id_token,
347                id_token_claims,
348                extra_callback_parameters,
349                userinfo,
350            )
351            .map_err(DatabaseError::to_invalid_operation)?;
352
353        Ok(upstream_oauth_authorization_session)
354    }
355
356    /// Mark a session as consumed
357    #[tracing::instrument(
358        name = "db.upstream_oauth_authorization_session.consume",
359        skip_all,
360        fields(
361            db.query.text,
362            %upstream_oauth_authorization_session.id,
363        ),
364        err,
365    )]
366    async fn consume(
367        &mut self,
368        clock: &dyn Clock,
369        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
370    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
371        let consumed_at = clock.now();
372        sqlx::query!(
373            r#"
374                UPDATE upstream_oauth_authorization_sessions
375                SET consumed_at = $1
376                WHERE upstream_oauth_authorization_session_id = $2
377            "#,
378            consumed_at,
379            Uuid::from(upstream_oauth_authorization_session.id),
380        )
381        .traced()
382        .execute(&mut *self.conn)
383        .await?;
384
385        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
386            .consume(consumed_at)
387            .map_err(DatabaseError::to_invalid_operation)?;
388
389        Ok(upstream_oauth_authorization_session)
390    }
391
392    #[tracing::instrument(
393        name = "db.upstream_oauth_authorization_session.list",
394        skip_all,
395        fields(
396            db.query.text,
397        ),
398        err,
399    )]
400    async fn list(
401        &mut self,
402        filter: UpstreamOAuthSessionFilter<'_>,
403        pagination: Pagination,
404    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
405        let (sql, arguments) = Query::select()
406            .expr_as(
407                Expr::col((
408                    UpstreamOAuthAuthorizationSessions::Table,
409                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
410                )),
411                SessionLookupIden::UpstreamOauthAuthorizationSessionId,
412            )
413            .expr_as(
414                Expr::col((
415                    UpstreamOAuthAuthorizationSessions::Table,
416                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
417                )),
418                SessionLookupIden::UpstreamOauthProviderId,
419            )
420            .expr_as(
421                Expr::col((
422                    UpstreamOAuthAuthorizationSessions::Table,
423                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
424                )),
425                SessionLookupIden::UpstreamOauthLinkId,
426            )
427            .expr_as(
428                Expr::col((
429                    UpstreamOAuthAuthorizationSessions::Table,
430                    UpstreamOAuthAuthorizationSessions::State,
431                )),
432                SessionLookupIden::State,
433            )
434            .expr_as(
435                Expr::col((
436                    UpstreamOAuthAuthorizationSessions::Table,
437                    UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
438                )),
439                SessionLookupIden::CodeChallengeVerifier,
440            )
441            .expr_as(
442                Expr::col((
443                    UpstreamOAuthAuthorizationSessions::Table,
444                    UpstreamOAuthAuthorizationSessions::Nonce,
445                )),
446                SessionLookupIden::Nonce,
447            )
448            .expr_as(
449                Expr::col((
450                    UpstreamOAuthAuthorizationSessions::Table,
451                    UpstreamOAuthAuthorizationSessions::IdToken,
452                )),
453                SessionLookupIden::IdToken,
454            )
455            .expr_as(
456                Expr::col((
457                    UpstreamOAuthAuthorizationSessions::Table,
458                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
459                )),
460                SessionLookupIden::IdTokenClaims,
461            )
462            .expr_as(
463                Expr::col((
464                    UpstreamOAuthAuthorizationSessions::Table,
465                    UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
466                )),
467                SessionLookupIden::ExtraCallbackParameters,
468            )
469            .expr_as(
470                Expr::col((
471                    UpstreamOAuthAuthorizationSessions::Table,
472                    UpstreamOAuthAuthorizationSessions::Userinfo,
473                )),
474                SessionLookupIden::Userinfo,
475            )
476            .expr_as(
477                Expr::col((
478                    UpstreamOAuthAuthorizationSessions::Table,
479                    UpstreamOAuthAuthorizationSessions::CreatedAt,
480                )),
481                SessionLookupIden::CreatedAt,
482            )
483            .expr_as(
484                Expr::col((
485                    UpstreamOAuthAuthorizationSessions::Table,
486                    UpstreamOAuthAuthorizationSessions::CompletedAt,
487                )),
488                SessionLookupIden::CompletedAt,
489            )
490            .expr_as(
491                Expr::col((
492                    UpstreamOAuthAuthorizationSessions::Table,
493                    UpstreamOAuthAuthorizationSessions::ConsumedAt,
494                )),
495                SessionLookupIden::ConsumedAt,
496            )
497            .expr_as(
498                Expr::col((
499                    UpstreamOAuthAuthorizationSessions::Table,
500                    UpstreamOAuthAuthorizationSessions::UnlinkedAt,
501                )),
502                SessionLookupIden::UnlinkedAt,
503            )
504            .from(UpstreamOAuthAuthorizationSessions::Table)
505            .apply_filter(filter)
506            .generate_pagination(
507                (
508                    UpstreamOAuthAuthorizationSessions::Table,
509                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
510                ),
511                pagination,
512            )
513            .build_sqlx(PostgresQueryBuilder);
514
515        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
516            .traced()
517            .fetch_all(&mut *self.conn)
518            .await?;
519
520        let page = pagination
521            .process(edges)
522            .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
523
524        Ok(page)
525    }
526
527    #[tracing::instrument(
528        name = "db.upstream_oauth_authorization_session.count",
529        skip_all,
530        fields(
531            db.query.text,
532        ),
533        err,
534    )]
535    async fn count(
536        &mut self,
537        filter: UpstreamOAuthSessionFilter<'_>,
538    ) -> Result<usize, Self::Error> {
539        let (sql, arguments) = Query::select()
540            .expr(
541                Expr::col((
542                    UpstreamOAuthAuthorizationSessions::Table,
543                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
544                ))
545                .count(),
546            )
547            .from(UpstreamOAuthAuthorizationSessions::Table)
548            .apply_filter(filter)
549            .build_sqlx(PostgresQueryBuilder);
550
551        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
552            .traced()
553            .fetch_one(&mut *self.conn)
554            .await?;
555
556        count
557            .try_into()
558            .map_err(DatabaseError::to_invalid_operation)
559    }
560}