1use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::time::Duration;
12
13use async_trait::async_trait;
14use tokio::sync::RwLock;
15use tokio::time::sleep;
16use tracing::{info, warn};
17
18use aranet_types::{CurrentReading, DeviceInfo, DeviceType, HistoryRecord};
19
20use crate::device::Device;
21use crate::error::{Error, Result};
22use crate::events::{DeviceEvent, DeviceId, EventSender};
23use crate::history::{HistoryInfo, HistoryOptions};
24use crate::settings::{CalibrationData, MeasurementInterval};
25use crate::traits::AranetDevice;
26
27#[derive(Debug, Clone)]
29pub struct ReconnectOptions {
30 pub max_attempts: Option<u32>,
32 pub initial_delay: Duration,
34 pub max_delay: Duration,
36 pub backoff_multiplier: f64,
38 pub use_exponential_backoff: bool,
40}
41
42impl Default for ReconnectOptions {
43 fn default() -> Self {
44 Self {
45 max_attempts: Some(5),
46 initial_delay: Duration::from_secs(1),
47 max_delay: Duration::from_secs(60),
48 backoff_multiplier: 2.0,
49 use_exponential_backoff: true,
50 }
51 }
52}
53
54impl ReconnectOptions {
55 pub fn new() -> Self {
57 Self::default()
58 }
59
60 pub fn unlimited() -> Self {
62 Self {
63 max_attempts: None,
64 ..Default::default()
65 }
66 }
67
68 pub fn fixed_delay(delay: Duration) -> Self {
70 Self {
71 initial_delay: delay,
72 use_exponential_backoff: false,
73 ..Default::default()
74 }
75 }
76
77 pub fn max_attempts(mut self, attempts: u32) -> Self {
79 self.max_attempts = Some(attempts);
80 self
81 }
82
83 pub fn initial_delay(mut self, delay: Duration) -> Self {
85 self.initial_delay = delay;
86 self
87 }
88
89 pub fn max_delay(mut self, delay: Duration) -> Self {
91 self.max_delay = delay;
92 self
93 }
94
95 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
97 self.backoff_multiplier = multiplier;
98 self
99 }
100
101 pub fn exponential_backoff(mut self, enabled: bool) -> Self {
103 self.use_exponential_backoff = enabled;
104 self
105 }
106
107 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
109 if !self.use_exponential_backoff {
110 return self.initial_delay;
111 }
112
113 let delay_ms =
114 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
115 let delay = Duration::from_millis(delay_ms as u64);
116
117 delay.min(self.max_delay)
118 }
119
120 pub fn validate(&self) -> Result<()> {
127 if self.backoff_multiplier < 1.0 {
128 return Err(Error::InvalidConfig(
129 "backoff_multiplier must be >= 1.0".to_string(),
130 ));
131 }
132 if self.initial_delay.is_zero() {
133 return Err(Error::InvalidConfig(
134 "initial_delay must be > 0".to_string(),
135 ));
136 }
137 if self.max_delay < self.initial_delay {
138 return Err(Error::InvalidConfig(
139 "max_delay must be >= initial_delay".to_string(),
140 ));
141 }
142 Ok(())
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum ConnectionState {
149 Connected,
151 Disconnected,
153 Reconnecting,
155 Failed,
157}
158
159pub struct ReconnectingDevice {
165 identifier: String,
166 device: RwLock<Option<Arc<Device>>>,
168 options: ReconnectOptions,
169 state: RwLock<ConnectionState>,
170 event_sender: Option<EventSender>,
171 attempt_count: RwLock<u32>,
172 cancelled: Arc<AtomicBool>,
174 cached_name: std::sync::OnceLock<String>,
176 cached_device_type: std::sync::OnceLock<DeviceType>,
178}
179
180impl ReconnectingDevice {
181 pub async fn connect(identifier: &str, options: ReconnectOptions) -> Result<Self> {
183 let device = Arc::new(Device::connect(identifier).await?);
184
185 let cached_name = std::sync::OnceLock::new();
187 if let Some(name) = device.name() {
188 let _ = cached_name.set(name.to_string());
189 }
190
191 let cached_device_type = std::sync::OnceLock::new();
192 if let Some(device_type) = device.device_type() {
193 let _ = cached_device_type.set(device_type);
194 }
195
196 Ok(Self {
197 identifier: identifier.to_string(),
198 device: RwLock::new(Some(device)),
199 options,
200 state: RwLock::new(ConnectionState::Connected),
201 event_sender: None,
202 attempt_count: RwLock::new(0),
203 cancelled: Arc::new(AtomicBool::new(false)),
204 cached_name,
205 cached_device_type,
206 })
207 }
208
209 pub async fn connect_with_events(
211 identifier: &str,
212 options: ReconnectOptions,
213 event_sender: EventSender,
214 ) -> Result<Self> {
215 let mut this = Self::connect(identifier, options).await?;
216 this.event_sender = Some(event_sender);
217 Ok(this)
218 }
219
220 pub fn cancel_reconnect(&self) {
224 self.cancelled.store(true, Ordering::SeqCst);
225 }
226
227 pub fn is_cancelled(&self) -> bool {
229 self.cancelled.load(Ordering::SeqCst)
230 }
231
232 pub fn reset_cancellation(&self) {
239 self.cancelled.store(false, Ordering::SeqCst);
240 }
241
242 pub async fn state(&self) -> ConnectionState {
244 *self.state.read().await
245 }
246
247 pub async fn is_connected(&self) -> bool {
249 let guard = self.device.read().await;
250 if let Some(device) = guard.as_ref() {
251 device.is_connected().await
252 } else {
253 false
254 }
255 }
256
257 pub fn identifier(&self) -> &str {
259 &self.identifier
260 }
261
262 pub async fn with_device<F, Fut, T>(&self, f: F) -> Result<T>
274 where
275 F: Fn(&Device) -> Fut,
276 Fut: std::future::Future<Output = Result<T>>,
277 {
278 {
280 let guard = self.device.read().await;
281 if let Some(device) = guard.as_ref()
282 && device.is_connected().await
283 {
284 match f(device).await {
285 Ok(result) => return Ok(result),
286 Err(e) => {
287 warn!("Operation failed: {}", e);
288 }
290 }
291 }
292 }
293
294 self.reconnect().await?;
296
297 let guard = self.device.read().await;
299 if let Some(device) = guard.as_ref() {
300 f(device).await
301 } else {
302 Err(Error::NotConnected)
303 }
304 }
305
306 async fn run_with_reconnect<'a, T, F>(&'a self, f: F) -> Result<T>
316 where
317 F: for<'b> Fn(
318 &'b Device,
319 ) -> std::pin::Pin<
320 Box<dyn std::future::Future<Output = Result<T>> + Send + 'b>,
321 > + Send
322 + Sync,
323 T: Send,
324 {
325 {
327 let guard = self.device.read().await;
328 if let Some(device) = guard.as_ref()
329 && device.is_connected().await
330 {
331 match f(device).await {
332 Ok(result) => return Ok(result),
333 Err(e) => {
334 warn!("Operation failed: {}", e);
335 }
337 }
338 }
339 }
340
341 self.reconnect().await?;
343
344 let guard = self.device.read().await;
346 if let Some(device) = guard.as_ref() {
347 f(device).await
348 } else {
349 Err(Error::NotConnected)
350 }
351 }
352
353 pub async fn reconnect(&self) -> Result<()> {
362 if !self.is_cancelled() {
366 self.reset_cancellation();
367 }
368
369 *self.state.write().await = ConnectionState::Reconnecting;
370 *self.attempt_count.write().await = 0;
371
372 loop {
373 if self.is_cancelled() {
375 *self.state.write().await = ConnectionState::Disconnected;
376 info!("Reconnection cancelled for {}", self.identifier);
377 return Err(Error::Cancelled);
378 }
379
380 let attempt = {
381 let mut count = self.attempt_count.write().await;
382 *count += 1;
383 *count
384 };
385
386 if let Some(max) = self.options.max_attempts
388 && attempt > max
389 {
390 *self.state.write().await = ConnectionState::Failed;
391 return Err(Error::Timeout {
392 operation: format!("reconnect to '{}'", self.identifier),
393 duration: self.options.max_delay * max,
394 });
395 }
396
397 if let Some(sender) = &self.event_sender {
399 let _ = sender.send(DeviceEvent::ReconnectStarted {
400 device: DeviceId::new(&self.identifier),
401 attempt,
402 });
403 }
404
405 info!("Reconnection attempt {} for {}", attempt, self.identifier);
406
407 let delay = self.options.delay_for_attempt(attempt - 1);
409 sleep(delay).await;
410
411 if self.is_cancelled() {
413 *self.state.write().await = ConnectionState::Disconnected;
414 info!("Reconnection cancelled for {}", self.identifier);
415 return Err(Error::Cancelled);
416 }
417
418 match Device::connect(&self.identifier).await {
420 Ok(new_device) => {
421 *self.device.write().await = Some(Arc::new(new_device));
422 *self.state.write().await = ConnectionState::Connected;
423
424 if let Some(sender) = &self.event_sender {
426 let _ = sender.send(DeviceEvent::ReconnectSucceeded {
427 device: DeviceId::new(&self.identifier),
428 attempts: attempt,
429 });
430 }
431
432 info!("Reconnected successfully after {} attempts", attempt);
433 return Ok(());
434 }
435 Err(e) => {
436 warn!("Reconnection attempt {} failed: {}", attempt, e);
437 }
438 }
439 }
440 }
441
442 pub async fn disconnect(&self) -> Result<()> {
444 let mut guard = self.device.write().await;
445 if let Some(device) = guard.take() {
446 device.disconnect().await?;
447 }
448 *self.state.write().await = ConnectionState::Disconnected;
449 Ok(())
450 }
451
452 pub async fn attempt_count(&self) -> u32 {
454 *self.attempt_count.read().await
455 }
456
457 pub async fn name(&self) -> Option<String> {
459 let guard = self.device.read().await;
460 guard.as_ref().and_then(|d| d.name().map(|s| s.to_string()))
461 }
462
463 pub async fn address(&self) -> String {
465 let guard = self.device.read().await;
466 guard
467 .as_ref()
468 .map(|d| d.address().to_string())
469 .unwrap_or_else(|| self.identifier.clone())
470 }
471
472 pub async fn device_type(&self) -> Option<DeviceType> {
474 let guard = self.device.read().await;
475 guard.as_ref().and_then(|d| d.device_type())
476 }
477}
478
479#[async_trait]
481impl AranetDevice for ReconnectingDevice {
482 async fn is_connected(&self) -> bool {
483 ReconnectingDevice::is_connected(self).await
484 }
485
486 async fn connect(&self) -> Result<()> {
487 if self.is_connected().await {
489 return Ok(());
490 }
491 self.reconnect().await
493 }
494
495 async fn disconnect(&self) -> Result<()> {
496 ReconnectingDevice::disconnect(self).await
497 }
498
499 fn name(&self) -> Option<&str> {
500 self.cached_name.get().map(|s| s.as_str())
501 }
502
503 fn address(&self) -> &str {
504 &self.identifier
505 }
506
507 fn device_type(&self) -> Option<DeviceType> {
508 self.cached_device_type.get().copied()
509 }
510
511 async fn read_current(&self) -> Result<CurrentReading> {
512 self.run_with_reconnect(|d| Box::pin(d.read_current()))
513 .await
514 }
515
516 async fn read_device_info(&self) -> Result<DeviceInfo> {
517 self.run_with_reconnect(|d| Box::pin(d.read_device_info()))
518 .await
519 }
520
521 async fn read_rssi(&self) -> Result<i16> {
522 self.run_with_reconnect(|d| Box::pin(d.read_rssi())).await
523 }
524
525 async fn read_battery(&self) -> Result<u8> {
526 self.run_with_reconnect(|d| Box::pin(d.read_battery()))
527 .await
528 }
529
530 async fn get_history_info(&self) -> Result<HistoryInfo> {
531 self.run_with_reconnect(|d| Box::pin(d.get_history_info()))
532 .await
533 }
534
535 async fn download_history(&self) -> Result<Vec<HistoryRecord>> {
536 self.run_with_reconnect(|d| Box::pin(d.download_history()))
537 .await
538 }
539
540 async fn download_history_with_options(
541 &self,
542 options: HistoryOptions,
543 ) -> Result<Vec<HistoryRecord>> {
544 let opts = options.clone();
545 self.run_with_reconnect(move |d| {
546 let opts = opts.clone();
547 Box::pin(async move { d.download_history_with_options(opts).await })
548 })
549 .await
550 }
551
552 async fn get_interval(&self) -> Result<MeasurementInterval> {
553 self.run_with_reconnect(|d| Box::pin(d.get_interval()))
554 .await
555 }
556
557 async fn set_interval(&self, interval: MeasurementInterval) -> Result<()> {
558 self.run_with_reconnect(move |d| Box::pin(d.set_interval(interval)))
559 .await
560 }
561
562 async fn get_calibration(&self) -> Result<CalibrationData> {
563 self.run_with_reconnect(|d| Box::pin(d.get_calibration()))
564 .await
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_reconnect_options_default() {
574 let opts = ReconnectOptions::default();
575 assert_eq!(opts.max_attempts, Some(5));
576 assert!(opts.use_exponential_backoff);
577 }
578
579 #[test]
580 fn test_reconnect_options_unlimited() {
581 let opts = ReconnectOptions::unlimited();
582 assert!(opts.max_attempts.is_none());
583 }
584
585 #[test]
586 fn test_delay_calculation() {
587 let opts = ReconnectOptions {
588 initial_delay: Duration::from_secs(1),
589 max_delay: Duration::from_secs(60),
590 backoff_multiplier: 2.0,
591 use_exponential_backoff: true,
592 ..Default::default()
593 };
594
595 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(1));
596 assert_eq!(opts.delay_for_attempt(1), Duration::from_secs(2));
597 assert_eq!(opts.delay_for_attempt(2), Duration::from_secs(4));
598 assert_eq!(opts.delay_for_attempt(3), Duration::from_secs(8));
599 }
600
601 #[test]
602 fn test_delay_capped_at_max() {
603 let opts = ReconnectOptions {
604 initial_delay: Duration::from_secs(1),
605 max_delay: Duration::from_secs(10),
606 backoff_multiplier: 2.0,
607 use_exponential_backoff: true,
608 ..Default::default()
609 };
610
611 assert_eq!(opts.delay_for_attempt(10), Duration::from_secs(10));
613 }
614
615 #[test]
616 fn test_fixed_delay() {
617 let opts = ReconnectOptions::fixed_delay(Duration::from_secs(5));
618 assert_eq!(opts.delay_for_attempt(0), Duration::from_secs(5));
619 assert_eq!(opts.delay_for_attempt(5), Duration::from_secs(5));
620 }
621}