mas_storage/upstream_oauth2/session.rs
1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
9use rand_core::RngCore;
10use ulid::Ulid;
11
12use crate::{Clock, Pagination, pagination::Page, repository_impl};
13
14/// Filter parameters for listing upstream OAuth sessions
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
16pub struct UpstreamOAuthSessionFilter<'a> {
17 provider: Option<&'a UpstreamOAuthProvider>,
18 sub_claim: Option<&'a str>,
19 sid_claim: Option<&'a str>,
20}
21
22impl<'a> UpstreamOAuthSessionFilter<'a> {
23 /// Create a new [`UpstreamOAuthSessionFilter`] with default values
24 #[must_use]
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 /// Set the upstream OAuth provider for which to list sessions
30 #[must_use]
31 pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
32 self.provider = Some(provider);
33 self
34 }
35
36 /// Get the upstream OAuth provider filter
37 ///
38 /// Returns [`None`] if no filter was set
39 #[must_use]
40 pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
41 self.provider
42 }
43
44 /// Set the `sub` claim to filter by
45 #[must_use]
46 pub fn with_sub_claim(mut self, sub_claim: &'a str) -> Self {
47 self.sub_claim = Some(sub_claim);
48 self
49 }
50
51 /// Get the `sub` claim filter
52 ///
53 /// Returns [`None`] if no filter was set
54 #[must_use]
55 pub fn sub_claim(&self) -> Option<&str> {
56 self.sub_claim
57 }
58
59 /// Set the `sid` claim to filter by
60 #[must_use]
61 pub fn with_sid_claim(mut self, sid_claim: &'a str) -> Self {
62 self.sid_claim = Some(sid_claim);
63 self
64 }
65
66 /// Get the `sid` claim filter
67 ///
68 /// Returns [`None`] if no filter was set
69 #[must_use]
70 pub fn sid_claim(&self) -> Option<&str> {
71 self.sid_claim
72 }
73}
74
75/// An [`UpstreamOAuthSessionRepository`] helps interacting with
76/// [`UpstreamOAuthAuthorizationSession`] saved in the storage backend
77#[async_trait]
78pub trait UpstreamOAuthSessionRepository: Send + Sync {
79 /// The error type returned by the repository
80 type Error;
81
82 /// Lookup a session by its ID
83 ///
84 /// Returns `None` if the session does not exist
85 ///
86 /// # Parameters
87 ///
88 /// * `id`: the ID of the session to lookup
89 ///
90 /// # Errors
91 ///
92 /// Returns [`Self::Error`] if the underlying repository fails
93 async fn lookup(
94 &mut self,
95 id: Ulid,
96 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
97
98 /// Add a session to the database
99 ///
100 /// Returns the newly created session
101 ///
102 /// # Parameters
103 ///
104 /// * `rng`: the random number generator to use
105 /// * `clock`: the clock source
106 /// * `upstream_oauth_provider`: the upstream OAuth provider for which to
107 /// create the session
108 /// * `state`: the authorization grant `state` parameter sent to the
109 /// upstream OAuth provider
110 /// * `code_challenge_verifier`: the code challenge verifier used in this
111 /// session, if PKCE is being used
112 /// * `nonce`: the `nonce` used in this session if in OIDC mode
113 ///
114 /// # Errors
115 ///
116 /// Returns [`Self::Error`] if the underlying repository fails
117 async fn add(
118 &mut self,
119 rng: &mut (dyn RngCore + Send),
120 clock: &dyn Clock,
121 upstream_oauth_provider: &UpstreamOAuthProvider,
122 state: String,
123 code_challenge_verifier: Option<String>,
124 nonce: Option<String>,
125 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
126
127 /// Mark a session as completed and associate the given link
128 ///
129 /// Returns the updated session
130 ///
131 /// # Parameters
132 ///
133 /// * `clock`: the clock source
134 /// * `upstream_oauth_authorization_session`: the session to update
135 /// * `upstream_oauth_link`: the link to associate with the session
136 /// * `id_token`: the ID token returned by the upstream OAuth provider, if
137 /// present
138 /// * `id_token_claims`: the claims contained in the ID token, if present
139 /// * `extra_callback_parameters`: the extra query parameters returned in
140 /// the callback, if any
141 /// * `userinfo`: the user info returned by the upstream OAuth provider, if
142 /// requested
143 ///
144 /// # Errors
145 ///
146 /// Returns [`Self::Error`] if the underlying repository fails
147 #[expect(clippy::too_many_arguments)]
148 async fn complete_with_link(
149 &mut self,
150 clock: &dyn Clock,
151 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
152 upstream_oauth_link: &UpstreamOAuthLink,
153 id_token: Option<String>,
154 id_token_claims: Option<serde_json::Value>,
155 extra_callback_parameters: Option<serde_json::Value>,
156 userinfo: Option<serde_json::Value>,
157 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
158
159 /// Mark a session as consumed
160 ///
161 /// Returns the updated session
162 ///
163 /// # Parameters
164 ///
165 /// * `clock`: the clock source
166 /// * `upstream_oauth_authorization_session`: the session to consume
167 ///
168 /// # Errors
169 ///
170 /// Returns [`Self::Error`] if the underlying repository fails
171 async fn consume(
172 &mut self,
173 clock: &dyn Clock,
174 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
175 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
176
177 /// List [`UpstreamOAuthAuthorizationSession`] with the given filter and
178 /// pagination
179 ///
180 /// # Parameters
181 ///
182 /// * `filter`: The filter to apply
183 /// * `pagination`: The pagination parameters
184 ///
185 /// # Errors
186 ///
187 /// Returns [`Self::Error`] if the underlying repository fails
188 async fn list(
189 &mut self,
190 filter: UpstreamOAuthSessionFilter<'_>,
191 pagination: Pagination,
192 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
193
194 /// Count the number of [`UpstreamOAuthAuthorizationSession`] with the given
195 /// filter
196 ///
197 /// # Parameters
198 ///
199 /// * `filter`: The filter to apply
200 ///
201 /// # Errors
202 ///
203 /// Returns [`Self::Error`] if the underlying repository fails
204 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>)
205 -> Result<usize, Self::Error>;
206}
207
208repository_impl!(UpstreamOAuthSessionRepository:
209 async fn lookup(
210 &mut self,
211 id: Ulid,
212 ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error>;
213
214 async fn add(
215 &mut self,
216 rng: &mut (dyn RngCore + Send),
217 clock: &dyn Clock,
218 upstream_oauth_provider: &UpstreamOAuthProvider,
219 state: String,
220 code_challenge_verifier: Option<String>,
221 nonce: Option<String>,
222 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
223
224 async fn complete_with_link(
225 &mut self,
226 clock: &dyn Clock,
227 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
228 upstream_oauth_link: &UpstreamOAuthLink,
229 id_token: Option<String>,
230 id_token_claims: Option<serde_json::Value>,
231 extra_callback_parameters: Option<serde_json::Value>,
232 userinfo: Option<serde_json::Value>,
233 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
234
235 async fn consume(
236 &mut self,
237 clock: &dyn Clock,
238 upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
239 ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
240
241 async fn list(
242 &mut self,
243 filter: UpstreamOAuthSessionFilter<'_>,
244 pagination: Pagination,
245 ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error>;
246
247 async fn count(&mut self, filter: UpstreamOAuthSessionFilter<'_>) -> Result<usize, Self::Error>;
248);