mas_templates/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7#![deny(missing_docs)]
8#![allow(clippy::module_name_repetitions)]
9
10//! Templates rendering
11
12use std::{collections::HashSet, sync::Arc};
13
14use anyhow::Context as _;
15use arc_swap::ArcSwap;
16use camino::{Utf8Path, Utf8PathBuf};
17use mas_i18n::Translator;
18use mas_router::UrlBuilder;
19use mas_spa::ViteManifest;
20use minijinja::{UndefinedBehavior, Value};
21use rand::Rng;
22use serde::Serialize;
23use thiserror::Error;
24use tokio::task::JoinError;
25use tracing::{debug, info};
26use walkdir::DirEntry;
27
28mod context;
29mod forms;
30mod functions;
31
32#[macro_use]
33mod macros;
34
35pub use self::{
36    context::{
37        AccountInactiveContext, ApiDocContext, AppContext, CompatSsoContext, ConsentContext,
38        DeviceConsentContext, DeviceLinkContext, DeviceLinkFormField, DeviceNameContext,
39        EmailRecoveryContext, EmailVerificationContext, EmptyContext, ErrorContext,
40        FormPostContext, IndexContext, LoginContext, LoginFormField, NotFoundContext,
41        PasswordRegisterContext, PolicyViolationContext, PostAuthContext, PostAuthContextInner,
42        RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField,
43        RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext,
44        RegisterFormField, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
45        RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext,
46        RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext,
47        RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures,
48        TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamRegisterFormField,
49        UpstreamSuggestLink, WithCaptcha, WithCsrf, WithLanguage, WithOptionalSession, WithSession,
50    },
51    forms::{FieldError, FormError, FormField, FormState, ToFormState},
52};
53
54/// Escape the given string for use in HTML
55///
56/// It uses the same crate as the one used by the minijinja templates
57#[must_use]
58pub fn escape_html(input: &str) -> String {
59    v_htmlescape::escape(input).to_string()
60}
61
62/// Wrapper around [`minijinja::Environment`] helping rendering the various
63/// templates
64#[derive(Debug, Clone)]
65pub struct Templates {
66    environment: Arc<ArcSwap<minijinja::Environment<'static>>>,
67    translator: Arc<ArcSwap<Translator>>,
68    url_builder: UrlBuilder,
69    branding: SiteBranding,
70    features: SiteFeatures,
71    vite_manifest_path: Utf8PathBuf,
72    translations_path: Utf8PathBuf,
73    path: Utf8PathBuf,
74    /// Whether template rendering is in strict mode (for testing,
75    /// until this can be rolled out in production.)
76    strict: bool,
77}
78
79/// There was an issue while loading the templates
80#[derive(Error, Debug)]
81pub enum TemplateLoadingError {
82    /// I/O error
83    #[error(transparent)]
84    IO(#[from] std::io::Error),
85
86    /// Failed to read the assets manifest
87    #[error("failed to read the assets manifest")]
88    ViteManifestIO(#[source] std::io::Error),
89
90    /// Failed to deserialize the assets manifest
91    #[error("invalid assets manifest")]
92    ViteManifest(#[from] serde_json::Error),
93
94    /// Failed to load the translations
95    #[error("failed to load the translations")]
96    Translations(#[from] mas_i18n::LoadError),
97
98    /// Failed to traverse the filesystem
99    #[error("failed to traverse the filesystem")]
100    WalkDir(#[from] walkdir::Error),
101
102    /// Encountered non-UTF-8 path
103    #[error("encountered non-UTF-8 path")]
104    NonUtf8Path(#[from] camino::FromPathError),
105
106    /// Encountered non-UTF-8 path
107    #[error("encountered non-UTF-8 path")]
108    NonUtf8PathBuf(#[from] camino::FromPathBufError),
109
110    /// Encountered invalid path
111    #[error("encountered invalid path")]
112    InvalidPath(#[from] std::path::StripPrefixError),
113
114    /// Some templates failed to compile
115    #[error("could not load and compile some templates")]
116    Compile(#[from] minijinja::Error),
117
118    /// Could not join blocking task
119    #[error("error from async runtime")]
120    Runtime(#[from] JoinError),
121
122    /// There are essential templates missing
123    #[error("missing templates {missing:?}")]
124    MissingTemplates {
125        /// List of missing templates
126        missing: HashSet<String>,
127        /// List of templates that were loaded
128        loaded: HashSet<String>,
129    },
130}
131
132fn is_hidden(entry: &DirEntry) -> bool {
133    entry
134        .file_name()
135        .to_str()
136        .is_some_and(|s| s.starts_with('.'))
137}
138
139impl Templates {
140    /// Load the templates from the given config
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the templates could not be loaded from disk.
145    #[tracing::instrument(
146        name = "templates.load",
147        skip_all,
148        fields(%path),
149    )]
150    pub async fn load(
151        path: Utf8PathBuf,
152        url_builder: UrlBuilder,
153        vite_manifest_path: Utf8PathBuf,
154        translations_path: Utf8PathBuf,
155        branding: SiteBranding,
156        features: SiteFeatures,
157        strict: bool,
158    ) -> Result<Self, TemplateLoadingError> {
159        let (translator, environment) = Self::load_(
160            &path,
161            url_builder.clone(),
162            &vite_manifest_path,
163            &translations_path,
164            branding.clone(),
165            features,
166            strict,
167        )
168        .await?;
169        Ok(Self {
170            environment: Arc::new(ArcSwap::new(environment)),
171            translator: Arc::new(ArcSwap::new(translator)),
172            path,
173            url_builder,
174            vite_manifest_path,
175            translations_path,
176            branding,
177            features,
178            strict,
179        })
180    }
181
182    async fn load_(
183        path: &Utf8Path,
184        url_builder: UrlBuilder,
185        vite_manifest_path: &Utf8Path,
186        translations_path: &Utf8Path,
187        branding: SiteBranding,
188        features: SiteFeatures,
189        strict: bool,
190    ) -> Result<(Arc<Translator>, Arc<minijinja::Environment<'static>>), TemplateLoadingError> {
191        let path = path.to_owned();
192        let span = tracing::Span::current();
193
194        // Read the assets manifest from disk
195        let vite_manifest = tokio::fs::read(vite_manifest_path)
196            .await
197            .map_err(TemplateLoadingError::ViteManifestIO)?;
198
199        // Parse it
200        let vite_manifest: ViteManifest =
201            serde_json::from_slice(&vite_manifest).map_err(TemplateLoadingError::ViteManifest)?;
202
203        let translations_path = translations_path.to_owned();
204        let translator =
205            tokio::task::spawn_blocking(move || Translator::load_from_path(&translations_path))
206                .await??;
207        let translator = Arc::new(translator);
208
209        debug!(locales = ?translator.available_locales(), "Loaded translations");
210
211        let (loaded, mut env) = tokio::task::spawn_blocking(move || {
212            span.in_scope(move || {
213                let mut loaded: HashSet<_> = HashSet::new();
214                let mut env = minijinja::Environment::new();
215                // Don't allow use of undefined variables
216                env.set_undefined_behavior(if strict {
217                    UndefinedBehavior::Strict
218                } else {
219                    // For now, allow semi-strict, because we don't have total test coverage of
220                    // tests and some tests rely on if conditions against sometimes-undefined
221                    // variables
222                    UndefinedBehavior::SemiStrict
223                });
224                let root = path.canonicalize_utf8()?;
225                info!(%root, "Loading templates from filesystem");
226                for entry in walkdir::WalkDir::new(&root)
227                    .min_depth(1)
228                    .into_iter()
229                    .filter_entry(|e| !is_hidden(e))
230                {
231                    let entry = entry?;
232                    if entry.file_type().is_file() {
233                        let path = Utf8PathBuf::try_from(entry.into_path())?;
234                        let Some(ext) = path.extension() else {
235                            continue;
236                        };
237
238                        if ext == "html" || ext == "txt" || ext == "subject" {
239                            let relative = path.strip_prefix(&root)?;
240                            debug!(%relative, "Registering template");
241                            let template = std::fs::read_to_string(&path)?;
242                            env.add_template_owned(relative.as_str().to_owned(), template)?;
243                            loaded.insert(relative.as_str().to_owned());
244                        }
245                    }
246                }
247
248                Ok::<_, TemplateLoadingError>((loaded, env))
249            })
250        })
251        .await??;
252
253        env.add_global("branding", Value::from_object(branding));
254        env.add_global("features", Value::from_object(features));
255
256        self::functions::register(
257            &mut env,
258            url_builder,
259            vite_manifest,
260            Arc::clone(&translator),
261        );
262
263        let env = Arc::new(env);
264
265        let needed: HashSet<_> = TEMPLATES.into_iter().map(ToOwned::to_owned).collect();
266        debug!(?loaded, ?needed, "Templates loaded");
267        let missing: HashSet<_> = needed.difference(&loaded).cloned().collect();
268
269        if missing.is_empty() {
270            Ok((translator, env))
271        } else {
272            Err(TemplateLoadingError::MissingTemplates { missing, loaded })
273        }
274    }
275
276    /// Reload the templates on disk
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if the templates could not be reloaded from disk.
281    #[tracing::instrument(
282        name = "templates.reload",
283        skip_all,
284        fields(path = %self.path),
285    )]
286    pub async fn reload(&self) -> Result<(), TemplateLoadingError> {
287        let (translator, environment) = Self::load_(
288            &self.path,
289            self.url_builder.clone(),
290            &self.vite_manifest_path,
291            &self.translations_path,
292            self.branding.clone(),
293            self.features,
294            self.strict,
295        )
296        .await?;
297
298        // Swap them
299        self.environment.store(environment);
300        self.translator.store(translator);
301
302        Ok(())
303    }
304
305    /// Get the translator
306    #[must_use]
307    pub fn translator(&self) -> Arc<Translator> {
308        self.translator.load_full()
309    }
310}
311
312/// Failed to render a template
313#[derive(Error, Debug)]
314pub enum TemplateError {
315    /// Missing template
316    #[error("missing template {template:?}")]
317    Missing {
318        /// The name of the template being rendered
319        template: &'static str,
320
321        /// The underlying error
322        #[source]
323        source: minijinja::Error,
324    },
325
326    /// Failed to render the template
327    #[error("could not render template {template:?}")]
328    Render {
329        /// The name of the template being rendered
330        template: &'static str,
331
332        /// The underlying error
333        #[source]
334        source: minijinja::Error,
335    },
336}
337
338register_templates! {
339    /// Render the not found fallback page
340    pub fn render_not_found(WithLanguage<NotFoundContext>) { "pages/404.html" }
341
342    /// Render the frontend app
343    pub fn render_app(WithLanguage<AppContext>) { "app.html" }
344
345    /// Render the Swagger API reference
346    pub fn render_swagger(ApiDocContext) { "swagger/doc.html" }
347
348    /// Render the Swagger OAuth callback page
349    pub fn render_swagger_callback(ApiDocContext) { "swagger/oauth2-redirect.html" }
350
351    /// Render the login page
352    pub fn render_login(WithLanguage<WithCsrf<LoginContext>>) { "pages/login.html" }
353
354    /// Render the registration page
355    pub fn render_register(WithLanguage<WithCsrf<RegisterContext>>) { "pages/register/index.html" }
356
357    /// Render the password registration page
358    pub fn render_password_register(WithLanguage<WithCsrf<WithCaptcha<PasswordRegisterContext>>>) { "pages/register/password.html" }
359
360    /// Render the email verification page
361    pub fn render_register_steps_verify_email(WithLanguage<WithCsrf<RegisterStepsVerifyEmailContext>>) { "pages/register/steps/verify_email.html" }
362
363    /// Render the email in use page
364    pub fn render_register_steps_email_in_use(WithLanguage<RegisterStepsEmailInUseContext>) { "pages/register/steps/email_in_use.html" }
365
366    /// Render the display name page
367    pub fn render_register_steps_display_name(WithLanguage<WithCsrf<RegisterStepsDisplayNameContext>>) { "pages/register/steps/display_name.html" }
368
369    /// Render the registration token page
370    pub fn render_register_steps_registration_token(WithLanguage<WithCsrf<RegisterStepsRegistrationTokenContext>>) { "pages/register/steps/registration_token.html" }
371
372    /// Render the client consent page
373    pub fn render_consent(WithLanguage<WithCsrf<WithSession<ConsentContext>>>) { "pages/consent.html" }
374
375    /// Render the policy violation page
376    pub fn render_policy_violation(WithLanguage<WithCsrf<WithSession<PolicyViolationContext>>>) { "pages/policy_violation.html" }
377
378    /// Render the legacy SSO login consent page
379    pub fn render_sso_login(WithLanguage<WithCsrf<WithSession<CompatSsoContext>>>) { "pages/sso.html" }
380
381    /// Render the home page
382    pub fn render_index(WithLanguage<WithCsrf<WithOptionalSession<IndexContext>>>) { "pages/index.html" }
383
384    /// Render the account recovery start page
385    pub fn render_recovery_start(WithLanguage<WithCsrf<RecoveryStartContext>>) { "pages/recovery/start.html" }
386
387    /// Render the account recovery start page
388    pub fn render_recovery_progress(WithLanguage<WithCsrf<RecoveryProgressContext>>) { "pages/recovery/progress.html" }
389
390    /// Render the account recovery finish page
391    pub fn render_recovery_finish(WithLanguage<WithCsrf<RecoveryFinishContext>>) { "pages/recovery/finish.html" }
392
393    /// Render the account recovery link expired page
394    pub fn render_recovery_expired(WithLanguage<WithCsrf<RecoveryExpiredContext>>) { "pages/recovery/expired.html" }
395
396    /// Render the account recovery link consumed page
397    pub fn render_recovery_consumed(WithLanguage<EmptyContext>) { "pages/recovery/consumed.html" }
398
399    /// Render the account recovery disabled page
400    pub fn render_recovery_disabled(WithLanguage<EmptyContext>) { "pages/recovery/disabled.html" }
401
402    /// Render the form used by the `form_post` response mode
403    pub fn render_form_post<T: Serialize>(WithLanguage<FormPostContext<T>>) { "form_post.html" }
404
405    /// Render the HTML error page
406    pub fn render_error(ErrorContext) { "pages/error.html" }
407
408    /// Render the email recovery email (plain text variant)
409    pub fn render_email_recovery_txt(WithLanguage<EmailRecoveryContext>) { "emails/recovery.txt" }
410
411    /// Render the email recovery email (HTML text variant)
412    pub fn render_email_recovery_html(WithLanguage<EmailRecoveryContext>) { "emails/recovery.html" }
413
414    /// Render the email recovery subject
415    pub fn render_email_recovery_subject(WithLanguage<EmailRecoveryContext>) { "emails/recovery.subject" }
416
417    /// Render the email verification email (plain text variant)
418    pub fn render_email_verification_txt(WithLanguage<EmailVerificationContext>) { "emails/verification.txt" }
419
420    /// Render the email verification email (HTML text variant)
421    pub fn render_email_verification_html(WithLanguage<EmailVerificationContext>) { "emails/verification.html" }
422
423    /// Render the email verification subject
424    pub fn render_email_verification_subject(WithLanguage<EmailVerificationContext>) { "emails/verification.subject" }
425
426    /// Render the upstream link mismatch message
427    pub fn render_upstream_oauth2_link_mismatch(WithLanguage<WithCsrf<WithSession<UpstreamExistingLinkContext>>>) { "pages/upstream_oauth2/link_mismatch.html" }
428
429    /// Render the upstream link match
430    pub fn render_upstream_oauth2_login_link(WithLanguage<WithCsrf<UpstreamExistingLinkContext>>) { "pages/upstream_oauth2/login_link.html" }
431
432    /// Render the upstream suggest link message
433    pub fn render_upstream_oauth2_suggest_link(WithLanguage<WithCsrf<WithSession<UpstreamSuggestLink>>>) { "pages/upstream_oauth2/suggest_link.html" }
434
435    /// Render the upstream register screen
436    pub fn render_upstream_oauth2_do_register(WithLanguage<WithCsrf<UpstreamRegister>>) { "pages/upstream_oauth2/do_register.html" }
437
438    /// Render the device code link page
439    pub fn render_device_link(WithLanguage<DeviceLinkContext>) { "pages/device_link.html" }
440
441    /// Render the device code consent page
442    pub fn render_device_consent(WithLanguage<WithCsrf<WithSession<DeviceConsentContext>>>) { "pages/device_consent.html" }
443
444    /// Render the 'account deactivated' page
445    pub fn render_account_deactivated(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/deactivated.html" }
446
447    /// Render the 'account locked' page
448    pub fn render_account_locked(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/locked.html" }
449
450    /// Render the 'account logged out' page
451    pub fn render_account_logged_out(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/logged_out.html" }
452
453    /// Render the automatic device name for OAuth 2.0 client
454    pub fn render_device_name(WithLanguage<DeviceNameContext>) { "device_name.txt" }
455}
456
457impl Templates {
458    /// Render all templates with the generated samples to check if they render
459    /// properly
460    ///
461    /// # Errors
462    ///
463    /// Returns an error if any of the templates fails to render
464    pub fn check_render(
465        &self,
466        now: chrono::DateTime<chrono::Utc>,
467        rng: &mut impl Rng,
468    ) -> anyhow::Result<()> {
469        check::render_not_found(self, now, rng)?;
470        check::render_app(self, now, rng)?;
471        check::render_swagger(self, now, rng)?;
472        check::render_swagger_callback(self, now, rng)?;
473        check::render_login(self, now, rng)?;
474        check::render_register(self, now, rng)?;
475        check::render_password_register(self, now, rng)?;
476        check::render_register_steps_verify_email(self, now, rng)?;
477        check::render_register_steps_email_in_use(self, now, rng)?;
478        check::render_register_steps_display_name(self, now, rng)?;
479        check::render_register_steps_registration_token(self, now, rng)?;
480        check::render_consent(self, now, rng)?;
481        check::render_policy_violation(self, now, rng)?;
482        check::render_sso_login(self, now, rng)?;
483        check::render_index(self, now, rng)?;
484        check::render_recovery_start(self, now, rng)?;
485        check::render_recovery_progress(self, now, rng)?;
486        check::render_recovery_finish(self, now, rng)?;
487        check::render_recovery_expired(self, now, rng)?;
488        check::render_recovery_consumed(self, now, rng)?;
489        check::render_recovery_disabled(self, now, rng)?;
490        check::render_form_post::<EmptyContext>(self, now, rng)?;
491        check::render_error(self, now, rng)?;
492        check::render_email_recovery_txt(self, now, rng)?;
493        check::render_email_recovery_html(self, now, rng)?;
494        check::render_email_recovery_subject(self, now, rng)?;
495        check::render_email_verification_txt(self, now, rng)?;
496        check::render_email_verification_html(self, now, rng)?;
497        check::render_email_verification_subject(self, now, rng)?;
498        check::render_upstream_oauth2_link_mismatch(self, now, rng)?;
499        check::render_upstream_oauth2_login_link(self, now, rng)?;
500        check::render_upstream_oauth2_suggest_link(self, now, rng)?;
501        check::render_upstream_oauth2_do_register(self, now, rng)?;
502        check::render_device_link(self, now, rng)?;
503        check::render_device_consent(self, now, rng)?;
504        check::render_account_deactivated(self, now, rng)?;
505        check::render_account_locked(self, now, rng)?;
506        check::render_account_logged_out(self, now, rng)?;
507        check::render_device_name(self, now, rng)?;
508        Ok(())
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[tokio::test]
517    async fn check_builtin_templates() {
518        #[allow(clippy::disallowed_methods)]
519        let now = chrono::Utc::now();
520        #[allow(clippy::disallowed_methods)]
521        let mut rng = rand::thread_rng();
522
523        let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/");
524        let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
525        let branding = SiteBranding::new("example.com");
526        let features = SiteFeatures {
527            password_login: true,
528            password_registration: true,
529            password_registration_email_required: true,
530            account_recovery: true,
531            login_with_email_allowed: true,
532        };
533        let vite_manifest_path =
534            Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json");
535        let translations_path =
536            Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../translations");
537        let templates = Templates::load(
538            path,
539            url_builder,
540            vite_manifest_path,
541            translations_path,
542            branding,
543            features,
544            // Use strict mode in tests
545            true,
546        )
547        .await
548        .unwrap();
549        templates.check_render(now, &mut rng).unwrap();
550    }
551}