Skip to main content

ssg/
livereload.rs

1// Copyright © 2023 - 2026 Static Site Generator (SSG). All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Live-reload script injection plugin.
5//!
6//! Injects a WebSocket-based live-reload client into all HTML files in
7//! the site directory when the development server starts.
8//!
9//! # How it works
10//!
11//! 1. The `LiveReloadPlugin` hooks into the `on_serve` lifecycle event.
12//! 2. It walks all HTML files in the site directory.
13//! 3. It injects a `<script>` tag before `</body>` that opens a WebSocket
14//!    connection to a configurable port (default 35729).
15//! 4. When the server sends a `"reload"` message, the page reloads.
16//! 5. On disconnect, the script auto-reconnects with exponential backoff
17//!    (1s, 2s, 4s, capped at 10s) and shows a small "Connecting..."
18//!    indicator in the bottom-right corner.
19
20use crate::plugin::{Plugin, PluginContext};
21use anyhow::{Context, Result};
22use std::fs;
23use std::path::{Path, PathBuf};
24
25/// Default WebSocket port for the live-reload server.
26pub const DEFAULT_PORT: u16 = 35729;
27
28/// Maximum number of HTML files to process.
29const MAX_FILES: usize = 50_000;
30
31/// Marker attribute used to detect whether the script has already been injected.
32const MARKER: &str = "ssg-livereload";
33
34/// Plugin that injects a live-reload script into all HTML files.
35///
36/// The injected script opens a WebSocket connection and reloads the page
37/// when it receives a `"reload"` message. It reconnects automatically
38/// with exponential backoff on disconnect.
39///
40/// # Example
41///
42/// ```rust
43/// use ssg::plugin::PluginManager;
44/// use ssg::livereload::LiveReloadPlugin;
45///
46/// let mut pm = PluginManager::new();
47/// pm.register(LiveReloadPlugin::new());
48/// ```
49#[derive(Debug, Clone, Copy)]
50pub struct LiveReloadPlugin {
51    /// WebSocket port the live-reload client connects to.
52    port: u16,
53}
54
55impl LiveReloadPlugin {
56    /// Creates a new `LiveReloadPlugin` with the default port (35729).
57    #[must_use]
58    pub const fn new() -> Self {
59        Self { port: DEFAULT_PORT }
60    }
61
62    /// Creates a new `LiveReloadPlugin` with a custom WebSocket port.
63    #[must_use]
64    pub const fn with_port(port: u16) -> Self {
65        Self { port }
66    }
67
68    /// Returns the configured port.
69    #[must_use]
70    pub const fn port(&self) -> u16 {
71        self.port
72    }
73}
74
75impl Default for LiveReloadPlugin {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl Plugin for LiveReloadPlugin {
82    fn name(&self) -> &'static str {
83        "livereload"
84    }
85
86    fn on_serve(&self, ctx: &PluginContext) -> Result<()> {
87        if !ctx.site_dir.exists() {
88            return Ok(());
89        }
90
91        let html_files = collect_html_files(&ctx.site_dir)?;
92        if html_files.is_empty() {
93            return Ok(());
94        }
95
96        for path in &html_files {
97            inject_livereload(path, self.port)?;
98        }
99
100        println!(
101            "[livereload] Injected live-reload script into {} HTML file(s) (port {})",
102            html_files.len(),
103            self.port,
104        );
105        Ok(())
106    }
107}
108
109/// Collect all `.html` files under `dir` (iterative, bounded).
110fn collect_html_files(dir: &Path) -> Result<Vec<PathBuf>> {
111    crate::walk::walk_files_bounded_count(dir, "html", MAX_FILES)
112}
113
114/// Inject the live-reload script into a single HTML file.
115///
116/// Inserts a `<script>` block before `</body>`. The script:
117/// 1. Opens a WebSocket to `ws://localhost:{port}`
118/// 2. Reloads on receiving a `"reload"` message
119/// 3. Reconnects with exponential backoff (1s, 2s, 4s, max 10s)
120/// 4. Shows a "Connecting..." indicator during reconnection
121///
122/// The injection is idempotent — if the marker is already present,
123/// the file is left unchanged.
124fn inject_livereload(path: &Path, port: u16) -> Result<()> {
125    let html = fs::read_to_string(path)
126        .with_context(|| format!("cannot read {}", path.display()))?;
127
128    if html.contains(MARKER) {
129        return Ok(()); // Already injected
130    }
131
132    let script = livereload_script(port);
133
134    let injected = if let Some(pos) = html.rfind("</body>") {
135        format!("{}{}{}", &html[..pos], script, &html[pos..])
136    } else {
137        format!("{html}{script}")
138    };
139
140    fs::write(path, injected)
141        .with_context(|| format!("cannot write {}", path.display()))?;
142    Ok(())
143}
144
145/// Generate the live-reload script tag for a given port.
146fn livereload_script(port: u16) -> String {
147    format!(
148        r"
149<!-- SSG Live-Reload -->
150<script data-ssg-livereload>
151(function(){{
152  var url='ws://localhost:{port}',delay=1000,maxDelay=10000,indicator=null;
153  try{{var sp=sessionStorage.getItem('ssg-scroll');if(sp){{sessionStorage.removeItem('ssg-scroll');var p=JSON.parse(sp);setTimeout(function(){{scrollTo(p.x,p.y);}},50);}}}}catch(se){{}}
154  function showIndicator(){{
155    if(indicator)return;
156    indicator=document.createElement('div');
157    indicator.id='ssg-livereload';
158    indicator.textContent='Connecting\u2026';
159    indicator.style.cssText='position:fixed;bottom:8px;right:8px;z-index:99999;'
160      +'background:rgba(0,0,0,0.75);color:#fff;padding:6px 12px;border-radius:6px;'
161      +'font:13px/1 -apple-system,system-ui,sans-serif;pointer-events:none';
162    document.body.appendChild(indicator);
163  }}
164  function hideIndicator(){{
165    if(indicator){{indicator.remove();indicator=null;}}
166  }}
167  function showOverlay(msg){{
168    hideOverlay();
169    var d=document.createElement('div');
170    d.id='ssg-error-overlay';
171    d.style.cssText='position:fixed;top:0;left:0;right:0;bottom:0;background:rgba(0,0,0,0.85);color:#fff;font-family:monospace;font-size:14px;z-index:999999;padding:32px;overflow:auto;';
172    var c=document.createElement('div');
173    c.style.cssText='max-width:800px;margin:0 auto;';
174    var hdr=document.createElement('div');
175    hdr.style.cssText='display:flex;justify-content:space-between;align-items:center;margin-bottom:16px;';
176    var title=document.createElement('span');
177    title.style.cssText='color:#ff6b6b;font-size:18px;font-weight:bold;';
178    title.textContent='Build Error';
179    var btn=document.createElement('button');
180    btn.textContent='\u2715';
181    btn.style.cssText='background:none;border:1px solid #666;color:#fff;padding:4px 12px;cursor:pointer;border-radius:4px;';
182    btn.addEventListener('click',hideOverlay);
183    hdr.appendChild(title);
184    hdr.appendChild(btn);
185    c.appendChild(hdr);
186    if(msg.file){{
187      var fp=document.createElement('div');
188      fp.style.cssText='color:#ffd93d;margin-bottom:8px;';
189      fp.textContent=msg.file+(msg.line?':'+msg.line:'');
190      c.appendChild(fp);
191    }}
192    var pre=document.createElement('pre');
193    pre.style.cssText='background:#1a1a2e;padding:16px;border-radius:8px;border-left:4px solid #ff6b6b;overflow-x:auto;white-space:pre-wrap;word-break:break-word;';
194    pre.textContent=msg.message;
195    c.appendChild(pre);
196    d.appendChild(c);
197    document.body.appendChild(d);
198  }}
199  function hideOverlay(){{var e=document.getElementById('ssg-error-overlay');if(e)e.remove();}}
200  function connect(){{
201    try{{
202      var ws=new WebSocket(url);
203      ws.onopen=function(){{delay=1000;hideIndicator();}};
204      ws.onmessage=function(e){{
205        if(e.data==='reload'){{hideOverlay();try{{sessionStorage.setItem('ssg-scroll',JSON.stringify({{x:scrollX,y:scrollY}}));}}catch(se){{}}location.reload();}}
206        try{{var msg=JSON.parse(e.data);
207        if(msg.type==='error'){{showOverlay(msg);}}
208        else if(msg.type==='clear-error'){{hideOverlay();}}
209        else if(msg.type==='css-reload'){{
210          var links=document.querySelectorAll('link[rel=stylesheet]');
211          links.forEach(function(link){{
212            var href=link.getAttribute('href');
213            if(href){{link.setAttribute('href',href.split('?')[0]+'?v='+Date.now());}}
214          }});
215        }}
216        }}catch(x){{}}
217      }};
218      ws.onclose=function(){{
219        var d=delay;
220        delay=Math.min(delay*2,maxDelay);
221        setTimeout(connect,d);
222      }};
223      ws.onerror=function(){{}};
224    }}catch(e){{}}
225  }}
226  // Only connect in development (localhost) and limit retries
227  // to avoid console error spam when the WS server is not running
228  if(location.hostname==='localhost'||location.hostname==='127.0.0.1'||location.hostname==='0.0.0.0'){{
229    if(document.readyState==='loading'){{
230      document.addEventListener('DOMContentLoaded',connect);
231    }}else{{
232      connect();
233    }}
234  }}
235}})();
236</script>
237"
238    )
239}
240
241/// Returns a WebSocket message for CSS-only reload.
242#[must_use]
243#[allow(dead_code)]
244pub fn css_reload_message(css_path: &str) -> String {
245    serde_json::json!({
246        "type": "css-reload",
247        "file": css_path,
248    })
249    .to_string()
250}
251
252#[cfg(test)]
253#[allow(clippy::unwrap_used, clippy::expect_used)]
254mod tests {
255    use super::*;
256    use tempfile::tempdir;
257
258    fn make_html(body: &str) -> String {
259        format!(
260            "<html><head><title>Test</title></head>\
261             <body>{body}</body></html>"
262        )
263    }
264
265    #[test]
266    fn inject_adds_script() -> Result<()> {
267        let tmp = tempdir()?;
268        let path = tmp.path().join("page.html");
269        fs::write(&path, make_html("<p>Hello</p>"))?;
270
271        inject_livereload(&path, DEFAULT_PORT)?;
272
273        let result = fs::read_to_string(&path)?;
274        assert!(result.contains(MARKER));
275        assert!(result.contains("WebSocket"));
276        assert!(result.contains("35729"));
277        assert!(result.contains("location.reload()"));
278        Ok(())
279    }
280
281    #[test]
282    fn inject_before_closing_body() -> Result<()> {
283        let tmp = tempdir()?;
284        let path = tmp.path().join("page.html");
285        fs::write(&path, make_html("<p>Hi</p>"))?;
286
287        inject_livereload(&path, DEFAULT_PORT)?;
288
289        let result = fs::read_to_string(&path)?;
290        let script_pos = result.find(MARKER).unwrap();
291        let body_pos = result.rfind("</body>").unwrap();
292        assert!(script_pos < body_pos);
293        Ok(())
294    }
295
296    #[test]
297    fn inject_idempotent() -> Result<()> {
298        let tmp = tempdir()?;
299        let path = tmp.path().join("page.html");
300        fs::write(&path, make_html("<p>Hi</p>"))?;
301
302        inject_livereload(&path, DEFAULT_PORT)?;
303        let first = fs::read_to_string(&path)?;
304
305        inject_livereload(&path, DEFAULT_PORT)?;
306        let second = fs::read_to_string(&path)?;
307
308        assert_eq!(first, second);
309        Ok(())
310    }
311
312    #[test]
313    fn inject_custom_port() -> Result<()> {
314        let tmp = tempdir()?;
315        let path = tmp.path().join("page.html");
316        fs::write(&path, make_html("<p>Hi</p>"))?;
317
318        inject_livereload(&path, 9999)?;
319
320        let result = fs::read_to_string(&path)?;
321        assert!(result.contains("9999"));
322        assert!(!result.contains("35729"));
323        Ok(())
324    }
325
326    #[test]
327    fn inject_no_body_tag() -> Result<()> {
328        let tmp = tempdir()?;
329        let path = tmp.path().join("page.html");
330        fs::write(&path, "<html><p>No body tag</p></html>")?;
331
332        inject_livereload(&path, DEFAULT_PORT)?;
333
334        let result = fs::read_to_string(&path)?;
335        assert!(result.contains(MARKER));
336        Ok(())
337    }
338
339    #[test]
340    fn skip_non_html_files() -> Result<()> {
341        let tmp = tempdir()?;
342        fs::write(tmp.path().join("style.css"), "body{}")?;
343        fs::write(tmp.path().join("data.json"), "{}")?;
344        fs::write(tmp.path().join("readme.txt"), "hello")?;
345
346        let files = collect_html_files(tmp.path())?;
347        assert!(files.is_empty());
348        Ok(())
349    }
350
351    #[test]
352    fn empty_directory() -> Result<()> {
353        let tmp = tempdir()?;
354        let files = collect_html_files(tmp.path())?;
355        assert!(files.is_empty());
356        Ok(())
357    }
358
359    #[test]
360    fn nonexistent_directory() {
361        let ctx = PluginContext::new(
362            Path::new("c"),
363            Path::new("b"),
364            Path::new("/nonexistent_dir_ssg_test"),
365            Path::new("t"),
366        );
367        let plugin = LiveReloadPlugin::new();
368        assert!(plugin.on_serve(&ctx).is_ok());
369    }
370
371    #[test]
372    fn plugin_name() {
373        assert_eq!(LiveReloadPlugin::new().name(), "livereload");
374    }
375
376    #[test]
377    fn plugin_registration() {
378        use crate::plugin::PluginManager;
379        let mut pm = PluginManager::new();
380        pm.register(LiveReloadPlugin::new());
381        assert_eq!(pm.names(), vec!["livereload"]);
382    }
383
384    #[test]
385    fn with_port_constructor() {
386        let plugin = LiveReloadPlugin::with_port(8080);
387        assert_eq!(plugin.port(), 8080);
388    }
389
390    #[test]
391    fn default_port_value() {
392        let plugin = LiveReloadPlugin::new();
393        assert_eq!(plugin.port(), 35729);
394    }
395
396    #[test]
397    fn default_trait_impl() {
398        let plugin = LiveReloadPlugin::default();
399        assert_eq!(plugin.port(), DEFAULT_PORT);
400    }
401
402    #[test]
403    fn on_serve_injects_all_html_files() -> Result<()> {
404        let tmp = tempdir()?;
405        fs::write(tmp.path().join("index.html"), make_html("<p>Home</p>"))?;
406        fs::write(tmp.path().join("about.html"), make_html("<p>About</p>"))?;
407        fs::write(tmp.path().join("style.css"), "body{}")?;
408
409        let ctx = PluginContext::new(
410            Path::new("content"),
411            Path::new("build"),
412            tmp.path(),
413            Path::new("templates"),
414        );
415        LiveReloadPlugin::new().on_serve(&ctx)?;
416
417        let index = fs::read_to_string(tmp.path().join("index.html"))?;
418        let about = fs::read_to_string(tmp.path().join("about.html"))?;
419        let css = fs::read_to_string(tmp.path().join("style.css"))?;
420
421        assert!(index.contains(MARKER));
422        assert!(about.contains(MARKER));
423        assert!(!css.contains(MARKER));
424        Ok(())
425    }
426
427    #[test]
428    fn script_contains_reconnect_backoff() {
429        let script = livereload_script(DEFAULT_PORT);
430        assert!(script.contains("delay*2"));
431        assert!(script.contains("maxDelay"));
432        assert!(script.contains("10000"));
433    }
434
435    #[test]
436    fn script_contains_connecting_indicator() {
437        let script = livereload_script(DEFAULT_PORT);
438        assert!(script.contains("Connecting"));
439        assert!(script.contains("showIndicator"));
440        assert!(script.contains("hideIndicator"));
441        assert!(script.contains("bottom"));
442        assert!(script.contains("right"));
443    }
444
445    #[test]
446    fn livereload_custom_port() {
447        // Arrange
448        let port: u16 = 44444;
449
450        // Act
451        let script = livereload_script(port);
452
453        // Assert — custom port appears, default does not
454        assert!(script.contains("44444"));
455        assert!(!script.contains("35729"));
456    }
457
458    #[test]
459    fn livereload_plugin_no_html_files() -> Result<()> {
460        // Arrange
461        let tmp = tempdir()?;
462        fs::write(tmp.path().join("style.css"), "body{}")?;
463        fs::write(tmp.path().join("data.json"), "{}")?;
464
465        let ctx = PluginContext::new(
466            Path::new("content"),
467            Path::new("build"),
468            tmp.path(),
469            Path::new("templates"),
470        );
471
472        // Act
473        let result = LiveReloadPlugin::new().on_serve(&ctx);
474
475        // Assert
476        assert!(result.is_ok());
477        Ok(())
478    }
479
480    #[test]
481    fn livereload_plugin_idempotent() -> Result<()> {
482        // Arrange
483        let tmp = tempdir()?;
484        let html_path = tmp.path().join("page.html");
485        fs::write(&html_path, make_html("<p>Hello</p>"))?;
486
487        let ctx = PluginContext::new(
488            Path::new("content"),
489            Path::new("build"),
490            tmp.path(),
491            Path::new("templates"),
492        );
493
494        // Act — run the full plugin twice
495        LiveReloadPlugin::new().on_serve(&ctx)?;
496        let after_first = fs::read_to_string(&html_path)?;
497
498        LiveReloadPlugin::new().on_serve(&ctx)?;
499        let after_second = fs::read_to_string(&html_path)?;
500
501        // Assert — content identical, no double injection
502        assert_eq!(after_first, after_second);
503        // The marker string appears in both the data attribute and the
504        // indicator id within a single injection, so count the script tags.
505        let script_count = after_second.matches("data-ssg-livereload").count();
506        assert_eq!(script_count, 1, "script tag should appear exactly once");
507        Ok(())
508    }
509
510    #[test]
511    fn livereload_script_contains_reconnect_logic() {
512        // Arrange & Act
513        let script = livereload_script(DEFAULT_PORT);
514
515        // Assert — script has exponential backoff reconnection
516        assert!(script.contains("delay*2"), "should double the delay");
517        assert!(script.contains("maxDelay"), "should cap the delay");
518        assert!(script.contains("setTimeout"), "should schedule reconnect");
519        assert!(script.contains("connect"), "should call connect again");
520    }
521
522    #[test]
523    fn livereload_plugin_nonexistent_dir() -> Result<()> {
524        // Arrange
525        let ctx = PluginContext::new(
526            Path::new("content"),
527            Path::new("build"),
528            Path::new("/absolutely/nonexistent/directory/for/test"),
529            Path::new("templates"),
530        );
531
532        // Act
533        let result = LiveReloadPlugin::new().on_serve(&ctx);
534
535        // Assert — returns Ok, does not error on missing directory
536        assert!(result.is_ok());
537        Ok(())
538    }
539
540    #[test]
541    fn test_script_contains_error_overlay() {
542        let script = livereload_script(DEFAULT_PORT);
543        assert!(
544            script.contains("showOverlay"),
545            "script must contain showOverlay function"
546        );
547        assert!(
548            script.contains("hideOverlay"),
549            "script must contain hideOverlay function"
550        );
551        assert!(
552            script.contains("ssg-error-overlay"),
553            "script must contain overlay element id"
554        );
555    }
556
557    #[test]
558    fn test_script_backward_compat() {
559        let script = livereload_script(DEFAULT_PORT);
560        assert!(
561            script.contains("'reload'"),
562            "script must still handle plain 'reload' messages"
563        );
564    }
565
566    #[test]
567    fn test_script_contains_css_reload() {
568        let script = livereload_script(DEFAULT_PORT);
569        assert!(
570            script.contains("css-reload"),
571            "script must contain css-reload handler"
572        );
573    }
574
575    #[test]
576    fn test_script_contains_scroll_preservation() {
577        let script = livereload_script(DEFAULT_PORT);
578        assert!(
579            script.contains("ssg-scroll"),
580            "script must contain scroll preservation key"
581        );
582        assert!(
583            script.contains("sessionStorage"),
584            "script must use sessionStorage for scroll"
585        );
586    }
587
588    #[test]
589    fn test_css_reload_message() {
590        let msg = css_reload_message("styles/main.css");
591        let parsed: serde_json::Value =
592            serde_json::from_str(&msg).expect("valid JSON");
593        assert_eq!(parsed["type"], "css-reload");
594        assert_eq!(parsed["file"], "styles/main.css");
595    }
596}