1use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10 BrowserSession, User, UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode,
11 UserRegistration,
12};
13use mas_storage::{
14 Clock, Page, Pagination,
15 user::{UserEmailFilter, UserEmailRepository},
16};
17use rand::RngCore;
18use sea_query::{Expr, Func, PostgresQueryBuilder, Query, SimpleExpr, enum_def};
19use sea_query_binder::SqlxBinder;
20use sqlx::PgConnection;
21use ulid::Ulid;
22use uuid::Uuid;
23
24use crate::{
25 DatabaseError,
26 filter::{Filter, StatementExt},
27 iden::UserEmails,
28 pagination::QueryBuilderExt,
29 tracing::ExecuteExt,
30};
31
32pub struct PgUserEmailRepository<'c> {
34 conn: &'c mut PgConnection,
35}
36
37impl<'c> PgUserEmailRepository<'c> {
38 pub fn new(conn: &'c mut PgConnection) -> Self {
41 Self { conn }
42 }
43}
44
45#[derive(Debug, Clone, sqlx::FromRow)]
46#[enum_def]
47struct UserEmailLookup {
48 user_email_id: Uuid,
49 user_id: Uuid,
50 email: String,
51 created_at: DateTime<Utc>,
52}
53
54impl From<UserEmailLookup> for UserEmail {
55 fn from(e: UserEmailLookup) -> UserEmail {
56 UserEmail {
57 id: e.user_email_id.into(),
58 user_id: e.user_id.into(),
59 email: e.email,
60 created_at: e.created_at,
61 }
62 }
63}
64
65struct UserEmailAuthenticationLookup {
66 user_email_authentication_id: Uuid,
67 user_session_id: Option<Uuid>,
68 user_registration_id: Option<Uuid>,
69 email: String,
70 created_at: DateTime<Utc>,
71 completed_at: Option<DateTime<Utc>>,
72}
73
74impl From<UserEmailAuthenticationLookup> for UserEmailAuthentication {
75 fn from(value: UserEmailAuthenticationLookup) -> Self {
76 UserEmailAuthentication {
77 id: value.user_email_authentication_id.into(),
78 user_session_id: value.user_session_id.map(Ulid::from),
79 user_registration_id: value.user_registration_id.map(Ulid::from),
80 email: value.email,
81 created_at: value.created_at,
82 completed_at: value.completed_at,
83 }
84 }
85}
86
87struct UserEmailAuthenticationCodeLookup {
88 user_email_authentication_code_id: Uuid,
89 user_email_authentication_id: Uuid,
90 code: String,
91 created_at: DateTime<Utc>,
92 expires_at: DateTime<Utc>,
93}
94
95impl From<UserEmailAuthenticationCodeLookup> for UserEmailAuthenticationCode {
96 fn from(value: UserEmailAuthenticationCodeLookup) -> Self {
97 UserEmailAuthenticationCode {
98 id: value.user_email_authentication_code_id.into(),
99 user_email_authentication_id: value.user_email_authentication_id.into(),
100 code: value.code,
101 created_at: value.created_at,
102 expires_at: value.expires_at,
103 }
104 }
105}
106
107impl Filter for UserEmailFilter<'_> {
108 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109 sea_query::Condition::all()
110 .add_option(self.user().map(|user| {
111 Expr::col((UserEmails::Table, UserEmails::UserId)).eq(Uuid::from(user.id))
112 }))
113 .add_option(self.email().map(|email| {
114 SimpleExpr::from(Func::lower(Expr::col((
115 UserEmails::Table,
116 UserEmails::Email,
117 ))))
118 .eq(Func::lower(email))
119 }))
120 }
121}
122
123#[async_trait]
124impl UserEmailRepository for PgUserEmailRepository<'_> {
125 type Error = DatabaseError;
126
127 #[tracing::instrument(
128 name = "db.user_email.lookup",
129 skip_all,
130 fields(
131 db.query.text,
132 user_email.id = %id,
133 ),
134 err,
135 )]
136 async fn lookup(&mut self, id: Ulid) -> Result<Option<UserEmail>, Self::Error> {
137 let res = sqlx::query_as!(
138 UserEmailLookup,
139 r#"
140 SELECT user_email_id
141 , user_id
142 , email
143 , created_at
144 FROM user_emails
145
146 WHERE user_email_id = $1
147 "#,
148 Uuid::from(id),
149 )
150 .traced()
151 .fetch_optional(&mut *self.conn)
152 .await?;
153
154 let Some(user_email) = res else {
155 return Ok(None);
156 };
157
158 Ok(Some(user_email.into()))
159 }
160
161 #[tracing::instrument(
162 name = "db.user_email.find",
163 skip_all,
164 fields(
165 db.query.text,
166 %user.id,
167 user_email.email = email,
168 ),
169 err,
170 )]
171 async fn find(&mut self, user: &User, email: &str) -> Result<Option<UserEmail>, Self::Error> {
172 let res = sqlx::query_as!(
173 UserEmailLookup,
174 r#"
175 SELECT user_email_id
176 , user_id
177 , email
178 , created_at
179 FROM user_emails
180
181 WHERE user_id = $1 AND LOWER(email) = LOWER($2)
182 "#,
183 Uuid::from(user.id),
184 email,
185 )
186 .traced()
187 .fetch_optional(&mut *self.conn)
188 .await?;
189
190 let Some(user_email) = res else {
191 return Ok(None);
192 };
193
194 Ok(Some(user_email.into()))
195 }
196
197 #[tracing::instrument(
198 name = "db.user_email.find_by_email",
199 skip_all,
200 fields(
201 db.query.text,
202 user_email.email = email,
203 ),
204 err,
205 )]
206 async fn find_by_email(&mut self, email: &str) -> Result<Option<UserEmail>, Self::Error> {
207 let res = sqlx::query_as!(
208 UserEmailLookup,
209 r#"
210 SELECT user_email_id
211 , user_id
212 , email
213 , created_at
214 FROM user_emails
215 WHERE LOWER(email) = LOWER($1)
216 "#,
217 email,
218 )
219 .traced()
220 .fetch_all(&mut *self.conn)
221 .await?;
222
223 if res.len() != 1 {
224 return Ok(None);
225 }
226
227 let Some(user_email) = res.into_iter().next() else {
228 return Ok(None);
229 };
230
231 Ok(Some(user_email.into()))
232 }
233
234 #[tracing::instrument(
235 name = "db.user_email.all",
236 skip_all,
237 fields(
238 db.query.text,
239 %user.id,
240 ),
241 err,
242 )]
243 async fn all(&mut self, user: &User) -> Result<Vec<UserEmail>, Self::Error> {
244 let res = sqlx::query_as!(
245 UserEmailLookup,
246 r#"
247 SELECT user_email_id
248 , user_id
249 , email
250 , created_at
251 FROM user_emails
252
253 WHERE user_id = $1
254
255 ORDER BY email ASC
256 "#,
257 Uuid::from(user.id),
258 )
259 .traced()
260 .fetch_all(&mut *self.conn)
261 .await?;
262
263 Ok(res.into_iter().map(Into::into).collect())
264 }
265
266 #[tracing::instrument(
267 name = "db.user_email.list",
268 skip_all,
269 fields(
270 db.query.text,
271 ),
272 err,
273 )]
274 async fn list(
275 &mut self,
276 filter: UserEmailFilter<'_>,
277 pagination: Pagination,
278 ) -> Result<Page<UserEmail>, DatabaseError> {
279 let (sql, arguments) = Query::select()
280 .expr_as(
281 Expr::col((UserEmails::Table, UserEmails::UserEmailId)),
282 UserEmailLookupIden::UserEmailId,
283 )
284 .expr_as(
285 Expr::col((UserEmails::Table, UserEmails::UserId)),
286 UserEmailLookupIden::UserId,
287 )
288 .expr_as(
289 Expr::col((UserEmails::Table, UserEmails::Email)),
290 UserEmailLookupIden::Email,
291 )
292 .expr_as(
293 Expr::col((UserEmails::Table, UserEmails::CreatedAt)),
294 UserEmailLookupIden::CreatedAt,
295 )
296 .from(UserEmails::Table)
297 .apply_filter(filter)
298 .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
299 .build_sqlx(PostgresQueryBuilder);
300
301 let edges: Vec<UserEmailLookup> = sqlx::query_as_with(&sql, arguments)
302 .traced()
303 .fetch_all(&mut *self.conn)
304 .await?;
305
306 let page = pagination.process(edges).map(UserEmail::from);
307
308 Ok(page)
309 }
310
311 #[tracing::instrument(
312 name = "db.user_email.count",
313 skip_all,
314 fields(
315 db.query.text,
316 ),
317 err,
318 )]
319 async fn count(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
320 let (sql, arguments) = Query::select()
321 .expr(Expr::col((UserEmails::Table, UserEmails::UserEmailId)).count())
322 .from(UserEmails::Table)
323 .apply_filter(filter)
324 .build_sqlx(PostgresQueryBuilder);
325
326 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
327 .traced()
328 .fetch_one(&mut *self.conn)
329 .await?;
330
331 count
332 .try_into()
333 .map_err(DatabaseError::to_invalid_operation)
334 }
335
336 #[tracing::instrument(
337 name = "db.user_email.add",
338 skip_all,
339 fields(
340 db.query.text,
341 %user.id,
342 user_email.id,
343 user_email.email = email,
344 ),
345 err,
346 )]
347 async fn add(
348 &mut self,
349 rng: &mut (dyn RngCore + Send),
350 clock: &dyn Clock,
351 user: &User,
352 email: String,
353 ) -> Result<UserEmail, Self::Error> {
354 let created_at = clock.now();
355 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
356 tracing::Span::current().record("user_email.id", tracing::field::display(id));
357
358 sqlx::query!(
359 r#"
360 INSERT INTO user_emails (user_email_id, user_id, email, created_at)
361 VALUES ($1, $2, $3, $4)
362 "#,
363 Uuid::from(id),
364 Uuid::from(user.id),
365 &email,
366 created_at,
367 )
368 .traced()
369 .execute(&mut *self.conn)
370 .await?;
371
372 Ok(UserEmail {
373 id,
374 user_id: user.id,
375 email,
376 created_at,
377 })
378 }
379
380 #[tracing::instrument(
381 name = "db.user_email.remove",
382 skip_all,
383 fields(
384 db.query.text,
385 user.id = %user_email.user_id,
386 %user_email.id,
387 %user_email.email,
388 ),
389 err,
390 )]
391 async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
392 let res = sqlx::query!(
393 r#"
394 DELETE FROM user_emails
395 WHERE user_email_id = $1
396 "#,
397 Uuid::from(user_email.id),
398 )
399 .traced()
400 .execute(&mut *self.conn)
401 .await?;
402
403 DatabaseError::ensure_affected_rows(&res, 1)?;
404
405 Ok(())
406 }
407
408 #[tracing::instrument(
409 name = "db.user_email.remove_bulk",
410 skip_all,
411 fields(
412 db.query.text,
413 ),
414 err,
415 )]
416 async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
417 let (sql, arguments) = Query::delete()
418 .from_table(UserEmails::Table)
419 .apply_filter(filter)
420 .build_sqlx(PostgresQueryBuilder);
421
422 let res = sqlx::query_with(&sql, arguments)
423 .traced()
424 .execute(&mut *self.conn)
425 .await?;
426
427 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
428 }
429
430 #[tracing::instrument(
431 name = "db.user_email.add_authentication_for_session",
432 skip_all,
433 fields(
434 db.query.text,
435 %session.id,
436 user_email_authentication.id,
437 user_email_authentication.email = email,
438 ),
439 err,
440 )]
441 async fn add_authentication_for_session(
442 &mut self,
443 rng: &mut (dyn RngCore + Send),
444 clock: &dyn Clock,
445 email: String,
446 session: &BrowserSession,
447 ) -> Result<UserEmailAuthentication, Self::Error> {
448 let created_at = clock.now();
449 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
450 tracing::Span::current()
451 .record("user_email_authentication.id", tracing::field::display(id));
452
453 sqlx::query!(
454 r#"
455 INSERT INTO user_email_authentications
456 ( user_email_authentication_id
457 , user_session_id
458 , email
459 , created_at
460 )
461 VALUES ($1, $2, $3, $4)
462 "#,
463 Uuid::from(id),
464 Uuid::from(session.id),
465 &email,
466 created_at,
467 )
468 .traced()
469 .execute(&mut *self.conn)
470 .await?;
471
472 Ok(UserEmailAuthentication {
473 id,
474 user_session_id: Some(session.id),
475 user_registration_id: None,
476 email,
477 created_at,
478 completed_at: None,
479 })
480 }
481
482 #[tracing::instrument(
483 name = "db.user_email.add_authentication_for_registration",
484 skip_all,
485 fields(
486 db.query.text,
487 %user_registration.id,
488 user_email_authentication.id,
489 user_email_authentication.email = email,
490 ),
491 err,
492 )]
493 async fn add_authentication_for_registration(
494 &mut self,
495 rng: &mut (dyn RngCore + Send),
496 clock: &dyn Clock,
497 email: String,
498 user_registration: &UserRegistration,
499 ) -> Result<UserEmailAuthentication, Self::Error> {
500 let created_at = clock.now();
501 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
502 tracing::Span::current()
503 .record("user_email_authentication.id", tracing::field::display(id));
504
505 sqlx::query!(
506 r#"
507 INSERT INTO user_email_authentications
508 ( user_email_authentication_id
509 , user_registration_id
510 , email
511 , created_at
512 )
513 VALUES ($1, $2, $3, $4)
514 "#,
515 Uuid::from(id),
516 Uuid::from(user_registration.id),
517 &email,
518 created_at,
519 )
520 .traced()
521 .execute(&mut *self.conn)
522 .await?;
523
524 Ok(UserEmailAuthentication {
525 id,
526 user_session_id: None,
527 user_registration_id: Some(user_registration.id),
528 email,
529 created_at,
530 completed_at: None,
531 })
532 }
533
534 #[tracing::instrument(
535 name = "db.user_email.add_authentication_code",
536 skip_all,
537 fields(
538 db.query.text,
539 %user_email_authentication.id,
540 %user_email_authentication.email,
541 user_email_authentication_code.id,
542 user_email_authentication_code.code = code,
543 ),
544 err,
545 )]
546 async fn add_authentication_code(
547 &mut self,
548 rng: &mut (dyn RngCore + Send),
549 clock: &dyn Clock,
550 duration: chrono::Duration,
551 user_email_authentication: &UserEmailAuthentication,
552 code: String,
553 ) -> Result<UserEmailAuthenticationCode, Self::Error> {
554 let created_at = clock.now();
555 let expires_at = created_at + duration;
556 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
557 tracing::Span::current().record(
558 "user_email_authentication_code.id",
559 tracing::field::display(id),
560 );
561
562 sqlx::query!(
563 r#"
564 INSERT INTO user_email_authentication_codes
565 ( user_email_authentication_code_id
566 , user_email_authentication_id
567 , code
568 , created_at
569 , expires_at
570 )
571 VALUES ($1, $2, $3, $4, $5)
572 "#,
573 Uuid::from(id),
574 Uuid::from(user_email_authentication.id),
575 &code,
576 created_at,
577 expires_at,
578 )
579 .traced()
580 .execute(&mut *self.conn)
581 .await?;
582
583 Ok(UserEmailAuthenticationCode {
584 id,
585 user_email_authentication_id: user_email_authentication.id,
586 code,
587 created_at,
588 expires_at,
589 })
590 }
591
592 #[tracing::instrument(
593 name = "db.user_email.lookup_authentication",
594 skip_all,
595 fields(
596 db.query.text,
597 user_email_authentication.id = %id,
598 ),
599 err,
600 )]
601 async fn lookup_authentication(
602 &mut self,
603 id: Ulid,
604 ) -> Result<Option<UserEmailAuthentication>, Self::Error> {
605 let res = sqlx::query_as!(
606 UserEmailAuthenticationLookup,
607 r#"
608 SELECT user_email_authentication_id
609 , user_session_id
610 , user_registration_id
611 , email
612 , created_at
613 , completed_at
614 FROM user_email_authentications
615 WHERE user_email_authentication_id = $1
616 "#,
617 Uuid::from(id),
618 )
619 .traced()
620 .fetch_optional(&mut *self.conn)
621 .await?;
622
623 Ok(res.map(UserEmailAuthentication::from))
624 }
625
626 #[tracing::instrument(
627 name = "db.user_email.find_authentication_by_code",
628 skip_all,
629 fields(
630 db.query.text,
631 %authentication.id,
632 user_email_authentication_code.code = code,
633 ),
634 err,
635 )]
636 async fn find_authentication_code(
637 &mut self,
638 authentication: &UserEmailAuthentication,
639 code: &str,
640 ) -> Result<Option<UserEmailAuthenticationCode>, Self::Error> {
641 let res = sqlx::query_as!(
642 UserEmailAuthenticationCodeLookup,
643 r#"
644 SELECT user_email_authentication_code_id
645 , user_email_authentication_id
646 , code
647 , created_at
648 , expires_at
649 FROM user_email_authentication_codes
650 WHERE user_email_authentication_id = $1
651 AND code = $2
652 "#,
653 Uuid::from(authentication.id),
654 code,
655 )
656 .traced()
657 .fetch_optional(&mut *self.conn)
658 .await?;
659
660 Ok(res.map(UserEmailAuthenticationCode::from))
661 }
662
663 #[tracing::instrument(
664 name = "db.user_email.complete_email_authentication",
665 skip_all,
666 fields(
667 db.query.text,
668 %user_email_authentication.id,
669 %user_email_authentication.email,
670 %user_email_authentication_code.id,
671 %user_email_authentication_code.code,
672 ),
673 err,
674 )]
675 async fn complete_authentication(
676 &mut self,
677 clock: &dyn Clock,
678 mut user_email_authentication: UserEmailAuthentication,
679 user_email_authentication_code: &UserEmailAuthenticationCode,
680 ) -> Result<UserEmailAuthentication, Self::Error> {
681 let completed_at = clock.now();
685
686 let res = sqlx::query!(
690 r#"
691 UPDATE user_email_authentications
692 SET completed_at = $2
693 WHERE user_email_authentication_id = $1
694 AND completed_at IS NULL
695 "#,
696 Uuid::from(user_email_authentication.id),
697 completed_at,
698 )
699 .traced()
700 .execute(&mut *self.conn)
701 .await?;
702
703 DatabaseError::ensure_affected_rows(&res, 1)?;
704
705 user_email_authentication.completed_at = Some(completed_at);
706 Ok(user_email_authentication)
707 }
708}