mas_config/sections/
upstream_oauth2.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::BTreeMap;
8
9use camino::Utf8PathBuf;
10use mas_iana::jose::JsonWebSignatureAlg;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::skip_serializing_none;
14use ulid::Ulid;
15use url::Url;
16
17use crate::ConfigurationSection;
18
19/// Upstream OAuth 2.0 providers configuration
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
21pub struct UpstreamOAuth2Config {
22    /// List of OAuth 2.0 providers
23    pub providers: Vec<Provider>,
24}
25
26impl UpstreamOAuth2Config {
27    /// Returns true if the configuration is the default one
28    pub(crate) fn is_default(&self) -> bool {
29        self.providers.is_empty()
30    }
31}
32
33impl ConfigurationSection for UpstreamOAuth2Config {
34    const PATH: Option<&'static str> = Some("upstream_oauth2");
35
36    fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
37        for (index, provider) in self.providers.iter().enumerate() {
38            let annotate = |mut error: figment::Error| {
39                error.metadata = figment
40                    .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
41                    .cloned();
42                error.profile = Some(figment::Profile::Default);
43                error.path = vec![
44                    Self::PATH.unwrap().to_owned(),
45                    "providers".to_owned(),
46                    index.to_string(),
47                ];
48                Err(error)
49            };
50
51            if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
52                && provider.issuer.is_none()
53            {
54                return annotate(figment::Error::custom(
55                    "The `issuer` field is required when discovery is enabled",
56                ));
57            }
58
59            match provider.token_endpoint_auth_method {
60                TokenAuthMethod::None
61                | TokenAuthMethod::PrivateKeyJwt
62                | TokenAuthMethod::SignInWithApple => {
63                    if provider.client_secret.is_some() {
64                        return annotate(figment::Error::custom(
65                            "Unexpected field `client_secret` for the selected authentication method",
66                        ));
67                    }
68                }
69                TokenAuthMethod::ClientSecretBasic
70                | TokenAuthMethod::ClientSecretPost
71                | TokenAuthMethod::ClientSecretJwt => {
72                    if provider.client_secret.is_none() {
73                        return annotate(figment::Error::missing_field("client_secret"));
74                    }
75                }
76            }
77
78            match provider.token_endpoint_auth_method {
79                TokenAuthMethod::None
80                | TokenAuthMethod::ClientSecretBasic
81                | TokenAuthMethod::ClientSecretPost
82                | TokenAuthMethod::SignInWithApple => {
83                    if provider.token_endpoint_auth_signing_alg.is_some() {
84                        return annotate(figment::Error::custom(
85                            "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
86                        ));
87                    }
88                }
89                TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
90                    if provider.token_endpoint_auth_signing_alg.is_none() {
91                        return annotate(figment::Error::missing_field(
92                            "token_endpoint_auth_signing_alg",
93                        ));
94                    }
95                }
96            }
97
98            match provider.token_endpoint_auth_method {
99                TokenAuthMethod::SignInWithApple => {
100                    if provider.sign_in_with_apple.is_none() {
101                        return annotate(figment::Error::missing_field("sign_in_with_apple"));
102                    }
103                }
104
105                _ => {
106                    if provider.sign_in_with_apple.is_some() {
107                        return annotate(figment::Error::custom(
108                            "Unexpected field `sign_in_with_apple` for the selected authentication method",
109                        ));
110                    }
111                }
112            }
113        }
114
115        Ok(())
116    }
117}
118
119/// The response mode we ask the provider to use for the callback
120#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
121#[serde(rename_all = "snake_case")]
122pub enum ResponseMode {
123    /// `query`: The provider will send the response as a query string in the
124    /// URL search parameters
125    Query,
126
127    /// `form_post`: The provider will send the response as a POST request with
128    /// the response parameters in the request body
129    ///
130    /// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
131    FormPost,
132}
133
134/// Authentication methods used against the OAuth 2.0 provider
135#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
136#[serde(rename_all = "snake_case")]
137pub enum TokenAuthMethod {
138    /// `none`: No authentication
139    None,
140
141    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
142    /// authorization credentials
143    ClientSecretBasic,
144
145    /// `client_secret_post`: `client_id` and `client_secret` sent in the
146    /// request body
147    ClientSecretPost,
148
149    /// `client_secret_jwt`: a `client_assertion` sent in the request body and
150    /// signed using the `client_secret`
151    ClientSecretJwt,
152
153    /// `private_key_jwt`: a `client_assertion` sent in the request body and
154    /// signed by an asymmetric key
155    PrivateKeyJwt,
156
157    /// `sign_in_with_apple`: a special method for Signin with Apple
158    SignInWithApple,
159}
160
161/// How to handle a claim
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
163#[serde(rename_all = "lowercase")]
164pub enum ImportAction {
165    /// Ignore the claim
166    #[default]
167    Ignore,
168
169    /// Suggest the claim value, but allow the user to change it
170    Suggest,
171
172    /// Force the claim value, but don't fail if it is missing
173    Force,
174
175    /// Force the claim value, and fail if it is missing
176    Require,
177}
178
179impl ImportAction {
180    #[allow(clippy::trivially_copy_pass_by_ref)]
181    const fn is_default(&self) -> bool {
182        matches!(self, ImportAction::Ignore)
183    }
184}
185
186/// What should be done for the subject attribute
187#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
188pub struct SubjectImportPreference {
189    /// The Jinja2 template to use for the subject attribute
190    ///
191    /// If not provided, the default template is `{{ user.sub }}`
192    #[serde(default, skip_serializing_if = "Option::is_none")]
193    pub template: Option<String>,
194}
195
196impl SubjectImportPreference {
197    const fn is_default(&self) -> bool {
198        self.template.is_none()
199    }
200}
201
202/// What should be done for the localpart attribute
203#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
204pub struct LocalpartImportPreference {
205    /// How to handle the attribute
206    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
207    pub action: ImportAction,
208
209    /// The Jinja2 template to use for the localpart attribute
210    ///
211    /// If not provided, the default template is `{{ user.preferred_username }}`
212    #[serde(default, skip_serializing_if = "Option::is_none")]
213    pub template: Option<String>,
214}
215
216impl LocalpartImportPreference {
217    const fn is_default(&self) -> bool {
218        self.action.is_default() && self.template.is_none()
219    }
220}
221
222/// What should be done for the displayname attribute
223#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
224pub struct DisplaynameImportPreference {
225    /// How to handle the attribute
226    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
227    pub action: ImportAction,
228
229    /// The Jinja2 template to use for the displayname attribute
230    ///
231    /// If not provided, the default template is `{{ user.name }}`
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub template: Option<String>,
234}
235
236impl DisplaynameImportPreference {
237    const fn is_default(&self) -> bool {
238        self.action.is_default() && self.template.is_none()
239    }
240}
241
242/// What should be done with the email attribute
243#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
244pub struct EmailImportPreference {
245    /// How to handle the claim
246    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
247    pub action: ImportAction,
248
249    /// The Jinja2 template to use for the email address attribute
250    ///
251    /// If not provided, the default template is `{{ user.email }}`
252    #[serde(default, skip_serializing_if = "Option::is_none")]
253    pub template: Option<String>,
254}
255
256impl EmailImportPreference {
257    const fn is_default(&self) -> bool {
258        self.action.is_default() && self.template.is_none()
259    }
260}
261
262/// What should be done for the account name attribute
263#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
264pub struct AccountNameImportPreference {
265    /// The Jinja2 template to use for the account name. This name is only used
266    /// for display purposes.
267    ///
268    /// If not provided, it will be ignored.
269    #[serde(default, skip_serializing_if = "Option::is_none")]
270    pub template: Option<String>,
271}
272
273impl AccountNameImportPreference {
274    const fn is_default(&self) -> bool {
275        self.template.is_none()
276    }
277}
278
279/// How claims should be imported
280#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
281pub struct ClaimsImports {
282    /// How to determine the subject of the user
283    #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
284    pub subject: SubjectImportPreference,
285
286    /// Import the localpart of the MXID
287    #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
288    pub localpart: LocalpartImportPreference,
289
290    /// Import the displayname of the user.
291    #[serde(
292        default,
293        skip_serializing_if = "DisplaynameImportPreference::is_default"
294    )]
295    pub displayname: DisplaynameImportPreference,
296
297    /// Import the email address of the user based on the `email` and
298    /// `email_verified` claims
299    #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
300    pub email: EmailImportPreference,
301
302    /// Set a human-readable name for the upstream account for display purposes
303    #[serde(
304        default,
305        skip_serializing_if = "AccountNameImportPreference::is_default"
306    )]
307    pub account_name: AccountNameImportPreference,
308}
309
310impl ClaimsImports {
311    const fn is_default(&self) -> bool {
312        self.subject.is_default()
313            && self.localpart.is_default()
314            && self.displayname.is_default()
315            && self.email.is_default()
316    }
317}
318
319/// How to discover the provider's configuration
320#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
321#[serde(rename_all = "snake_case")]
322pub enum DiscoveryMode {
323    /// Use OIDC discovery with strict metadata verification
324    #[default]
325    Oidc,
326
327    /// Use OIDC discovery with relaxed metadata verification
328    Insecure,
329
330    /// Use a static configuration
331    Disabled,
332}
333
334impl DiscoveryMode {
335    #[allow(clippy::trivially_copy_pass_by_ref)]
336    const fn is_default(&self) -> bool {
337        matches!(self, DiscoveryMode::Oidc)
338    }
339}
340
341/// Whether to use proof key for code exchange (PKCE) when requesting and
342/// exchanging the token.
343#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
344#[serde(rename_all = "snake_case")]
345pub enum PkceMethod {
346    /// Use PKCE if the provider supports it
347    ///
348    /// Defaults to no PKCE if provider discovery is disabled
349    #[default]
350    Auto,
351
352    /// Always use PKCE with the S256 challenge method
353    Always,
354
355    /// Never use PKCE
356    Never,
357}
358
359impl PkceMethod {
360    #[allow(clippy::trivially_copy_pass_by_ref)]
361    const fn is_default(&self) -> bool {
362        matches!(self, PkceMethod::Auto)
363    }
364}
365
366fn default_true() -> bool {
367    true
368}
369
370#[allow(clippy::trivially_copy_pass_by_ref)]
371fn is_default_true(value: &bool) -> bool {
372    *value
373}
374
375#[allow(clippy::ref_option)]
376fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
377    *signed_response_alg == signed_response_alg_default()
378}
379
380#[allow(clippy::unnecessary_wraps)]
381fn signed_response_alg_default() -> JsonWebSignatureAlg {
382    JsonWebSignatureAlg::Rs256
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
386pub struct SignInWithApple {
387    /// The private key file used to sign the `id_token`
388    #[serde(skip_serializing_if = "Option::is_none")]
389    #[schemars(with = "Option<String>")]
390    pub private_key_file: Option<Utf8PathBuf>,
391
392    /// The private key used to sign the `id_token`
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub private_key: Option<String>,
395
396    /// The Team ID of the Apple Developer Portal
397    pub team_id: String,
398
399    /// The key ID of the Apple Developer Portal
400    pub key_id: String,
401}
402
403fn default_scope() -> String {
404    "openid".to_owned()
405}
406
407fn is_default_scope(scope: &str) -> bool {
408    scope == default_scope()
409}
410
411/// What to do when receiving an OIDC Backchannel logout request.
412#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
413#[serde(rename_all = "snake_case")]
414pub enum OnBackchannelLogout {
415    /// Do nothing
416    #[default]
417    DoNothing,
418
419    /// Only log out the MAS 'browser session' started by this OIDC session
420    LogoutBrowserOnly,
421
422    /// Log out all sessions started by this OIDC session, including MAS
423    /// 'browser sessions' and client sessions
424    LogoutAll,
425}
426
427impl OnBackchannelLogout {
428    #[allow(clippy::trivially_copy_pass_by_ref)]
429    const fn is_default(&self) -> bool {
430        matches!(self, OnBackchannelLogout::DoNothing)
431    }
432}
433
434/// Configuration for one upstream OAuth 2 provider.
435#[skip_serializing_none]
436#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
437pub struct Provider {
438    /// Whether this provider is enabled.
439    ///
440    /// Defaults to `true`
441    #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
442    pub enabled: bool,
443
444    /// An internal unique identifier for this provider
445    #[schemars(
446        with = "String",
447        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
448        description = "A ULID as per https://github.com/ulid/spec"
449    )]
450    pub id: Ulid,
451
452    /// The ID of the provider that was used by Synapse.
453    /// In order to perform a Synapse-to-MAS migration, this must be specified.
454    ///
455    /// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
456    ///
457    /// ### For `oidc_providers`:
458    /// This should be specified as `oidc-` followed by the ID that was
459    /// configured as `idp_id` in one of the `oidc_providers` in the Synapse
460    /// configuration.
461    /// For example, if Synapse's configuration contained `idp_id: wombat` for
462    /// this provider, then specify `oidc-wombat` here.
463    ///
464    /// ### For `oidc_config` (legacy):
465    /// Specify `oidc` here.
466    #[serde(skip_serializing_if = "Option::is_none")]
467    pub synapse_idp_id: Option<String>,
468
469    /// The OIDC issuer URL
470    ///
471    /// This is required if OIDC discovery is enabled (which is the default)
472    #[serde(skip_serializing_if = "Option::is_none")]
473    pub issuer: Option<String>,
474
475    /// A human-readable name for the provider, that will be shown to users
476    #[serde(skip_serializing_if = "Option::is_none")]
477    pub human_name: Option<String>,
478
479    /// A brand identifier used to customise the UI, e.g. `apple`, `google`,
480    /// `github`, etc.
481    ///
482    /// Values supported by the default template are:
483    ///
484    ///  - `apple`
485    ///  - `google`
486    ///  - `facebook`
487    ///  - `github`
488    ///  - `gitlab`
489    ///  - `twitter`
490    ///  - `discord`
491    #[serde(skip_serializing_if = "Option::is_none")]
492    pub brand_name: Option<String>,
493
494    /// The client ID to use when authenticating with the provider
495    pub client_id: String,
496
497    /// The client secret to use when authenticating with the provider
498    ///
499    /// Used by the `client_secret_basic`, `client_secret_post`, and
500    /// `client_secret_jwt` methods
501    #[serde(skip_serializing_if = "Option::is_none")]
502    pub client_secret: Option<String>,
503
504    /// The method to authenticate the client with the provider
505    pub token_endpoint_auth_method: TokenAuthMethod,
506
507    /// Additional parameters for the `sign_in_with_apple` method
508    #[serde(skip_serializing_if = "Option::is_none")]
509    pub sign_in_with_apple: Option<SignInWithApple>,
510
511    /// The JWS algorithm to use when authenticating the client with the
512    /// provider
513    ///
514    /// Used by the `client_secret_jwt` and `private_key_jwt` methods
515    #[serde(skip_serializing_if = "Option::is_none")]
516    pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
517
518    /// Expected signature for the JWT payload returned by the token
519    /// authentication endpoint.
520    ///
521    /// Defaults to `RS256`.
522    #[serde(
523        default = "signed_response_alg_default",
524        skip_serializing_if = "is_signed_response_alg_default"
525    )]
526    pub id_token_signed_response_alg: JsonWebSignatureAlg,
527
528    /// The scopes to request from the provider
529    ///
530    /// Defaults to `openid`.
531    #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
532    pub scope: String,
533
534    /// How to discover the provider's configuration
535    ///
536    /// Defaults to `oidc`, which uses OIDC discovery with strict metadata
537    /// verification
538    #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
539    pub discovery_mode: DiscoveryMode,
540
541    /// Whether to use proof key for code exchange (PKCE) when requesting and
542    /// exchanging the token.
543    ///
544    /// Defaults to `auto`, which uses PKCE if the provider supports it.
545    #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
546    pub pkce_method: PkceMethod,
547
548    /// Whether to fetch the user profile from the userinfo endpoint,
549    /// or to rely on the data returned in the `id_token` from the
550    /// `token_endpoint`.
551    ///
552    /// Defaults to `false`.
553    #[serde(default)]
554    pub fetch_userinfo: bool,
555
556    /// Expected signature for the JWT payload returned by the userinfo
557    /// endpoint.
558    ///
559    /// If not specified, the response is expected to be an unsigned JSON
560    /// payload.
561    #[serde(skip_serializing_if = "Option::is_none")]
562    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
563
564    /// The URL to use for the provider's authorization endpoint
565    ///
566    /// Defaults to the `authorization_endpoint` provided through discovery
567    #[serde(skip_serializing_if = "Option::is_none")]
568    pub authorization_endpoint: Option<Url>,
569
570    /// The URL to use for the provider's userinfo endpoint
571    ///
572    /// Defaults to the `userinfo_endpoint` provided through discovery
573    #[serde(skip_serializing_if = "Option::is_none")]
574    pub userinfo_endpoint: Option<Url>,
575
576    /// The URL to use for the provider's token endpoint
577    ///
578    /// Defaults to the `token_endpoint` provided through discovery
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub token_endpoint: Option<Url>,
581
582    /// The URL to use for getting the provider's public keys
583    ///
584    /// Defaults to the `jwks_uri` provided through discovery
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub jwks_uri: Option<Url>,
587
588    /// The response mode we ask the provider to use for the callback
589    #[serde(skip_serializing_if = "Option::is_none")]
590    pub response_mode: Option<ResponseMode>,
591
592    /// How claims should be imported from the `id_token` provided by the
593    /// provider
594    #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
595    pub claims_imports: ClaimsImports,
596
597    /// Additional parameters to include in the authorization request
598    ///
599    /// Orders of the keys are not preserved.
600    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
601    pub additional_authorization_parameters: BTreeMap<String, String>,
602
603    /// Whether the `login_hint` should be forwarded to the provider in the
604    /// authorization request.
605    ///
606    /// Defaults to `false`.
607    #[serde(default)]
608    pub forward_login_hint: bool,
609
610    /// What to do when receiving an OIDC Backchannel logout request.
611    ///
612    /// Defaults to "do_nothing".
613    #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
614    pub on_backchannel_logout: OnBackchannelLogout,
615}