aranet_core/
reconnect.rs

1//! Automatic reconnection handling for Aranet devices.
2//!
3//! This module provides a wrapper around Device that automatically
4//! handles reconnection when the connection is lost.
5//!
6//! [`ReconnectingDevice`] implements the [`AranetDevice`] trait,
7//! allowing it to be used interchangeably with regular devices in generic code.
8
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Duration;
12
13use async_trait::async_trait;
14use tokio::sync::RwLock;
15use tokio::time::sleep;
16use tracing::{info, warn};
17
18use aranet_types::{CurrentReading, DeviceInfo, DeviceType, HistoryRecord};
19
20use crate::device::Device;
21use crate::error::{Error, Result};
22use crate::events::{DeviceEvent, DeviceId, EventSender};
23use crate::history::{HistoryInfo, HistoryOptions};
24use crate::settings::{CalibrationData, MeasurementInterval};
25use crate::traits::AranetDevice;
26
27/// Options for automatic reconnection.
28#[derive(Debug, Clone)]
29pub struct ReconnectOptions {
30    /// Maximum number of reconnection attempts (None = unlimited).
31    pub max_attempts: Option<u32>,
32    /// Initial delay before first reconnection attempt.
33    pub initial_delay: Duration,
34    /// Maximum delay between attempts (for exponential backoff).
35    pub max_delay: Duration,
36    /// Multiplier for exponential backoff.
37    pub backoff_multiplier: f64,
38    /// Whether to use exponential backoff.
39    pub use_exponential_backoff: bool,
40}
41
42impl Default for ReconnectOptions {
43    fn default() -> Self {
44        Self {
45            max_attempts: Some(5),
46            initial_delay: Duration::from_secs(1),
47            max_delay: Duration::from_secs(60),
48            backoff_multiplier: 2.0,
49            use_exponential_backoff: true,
50        }
51    }
52}
53
54impl ReconnectOptions {
55    /// Create new reconnect options with defaults.
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Create options with unlimited retry attempts.
61    pub fn unlimited() -> Self {
62        Self {
63            max_attempts: None,
64            ..Default::default()
65        }
66    }
67
68    /// Create options with a fixed delay (no backoff).
69    pub fn fixed_delay(delay: Duration) -> Self {
70        Self {
71            initial_delay: delay,
72            use_exponential_backoff: false,
73            ..Default::default()
74        }
75    }
76
77    /// Set maximum number of reconnection attempts.
78    pub fn max_attempts(mut self, attempts: u32) -> Self {
79        self.max_attempts = Some(attempts);
80        self
81    }
82
83    /// Set initial delay before first reconnection attempt.
84    pub fn initial_delay(mut self, delay: Duration) -> Self {
85        self.initial_delay = delay;
86        self
87    }
88
89    /// Set maximum delay between attempts.
90    pub fn max_delay(mut self, delay: Duration) -> Self {
91        self.max_delay = delay;
92        self
93    }
94
95    /// Set backoff multiplier for exponential backoff.
96    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
97        self.backoff_multiplier = multiplier;
98        self
99    }
100
101    /// Enable or disable exponential backoff.
102    pub fn exponential_backoff(mut self, enabled: bool) -> Self {
103        self.use_exponential_backoff = enabled;
104        self
105    }
106
107    /// Calculate delay for a given attempt number.
108    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
109        if !self.use_exponential_backoff {
110            return self.initial_delay;
111        }
112
113        let delay_ms =
114            self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
115        let delay = Duration::from_millis(delay_ms as u64);
116
117        delay.min(self.max_delay)
118    }
119
120    /// Validate the options and return an error if invalid.
121    ///
122    /// Checks that:
123    /// - `backoff_multiplier` is >= 1.0
124    /// - `initial_delay` is > 0
125    /// - `max_delay` >= `initial_delay`
126    pub fn validate(&self) -> Result<()> {
127        if self.backoff_multiplier < 1.0 {
128            return Err(Error::InvalidConfig(
129                "backoff_multiplier must be >= 1.0".to_string(),
130            ));
131        }
132        if self.initial_delay.is_zero() {
133            return Err(Error::InvalidConfig(
134                "initial_delay must be > 0".to_string(),
135            ));
136        }
137        if self.max_delay < self.initial_delay {
138            return Err(Error::InvalidConfig(
139                "max_delay must be >= initial_delay".to_string(),
140            ));
141        }
142        Ok(())
143    }
144}
145
146/// State of the reconnecting device.
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ConnectionState {
149    /// Device is connected and operational.
150    Connected,
151    /// Device is disconnected.
152    Disconnected,
153    /// Attempting to reconnect.
154    Reconnecting,
155    /// Reconnection failed after max attempts.
156    Failed,
157}
158
159/// A device wrapper that automatically handles reconnection.
160///
161/// This wrapper caches the device name and type upon initial connection so they
162/// can be accessed synchronously via the [`AranetDevice`] trait, even while
163/// reconnecting.
164pub struct ReconnectingDevice {
165    identifier: String,
166    /// The connected device, wrapped in Arc to allow concurrent access.
167    device: RwLock<Option<Arc<Device>>>,
168    options: ReconnectOptions,
169    state: RwLock<ConnectionState>,
170    event_sender: Option<EventSender>,
171    attempt_count: RwLock<u32>,
172    /// Cancellation flag for stopping reconnection attempts.
173    cancelled: Arc<AtomicBool>,
174    /// Cached device name (populated on first connection).
175    cached_name: std::sync::OnceLock<String>,
176    /// Cached device type (populated on first connection).
177    cached_device_type: std::sync::OnceLock<DeviceType>,
178}
179
180impl ReconnectingDevice {
181    /// Create a new reconnecting device wrapper.
182    pub async fn connect(identifier: &str, options: ReconnectOptions) -> Result<Self> {
183        let device = Arc::new(Device::connect(identifier).await?);
184
185        // Cache the name and device type for synchronous access
186        let cached_name = std::sync::OnceLock::new();
187        if let Some(name) = device.name() {
188            let _ = cached_name.set(name.to_string());
189        }
190
191        let cached_device_type = std::sync::OnceLock::new();
192        if let Some(device_type) = device.device_type() {
193            let _ = cached_device_type.set(device_type);
194        }
195
196        Ok(Self {
197            identifier: identifier.to_string(),
198            device: RwLock::new(Some(device)),
199            options,
200            state: RwLock::new(ConnectionState::Connected),
201            event_sender: None,
202            attempt_count: RwLock::new(0),
203            cancelled: Arc::new(AtomicBool::new(false)),
204            cached_name,
205            cached_device_type,
206        })
207    }
208
209    /// Create with an event sender for notifications.
210    pub async fn connect_with_events(
211        identifier: &str,
212        options: ReconnectOptions,
213        event_sender: EventSender,
214    ) -> Result<Self> {
215        let mut this = Self::connect(identifier, options).await?;
216        this.event_sender = Some(event_sender);
217        Ok(this)
218    }
219
220    /// Cancel any ongoing reconnection attempts.
221    ///
222    /// This will cause the reconnect loop to exit on its next iteration.
223    pub fn cancel_reconnect(&self) {
224        self.cancelled.store(true, Ordering::SeqCst);
225    }
226
227    /// Check if reconnection has been cancelled.
228    pub fn is_cancelled(&self) -> bool {
229        self.cancelled.load(Ordering::SeqCst)
230    }
231
232    /// Reset the cancellation flag.
233    ///
234    /// Call this before starting a new reconnection attempt if you want to clear
235    /// a previous cancellation. The `reconnect()` method will check if cancelled
236    /// at the start of each iteration, so this allows re-using a previously
237    /// cancelled `ReconnectingDevice`.
238    pub fn reset_cancellation(&self) {
239        self.cancelled.store(false, Ordering::SeqCst);
240    }
241
242    /// Get the current connection state.
243    pub async fn state(&self) -> ConnectionState {
244        *self.state.read().await
245    }
246
247    /// Check if currently connected.
248    pub async fn is_connected(&self) -> bool {
249        let guard = self.device.read().await;
250        if let Some(device) = guard.as_ref() {
251            device.is_connected().await
252        } else {
253            false
254        }
255    }
256
257    /// Get the identifier.
258    pub fn identifier(&self) -> &str {
259        &self.identifier
260    }
261
262    /// Execute an operation, reconnecting if necessary.
263    ///
264    /// The closure is called with a reference to the device. If the operation
265    /// fails due to a connection issue, the device will attempt to reconnect
266    /// and retry the operation.
267    ///
268    /// # Example
269    ///
270    /// ```ignore
271    /// let reading = device.with_device(|d| async { d.read_current().await }).await?;
272    /// ```
273    pub async fn with_device<F, Fut, T>(&self, f: F) -> Result<T>
274    where
275        F: Fn(&Device) -> Fut,
276        Fut: std::future::Future<Output = Result<T>>,
277    {
278        // Try the operation if already connected
279        {
280            let guard = self.device.read().await;
281            if let Some(device) = guard.as_ref()
282                && device.is_connected().await
283            {
284                match f(device).await {
285                    Ok(result) => return Ok(result),
286                    Err(e) => {
287                        warn!("Operation failed: {}", e);
288                        // Fall through to reconnect
289                    }
290                }
291            }
292        }
293
294        // Need to reconnect
295        self.reconnect().await?;
296
297        // Retry the operation after reconnection
298        let guard = self.device.read().await;
299        if let Some(device) = guard.as_ref() {
300            f(device).await
301        } else {
302            Err(Error::NotConnected)
303        }
304    }
305
306    /// Internal helper that executes an operation with automatic reconnection using boxed futures.
307    ///
308    /// This method uses explicit HRTB (Higher-Rank Trait Bounds) to handle the complex
309    /// lifetime requirements when returning futures from closures. It's used internally
310    /// by the `AranetDevice` trait implementation.
311    ///
312    /// Note: We cannot consolidate this with `with_device` due to Rust's async closure
313    /// lifetime limitations. The `with_device` method provides a more ergonomic API for
314    /// callers, while this method handles the trait implementation requirements.
315    async fn run_with_reconnect<'a, T, F>(&'a self, f: F) -> Result<T>
316    where
317        F: for<'b> Fn(
318                &'b Device,
319            ) -> std::pin::Pin<
320                Box<dyn std::future::Future<Output = Result<T>> + Send + 'b>,
321            > + Send
322            + Sync,
323        T: Send,
324    {
325        // Try the operation if already connected
326        {
327            let guard = self.device.read().await;
328            if let Some(device) = guard.as_ref()
329                && device.is_connected().await
330            {
331                match f(device).await {
332                    Ok(result) => return Ok(result),
333                    Err(e) => {
334                        warn!("Operation failed: {}", e);
335                        // Fall through to reconnect
336                    }
337                }
338            }
339        }
340
341        // Need to reconnect
342        self.reconnect().await?;
343
344        // Retry the operation after reconnection
345        let guard = self.device.read().await;
346        if let Some(device) = guard.as_ref() {
347            f(device).await
348        } else {
349            Err(Error::NotConnected)
350        }
351    }
352
353    /// Attempt to reconnect to the device.
354    ///
355    /// This loop can be cancelled by calling `cancel_reconnect()` from another task.
356    /// When cancelled, returns `Error::Cancelled`.
357    ///
358    /// Note: If `cancel_reconnect()` was called before this method, reconnection
359    /// will still proceed. Call `reset_cancellation()` explicitly if you want to
360    /// clear a previous cancellation before starting a new reconnection attempt.
361    pub async fn reconnect(&self) -> Result<()> {
362        // Only reset if not already cancelled - this prevents a race condition
363        // where cancel_reconnect() is called just before reconnect() starts
364        // and would be immediately cleared.
365        if !self.is_cancelled() {
366            self.reset_cancellation();
367        }
368
369        *self.state.write().await = ConnectionState::Reconnecting;
370        *self.attempt_count.write().await = 0;
371
372        loop {
373            // Check for cancellation at the start of each iteration
374            if self.is_cancelled() {
375                *self.state.write().await = ConnectionState::Disconnected;
376                info!("Reconnection cancelled for {}", self.identifier);
377                return Err(Error::Cancelled);
378            }
379
380            let attempt = {
381                let mut count = self.attempt_count.write().await;
382                *count += 1;
383                *count
384            };
385
386            // Check if we've exceeded max attempts
387            if let Some(max) = self.options.max_attempts
388                && attempt > max
389            {
390                *self.state.write().await = ConnectionState::Failed;
391                return Err(Error::Timeout {
392                    operation: format!("reconnect to '{}'", self.identifier),
393                    duration: self.options.max_delay * max,
394                });
395            }
396
397            // Send reconnect started event
398            if let Some(sender) = &self.event_sender {
399                let _ = sender.send(DeviceEvent::ReconnectStarted {
400                    device: DeviceId::new(&self.identifier),
401                    attempt,
402                });
403            }
404
405            info!("Reconnection attempt {} for {}", attempt, self.identifier);
406
407            // Wait before attempting (check cancellation during sleep)
408            let delay = self.options.delay_for_attempt(attempt - 1);
409            sleep(delay).await;
410
411            // Check for cancellation after sleep
412            if self.is_cancelled() {
413                *self.state.write().await = ConnectionState::Disconnected;
414                info!("Reconnection cancelled for {}", self.identifier);
415                return Err(Error::Cancelled);
416            }
417
418            // Try to connect
419            match Device::connect(&self.identifier).await {
420                Ok(new_device) => {
421                    *self.device.write().await = Some(Arc::new(new_device));
422                    *self.state.write().await = ConnectionState::Connected;
423
424                    // Send reconnect succeeded event
425                    if let Some(sender) = &self.event_sender {
426                        let _ = sender.send(DeviceEvent::ReconnectSucceeded {
427                            device: DeviceId::new(&self.identifier),
428                            attempts: attempt,
429                        });
430                    }
431
432                    info!("Reconnected successfully after {} attempts", attempt);
433                    return Ok(());
434                }
435                Err(e) => {
436                    warn!("Reconnection attempt {} failed: {}", attempt, e);
437                }
438            }
439        }
440    }
441
442    /// Disconnect from the device.
443    pub async fn disconnect(&self) -> Result<()> {
444        let mut guard = self.device.write().await;
445        if let Some(device) = guard.take() {
446            device.disconnect().await?;
447        }
448        *self.state.write().await = ConnectionState::Disconnected;
449        Ok(())
450    }
451
452    /// Get the number of reconnection attempts made.
453    pub async fn attempt_count(&self) -> u32 {
454        *self.attempt_count.read().await
455    }
456
457    /// Get the device name, if available and connected.
458    pub async fn name(&self) -> Option<String> {
459        let guard = self.device.read().await;
460        guard.as_ref().and_then(|d| d.name().map(|s| s.to_string()))
461    }
462
463    /// Get the device address (returns identifier if not connected).
464    pub async fn address(&self) -> String {
465        let guard = self.device.read().await;
466        guard
467            .as_ref()
468            .map(|d| d.address().to_string())
469            .unwrap_or_else(|| self.identifier.clone())
470    }
471
472    /// Get the detected device type, if available.
473    pub async fn device_type(&self) -> Option<DeviceType> {
474        let guard = self.device.read().await;
475        guard.as_ref().and_then(|d| d.device_type())
476    }
477}
478
479// Implement the AranetDevice trait for ReconnectingDevice
480#[async_trait]
481impl AranetDevice for ReconnectingDevice {
482    async fn is_connected(&self) -> bool {
483        ReconnectingDevice::is_connected(self).await
484    }
485
486    async fn connect(&self) -> Result<()> {
487        // If already connected, this is a no-op
488        if self.is_connected().await {
489            return Ok(());
490        }
491        // Otherwise, attempt to reconnect
492        self.reconnect().await
493    }
494
495    async fn disconnect(&self) -> Result<()> {
496        ReconnectingDevice::disconnect(self).await
497    }
498
499    fn name(&self) -> Option<&str> {
500        self.cached_name.get().map(|s| s.as_str())
501    }
502
503    fn address(&self) -> &str {
504        &self.identifier
505    }
506
507    fn device_type(&self) -> Option<DeviceType> {
508        self.cached_device_type.get().copied()
509    }
510
511    async fn read_current(&self) -> Result<CurrentReading> {
512        self.run_with_reconnect(|d| Box::pin(d.read_current()))
513            .await
514    }
515
516    async fn read_device_info(&self) -> Result<DeviceInfo> {
517        self.run_with_reconnect(|d| Box::pin(d.read_device_info()))
518            .await
519    }
520
521    async fn read_rssi(&self) -> Result<i16> {
522        self.run_with_reconnect(|d| Box::pin(d.read_rssi())).await
523    }
524
525    async fn read_battery(&self) -> Result<u8> {
526        self.run_with_reconnect(|d| Box::pin(d.read_battery()))
527            .await
528    }
529
530    async fn get_history_info(&self) -> Result<HistoryInfo> {
531        self.run_with_reconnect(|d| Box::pin(d.get_history_info()))
532            .await
533    }
534
535    async fn download_history(&self) -> Result<Vec<HistoryRecord>> {
536        self.run_with_reconnect(|d| Box::pin(d.download_history()))
537            .await
538    }
539
540    async fn download_history_with_options(
541        &self,
542        options: HistoryOptions,
543    ) -> Result<Vec<HistoryRecord>> {
544        let opts = options.clone();
545        self.run_with_reconnect(move |d| {
546            let opts = opts.clone();
547            Box::pin(async move { d.download_history_with_options(opts).await })
548        })
549        .await
550    }
551
552    async fn get_interval(&self) -> Result<MeasurementInterval> {
553        self.run_with_reconnect(|d| Box::pin(d.get_interval()))
554            .await
555    }
556
557    async fn set_interval(&self, interval: MeasurementInterval) -> Result<()> {
558        self.run_with_reconnect(move |d| Box::pin(d.set_interval(interval)))
559            .await
560    }
561
562    async fn get_calibration(&self) -> Result<CalibrationData> {
563        self.run_with_reconnect(|d| Box::pin(d.get_calibration()))
564            .await
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_reconnect_options_default() {
574        let opts = ReconnectOptions::default();
575        assert_eq!(opts.max_attempts, Some(5));
576        assert!(opts.use_exponential_backoff);
577    }
578
579    #[test]
580    fn test_reconnect_options_unlimited() {
581        let opts = ReconnectOptions::unlimited();
582        assert!(opts.max_attempts.is_none());
583    }
584
585    #[test]
586    fn test_delay_calculation() {
587        let opts = ReconnectOptions {
588            initial_delay: Duration::from_secs(1),
589            max_delay: Duration::from_secs(60),
590            backoff_multiplier: 2.0,
591            use_exponential_backoff: true,
592            ..Default::default()
593        };
594
595        assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(1));
596        assert_eq!(opts.delay_for_attempt(1), Duration::from_secs(2));
597        assert_eq!(opts.delay_for_attempt(2), Duration::from_secs(4));
598        assert_eq!(opts.delay_for_attempt(3), Duration::from_secs(8));
599    }
600
601    #[test]
602    fn test_delay_capped_at_max() {
603        let opts = ReconnectOptions {
604            initial_delay: Duration::from_secs(1),
605            max_delay: Duration::from_secs(10),
606            backoff_multiplier: 2.0,
607            use_exponential_backoff: true,
608            ..Default::default()
609        };
610
611        // 2^10 = 1024 seconds, but capped at 10
612        assert_eq!(opts.delay_for_attempt(10), Duration::from_secs(10));
613    }
614
615    #[test]
616    fn test_fixed_delay() {
617        let opts = ReconnectOptions::fixed_delay(Duration::from_secs(5));
618        assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(5));
619        assert_eq!(opts.delay_for_attempt(5), Duration::from_secs(5));
620    }
621}