1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Duration;
12
13use async_trait::async_trait;
14use tokio::sync::RwLock;
15use tokio::time::sleep;
16use tracing::{info, warn};
17
18use aranet_types::{CurrentReading, DeviceInfo, DeviceType, HistoryRecord};
19
20use crate::device::Device;
21use crate::error::{Error, Result};
22use crate::events::{DeviceEvent, DeviceId, EventSender};
23use crate::history::{HistoryInfo, HistoryOptions};
24use crate::settings::{CalibrationData, MeasurementInterval};
25use crate::traits::AranetDevice;
26
27#[derive(Debug, Clone)]
29pub struct ReconnectOptions {
30 pub max_attempts: Option<u32>,
32 pub initial_delay: Duration,
34 pub max_delay: Duration,
36 pub backoff_multiplier: f64,
38 pub use_exponential_backoff: bool,
40}
41
42impl Default for ReconnectOptions {
43 fn default() -> Self {
44 Self {
45 max_attempts: Some(5),
46 initial_delay: Duration::from_secs(1),
47 max_delay: Duration::from_secs(60),
48 backoff_multiplier: 2.0,
49 use_exponential_backoff: true,
50 }
51 }
52}
53
54impl ReconnectOptions {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn unlimited() -> Self {
62 Self {
63 max_attempts: None,
64 ..Default::default()
65 }
66 }
67
68 pub fn fixed_delay(delay: Duration) -> Self {
70 Self {
71 initial_delay: delay,
72 use_exponential_backoff: false,
73 ..Default::default()
74 }
75 }
76
77 pub fn max_attempts(mut self, attempts: u32) -> Self {
79 self.max_attempts = Some(attempts);
80 self
81 }
82
83 pub fn initial_delay(mut self, delay: Duration) -> Self {
85 self.initial_delay = delay;
86 self
87 }
88
89 pub fn max_delay(mut self, delay: Duration) -> Self {
91 self.max_delay = delay;
92 self
93 }
94
95 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
97 self.backoff_multiplier = multiplier;
98 self
99 }
100
101 pub fn exponential_backoff(mut self, enabled: bool) -> Self {
103 self.use_exponential_backoff = enabled;
104 self
105 }
106
107 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
109 if !self.use_exponential_backoff {
110 return self.initial_delay;
111 }
112
113 let capped_attempt = attempt.min(32);
116 let delay_ms = self.initial_delay.as_millis() as f64
117 * self.backoff_multiplier.powi(capped_attempt as i32);
118
119 let delay = if delay_ms.is_finite() && delay_ms <= u64::MAX as f64 {
121 Duration::from_millis(delay_ms as u64)
122 } else {
123 self.max_delay
124 };
125
126 delay.min(self.max_delay)
127 }
128
129 pub fn validate(&self) -> Result<()> {
136 if self.backoff_multiplier < 1.0 {
137 return Err(Error::InvalidConfig(
138 "backoff_multiplier must be >= 1.0".to_string(),
139 ));
140 }
141 if self.initial_delay.is_zero() {
142 return Err(Error::InvalidConfig(
143 "initial_delay must be > 0".to_string(),
144 ));
145 }
146 if self.max_delay < self.initial_delay {
147 return Err(Error::InvalidConfig(
148 "max_delay must be >= initial_delay".to_string(),
149 ));
150 }
151 Ok(())
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum ConnectionState {
158 Connected,
160 Disconnected,
162 Reconnecting,
164 Failed,
166}
167
168pub struct ReconnectingDevice {
174 identifier: String,
175 device: RwLock<Option<Arc<Device>>>,
177 options: ReconnectOptions,
178 state: RwLock<ConnectionState>,
179 event_sender: Option<EventSender>,
180 attempt_count: RwLock<u32>,
181 cancelled: Arc<AtomicBool>,
183 cached_name: std::sync::OnceLock<String>,
185 cached_device_type: std::sync::OnceLock<DeviceType>,
187}
188
189impl ReconnectingDevice {
190 pub async fn connect(identifier: &str, options: ReconnectOptions) -> Result<Self> {
192 let device = Arc::new(Device::connect(identifier).await?);
193
194 let cached_name = std::sync::OnceLock::new();
196 if let Some(name) = device.name() {
197 let _ = cached_name.set(name.to_string());
198 }
199
200 let cached_device_type = std::sync::OnceLock::new();
201 if let Some(device_type) = device.device_type() {
202 let _ = cached_device_type.set(device_type);
203 }
204
205 Ok(Self {
206 identifier: identifier.to_string(),
207 device: RwLock::new(Some(device)),
208 options,
209 state: RwLock::new(ConnectionState::Connected),
210 event_sender: None,
211 attempt_count: RwLock::new(0),
212 cancelled: Arc::new(AtomicBool::new(false)),
213 cached_name,
214 cached_device_type,
215 })
216 }
217
218 pub async fn connect_with_events(
220 identifier: &str,
221 options: ReconnectOptions,
222 event_sender: EventSender,
223 ) -> Result<Self> {
224 let mut this = Self::connect(identifier, options).await?;
225 this.event_sender = Some(event_sender);
226 Ok(this)
227 }
228
229 pub fn cancel_reconnect(&self) {
233 self.cancelled.store(true, Ordering::SeqCst);
234 }
235
236 pub fn is_cancelled(&self) -> bool {
238 self.cancelled.load(Ordering::SeqCst)
239 }
240
241 pub fn reset_cancellation(&self) {
248 self.cancelled.store(false, Ordering::SeqCst);
249 }
250
251 pub async fn state(&self) -> ConnectionState {
253 *self.state.read().await
254 }
255
256 pub async fn is_connected(&self) -> bool {
258 let guard = self.device.read().await;
259 if let Some(device) = guard.as_ref() {
260 device.is_connected().await
261 } else {
262 false
263 }
264 }
265
266 pub fn identifier(&self) -> &str {
268 &self.identifier
269 }
270
271 pub async fn with_device<F, Fut, T>(&self, f: F) -> Result<T>
283 where
284 F: Fn(&Device) -> Fut,
285 Fut: std::future::Future<Output = Result<T>>,
286 {
287 {
289 let guard = self.device.read().await;
290 if let Some(device) = guard.as_ref()
291 && device.is_connected().await
292 {
293 match f(device).await {
294 Ok(result) => return Ok(result),
295 Err(e) => {
296 warn!("Operation failed: {}", e);
297 }
299 }
300 }
301 }
302
303 self.reconnect().await?;
305
306 let guard = self.device.read().await;
308 if let Some(device) = guard.as_ref() {
309 f(device).await
310 } else {
311 Err(Error::NotConnected)
312 }
313 }
314
315 async fn run_with_reconnect<'a, T, F>(&'a self, f: F) -> Result<T>
325 where
326 F: for<'b> Fn(
327 &'b Device,
328 ) -> std::pin::Pin<
329 Box<dyn std::future::Future<Output = Result<T>> + Send + 'b>,
330 > + Send
331 + Sync,
332 T: Send,
333 {
334 {
336 let guard = self.device.read().await;
337 if let Some(device) = guard.as_ref()
338 && device.is_connected().await
339 {
340 match f(device).await {
341 Ok(result) => return Ok(result),
342 Err(e) => {
343 warn!("Operation failed: {}", e);
344 }
346 }
347 }
348 }
349
350 self.reconnect().await?;
352
353 let guard = self.device.read().await;
355 if let Some(device) = guard.as_ref() {
356 f(device).await
357 } else {
358 Err(Error::NotConnected)
359 }
360 }
361
362 pub async fn reconnect(&self) -> Result<()> {
371 if !self.is_cancelled() {
375 self.reset_cancellation();
376 }
377
378 *self.state.write().await = ConnectionState::Reconnecting;
379 *self.attempt_count.write().await = 0;
380
381 loop {
382 if self.is_cancelled() {
384 *self.state.write().await = ConnectionState::Disconnected;
385 info!("Reconnection cancelled for {}", self.identifier);
386 return Err(Error::Cancelled);
387 }
388
389 let attempt = {
390 let mut count = self.attempt_count.write().await;
391 *count += 1;
392 *count
393 };
394
395 if let Some(max) = self.options.max_attempts
397 && attempt > max
398 {
399 *self.state.write().await = ConnectionState::Failed;
400 return Err(Error::Timeout {
401 operation: format!("reconnect to '{}'", self.identifier),
402 duration: self.options.max_delay * max,
403 });
404 }
405
406 if let Some(sender) = &self.event_sender {
408 let _ = sender.send(DeviceEvent::ReconnectStarted {
409 device: DeviceId::new(&self.identifier),
410 attempt,
411 });
412 }
413
414 info!("Reconnection attempt {} for {}", attempt, self.identifier);
415
416 let delay = self.options.delay_for_attempt(attempt - 1);
418 sleep(delay).await;
419
420 if self.is_cancelled() {
422 *self.state.write().await = ConnectionState::Disconnected;
423 info!("Reconnection cancelled for {}", self.identifier);
424 return Err(Error::Cancelled);
425 }
426
427 match Device::connect(&self.identifier).await {
429 Ok(new_device) => {
430 *self.device.write().await = Some(Arc::new(new_device));
431 *self.state.write().await = ConnectionState::Connected;
432
433 if let Some(sender) = &self.event_sender {
435 let _ = sender.send(DeviceEvent::ReconnectSucceeded {
436 device: DeviceId::new(&self.identifier),
437 attempts: attempt,
438 });
439 }
440
441 info!("Reconnected successfully after {} attempts", attempt);
442 return Ok(());
443 }
444 Err(e) => {
445 warn!("Reconnection attempt {} failed: {}", attempt, e);
446 }
447 }
448 }
449 }
450
451 pub async fn disconnect(&self) -> Result<()> {
453 let mut guard = self.device.write().await;
454 if let Some(device) = guard.take() {
455 device.disconnect().await?;
456 }
457 *self.state.write().await = ConnectionState::Disconnected;
458 Ok(())
459 }
460
461 pub async fn attempt_count(&self) -> u32 {
463 *self.attempt_count.read().await
464 }
465
466 pub async fn name(&self) -> Option<String> {
468 let guard = self.device.read().await;
469 guard.as_ref().and_then(|d| d.name().map(|s| s.to_string()))
470 }
471
472 pub async fn address(&self) -> String {
474 let guard = self.device.read().await;
475 guard
476 .as_ref()
477 .map(|d| d.address().to_string())
478 .unwrap_or_else(|| self.identifier.clone())
479 }
480
481 pub async fn device_type(&self) -> Option<DeviceType> {
483 let guard = self.device.read().await;
484 guard.as_ref().and_then(|d| d.device_type())
485 }
486}
487
488#[async_trait]
490impl AranetDevice for ReconnectingDevice {
491 async fn is_connected(&self) -> bool {
492 ReconnectingDevice::is_connected(self).await
493 }
494
495 async fn connect(&self) -> Result<()> {
496 if self.is_connected().await {
498 return Ok(());
499 }
500 self.reconnect().await
502 }
503
504 async fn disconnect(&self) -> Result<()> {
505 ReconnectingDevice::disconnect(self).await
506 }
507
508 fn name(&self) -> Option<&str> {
509 self.cached_name.get().map(|s| s.as_str())
510 }
511
512 fn address(&self) -> &str {
513 &self.identifier
514 }
515
516 fn device_type(&self) -> Option<DeviceType> {
517 self.cached_device_type.get().copied()
518 }
519
520 async fn read_current(&self) -> Result<CurrentReading> {
521 self.run_with_reconnect(|d| Box::pin(d.read_current()))
522 .await
523 }
524
525 async fn read_device_info(&self) -> Result<DeviceInfo> {
526 self.run_with_reconnect(|d| Box::pin(d.read_device_info()))
527 .await
528 }
529
530 async fn read_rssi(&self) -> Result<i16> {
531 self.run_with_reconnect(|d| Box::pin(d.read_rssi())).await
532 }
533
534 async fn read_battery(&self) -> Result<u8> {
535 self.run_with_reconnect(|d| Box::pin(d.read_battery()))
536 .await
537 }
538
539 async fn get_history_info(&self) -> Result<HistoryInfo> {
540 self.run_with_reconnect(|d| Box::pin(d.get_history_info()))
541 .await
542 }
543
544 async fn download_history(&self) -> Result<Vec<HistoryRecord>> {
545 self.run_with_reconnect(|d| Box::pin(d.download_history()))
546 .await
547 }
548
549 async fn download_history_with_options(
550 &self,
551 options: HistoryOptions,
552 ) -> Result<Vec<HistoryRecord>> {
553 let opts = options.clone();
554 self.run_with_reconnect(move |d| {
555 let opts = opts.clone();
556 Box::pin(async move { d.download_history_with_options(opts).await })
557 })
558 .await
559 }
560
561 async fn get_interval(&self) -> Result<MeasurementInterval> {
562 self.run_with_reconnect(|d| Box::pin(d.get_interval()))
563 .await
564 }
565
566 async fn set_interval(&self, interval: MeasurementInterval) -> Result<()> {
567 self.run_with_reconnect(move |d| Box::pin(d.set_interval(interval)))
568 .await
569 }
570
571 async fn get_calibration(&self) -> Result<CalibrationData> {
572 self.run_with_reconnect(|d| Box::pin(d.get_calibration()))
573 .await
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_reconnect_options_default() {
583 let opts = ReconnectOptions::default();
584 assert_eq!(opts.max_attempts, Some(5));
585 assert!(opts.use_exponential_backoff);
586 }
587
588 #[test]
589 fn test_reconnect_options_unlimited() {
590 let opts = ReconnectOptions::unlimited();
591 assert!(opts.max_attempts.is_none());
592 }
593
594 #[test]
595 fn test_delay_calculation() {
596 let opts = ReconnectOptions {
597 initial_delay: Duration::from_secs(1),
598 max_delay: Duration::from_secs(60),
599 backoff_multiplier: 2.0,
600 use_exponential_backoff: true,
601 ..Default::default()
602 };
603
604 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(1));
605 assert_eq!(opts.delay_for_attempt(1), Duration::from_secs(2));
606 assert_eq!(opts.delay_for_attempt(2), Duration::from_secs(4));
607 assert_eq!(opts.delay_for_attempt(3), Duration::from_secs(8));
608 }
609
610 #[test]
611 fn test_delay_capped_at_max() {
612 let opts = ReconnectOptions {
613 initial_delay: Duration::from_secs(1),
614 max_delay: Duration::from_secs(10),
615 backoff_multiplier: 2.0,
616 use_exponential_backoff: true,
617 ..Default::default()
618 };
619
620 assert_eq!(opts.delay_for_attempt(10), Duration::from_secs(10));
622 }
623
624 #[test]
625 fn test_fixed_delay() {
626 let opts = ReconnectOptions::fixed_delay(Duration::from_secs(5));
627 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(5));
628 assert_eq!(opts.delay_for_attempt(5), Duration::from_secs(5));
629 }
630}