mas_storage/upstream_oauth2/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
9use rand_core::RngCore;
10use ulid::Ulid;
11
12use crate::{Clock, Pagination, pagination::Page, repository_impl};
13
14/// Filter parameters for listing upstream OAuth sessions
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
16pub struct UpstreamOAuthSessionFilter<'a> {
17    provider: Option<&'a UpstreamOAuthProvider>,
18    sub_claim: Option<&'a str>,
19    sid_claim: Option<&'a str>,
20}
21
22impl<'a> UpstreamOAuthSessionFilter<'a> {
23    /// Create a new [`UpstreamOAuthSessionFilter`] with default values
24    #[must_use]
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Set the upstream OAuth provider for which to list sessions
30    #[must_use]
31    pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
32        self.provider = Some(provider);
33        self
34    }
35
36    /// Get the upstream OAuth provider filter
37    ///
38    /// Returns [`None`] if no filter was set
39    #[must_use]
40    pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
41        self.provider
42    }
43
44    /// Set the `sub` claim to filter by
45    #[must_use]
46    pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
47        self.sub_claim = Some(sub_claim);
48        self
49    }
50
51    /// Get the `sub` claim filter
52    ///
53    /// Returns [`None`] if no filter was set
54    #[must_use]
55    pub fn sub_claim(&self) -> Option<&str> {
56        self.sub_claim
57    }
58
59    /// Set the `sid` claim to filter by
60    #[must_use]
61    pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
62        self.sid_claim = Some(sid_claim);
63        self
64    }
65
66    /// Get the `sid` claim filter
67    ///
68    /// Returns [`None`] if no filter was set
69    #[must_use]
70    pub fn sid_claim(&self) -> Option<&str> {
71        self.sid_claim
72    }
73}
74
75/// An [`UpstreamOAuthSessionRepository`] helps interacting with
76/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
77#[async_trait]
78pub trait UpstreamOAuthSessionRepository: Send + Sync {
79    /// The error type returned by the repository
80    type Error;
81
82    /// Lookup a session by its ID
83    ///
84    /// Returns `None` if the session does not exist
85    ///
86    /// # Parameters
87    ///
88    /// * `id`: the ID of the session to lookup
89    ///
90    /// # Errors
91    ///
92    /// Returns [`Self::Error`] if the underlying repository fails
93    async fn lookup(
94        &mut self,
95        id: Ulid,
96    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
97
98    /// Add a session to the database
99    ///
100    /// Returns the newly created session
101    ///
102    /// # Parameters
103    ///
104    /// * `rng`: the random number generator to use
105    /// * `clock`: the clock source
106    /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
107    ///   create the session
108    /// * `state`: the authorization grant `state` parameter sent to the
109    ///   upstream OAuth provider
110    /// * `code_challenge_verifier`: the code challenge verifier used in this
111    ///   session, if PKCE is being used
112    /// * `nonce`: the `nonce` used in this session if in OIDC mode
113    ///
114    /// # Errors
115    ///
116    /// Returns [`Self::Error`] if the underlying repository fails
117    async fn add(
118        &mut self,
119        rng: &mut (dyn RngCore + Send),
120        clock: &dyn Clock,
121        upstream_oauth_provider: &UpstreamOAuthProvider,
122        state: String,
123        code_challenge_verifier: Option<String>,
124        nonce: Option<String>,
125    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
126
127    /// Mark a session as completed and associate the given link
128    ///
129    /// Returns the updated session
130    ///
131    /// # Parameters
132    ///
133    /// * `clock`: the clock source
134    /// * `upstream_oauth_authorization_session`: the session to update
135    /// * `upstream_oauth_link`: the link to associate with the session
136    /// * `id_token`: the ID token returned by the upstream OAuth provider, if
137    ///   present
138    /// * `id_token_claims`: the claims contained in the ID token, if present
139    /// * `extra_callback_parameters`: the extra query parameters returned in
140    ///   the callback, if any
141    /// * `userinfo`: the user info returned by the upstream OAuth provider, if
142    ///   requested
143    ///
144    /// # Errors
145    ///
146    /// Returns [`Self::Error`] if the underlying repository fails
147    #[expect(clippy::too_many_arguments)]
148    async fn complete_with_link(
149        &mut self,
150        clock: &dyn Clock,
151        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
152        upstream_oauth_link: &UpstreamOAuthLink,
153        id_token: Option<String>,
154        id_token_claims: Option<serde_json::Value>,
155        extra_callback_parameters: Option<serde_json::Value>,
156        userinfo: Option<serde_json::Value>,
157    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
158
159    /// Mark a session as consumed
160    ///
161    /// Returns the updated session
162    ///
163    /// # Parameters
164    ///
165    /// * `clock`: the clock source
166    /// * `upstream_oauth_authorization_session`: the session to consume
167    ///
168    /// # Errors
169    ///
170    /// Returns [`Self::Error`] if the underlying repository fails
171    async fn consume(
172        &mut self,
173        clock: &dyn Clock,
174        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
175    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
176
177    /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
178    /// pagination
179    ///
180    /// # Parameters
181    ///
182    /// * `filter`: The filter to apply
183    /// * `pagination`: The pagination parameters
184    ///
185    /// # Errors
186    ///
187    /// Returns [`Self::Error`] if the underlying repository fails
188    async fn list(
189        &mut self,
190        filter: UpstreamOAuthSessionFilter<'_>,
191        pagination: Pagination,
192    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
193
194    /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
195    /// filter
196    ///
197    /// # Parameters
198    ///
199    /// * `filter`: The filter to apply
200    ///
201    /// # Errors
202    ///
203    /// Returns [`Self::Error`] if the underlying repository fails
204    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
205    -> Result<usize, Self::Error>;
206}
207
208repository_impl!(UpstreamOAuthSessionRepository:
209    async fn lookup(
210        &mut self,
211        id: Ulid,
212    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
213
214    async fn add(
215        &mut self,
216        rng: &mut (dyn RngCore + Send),
217        clock: &dyn Clock,
218        upstream_oauth_provider: &UpstreamOAuthProvider,
219        state: String,
220        code_challenge_verifier: Option<String>,
221        nonce: Option<String>,
222    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
223
224    async fn complete_with_link(
225        &mut self,
226        clock: &dyn Clock,
227        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
228        upstream_oauth_link: &UpstreamOAuthLink,
229        id_token: Option<String>,
230        id_token_claims: Option<serde_json::Value>,
231        extra_callback_parameters: Option<serde_json::Value>,
232        userinfo: Option<serde_json::Value>,
233    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
234
235    async fn consume(
236        &mut self,
237        clock: &dyn Clock,
238        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
239    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
240
241    async fn list(
242        &mut self,
243        filter: UpstreamOAuthSessionFilter<'_>,
244        pagination: Pagination,
245    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
246
247    async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
248);