1use std::{
2 borrow::Cow,
3 cmp,
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use futures::FutureExt;
11use tokio::time::{Sleep, sleep};
12use tower::{retry::Policy, timeout::error::Elapsed};
13use vector_lib::configurable::configurable_component;
14
15use crate::Error;
16
17pub enum RetryAction<Request = ()> {
18 Retry(Cow<'static, str>),
20 RetryPartial(Box<dyn Fn(Request) -> Request + Send + Sync>),
22 DontRetry(Cow<'static, str>),
24 Successful,
26}
27
28pub trait RetryLogic: Clone + Send + Sync + 'static {
29 type Error: std::error::Error + Send + Sync + 'static;
30 type Request;
31 type Response;
32
33 fn is_retriable_error(&self, error: &Self::Error) -> bool;
36
37 fn is_retriable_timeout(&self) -> bool {
40 true
41 }
42
43 fn should_retry_response(&self, _response: &Self::Response) -> RetryAction<Self::Request> {
50 RetryAction::Successful
52 }
53
54 fn on_retriable_error(&self, _error: &Self::Error) {}
56}
57
58#[configurable_component]
60#[derive(Clone, Copy, Debug, Default)]
61pub enum JitterMode {
62 None,
64
65 #[default]
74 Full,
75}
76
77#[derive(Debug, Clone)]
78pub struct FibonacciRetryPolicy<L> {
79 remaining_attempts: usize,
80 previous_duration: Duration,
81 current_duration: Duration,
82 jitter_mode: JitterMode,
83 current_jitter_duration: Duration,
84 max_duration: Duration,
85 logic: L,
86}
87
88pub struct RetryPolicyFuture {
89 delay: Pin<Box<Sleep>>,
90}
91
92impl<L: RetryLogic> FibonacciRetryPolicy<L> {
93 pub fn new(
94 remaining_attempts: usize,
95 initial_backoff: Duration,
96 max_duration: Duration,
97 logic: L,
98 jitter_mode: JitterMode,
99 ) -> Self {
100 FibonacciRetryPolicy {
101 remaining_attempts,
102 previous_duration: Duration::from_secs(0),
103 current_duration: initial_backoff,
104 jitter_mode,
105 current_jitter_duration: Self::add_full_jitter(initial_backoff),
106 max_duration,
107 logic,
108 }
109 }
110
111 fn add_full_jitter(d: Duration) -> Duration {
112 let jitter = (rand::random::<u64>() % (d.as_millis() as u64)) + 1;
113 Duration::from_millis(jitter)
114 }
115
116 const fn backoff(&self) -> Duration {
117 match self.jitter_mode {
118 JitterMode::None => self.current_duration,
119 JitterMode::Full => self.current_jitter_duration,
120 }
121 }
122
123 fn advance(&mut self) {
124 let sum = self
125 .previous_duration
126 .checked_add(self.current_duration)
127 .unwrap_or(Duration::MAX);
128 let next_duration = cmp::min(sum, self.max_duration);
129 self.remaining_attempts = self.remaining_attempts.saturating_sub(1);
130 self.previous_duration = self.current_duration;
131 self.current_duration = next_duration;
132 self.current_jitter_duration = Self::add_full_jitter(next_duration);
133 }
134
135 fn build_retry(&mut self) -> RetryPolicyFuture {
136 self.advance();
137 let delay = Box::pin(sleep(self.backoff()));
138
139 debug!(message = "Retrying request.", delay_ms = %self.backoff().as_millis());
140 RetryPolicyFuture { delay }
141 }
142}
143
144impl<Req, Res, L> Policy<Req, Res, Error> for FibonacciRetryPolicy<L>
145where
146 Req: Clone + Send + 'static,
147 L: RetryLogic<Request = Req, Response = Res>,
148{
149 type Future = RetryPolicyFuture;
150
151 fn retry(&mut self, req: &mut Req, result: &mut Result<Res, Error>) -> Option<Self::Future> {
154 match result {
155 Ok(response) => match self.logic.should_retry_response(response) {
156 RetryAction::Retry(reason) => {
157 if self.remaining_attempts == 0 {
158 error!(
159 message = "OK/retry response but retries exhausted; dropping the request.",
160 reason = ?reason,
161 );
162 return None;
163 }
164
165 warn!(message = "Retrying after response.", reason = %reason);
166 Some(self.build_retry())
167 }
168 RetryAction::RetryPartial(modify_request) => {
169 if self.remaining_attempts == 0 {
170 error!(
171 message =
172 "OK/retry response but retries exhausted; dropping the request.",
173 );
174 return None;
175 }
176 *req = modify_request(req.clone());
177 warn!("OK/retrying partial after response.");
178 Some(self.build_retry())
179 }
180 RetryAction::DontRetry(reason) => {
181 error!(message = "Not retriable; dropping the request.", ?reason);
182 None
183 }
184
185 RetryAction::Successful => None,
186 },
187 Err(error) => {
188 if self.remaining_attempts == 0 {
189 error!(message = "Retries exhausted; dropping the request.", %error);
190 return None;
191 }
192
193 if let Some(expected) = error.downcast_ref::<L::Error>() {
194 if self.logic.is_retriable_error(expected) {
195 self.logic.on_retriable_error(expected);
196 warn!(message = "Retrying after error.", error = %expected);
197 Some(self.build_retry())
198 } else {
199 error!(
200 message = "Non-retriable error; dropping the request.",
201 %error,
202 );
203 None
204 }
205 } else if error.downcast_ref::<Elapsed>().is_some() {
206 if self.logic.is_retriable_timeout() {
207 warn!(
208 "Request timed out. If this happens often while the events are actually reaching their destination, try decreasing `batch.max_bytes` and/or using `compression` if applicable. Alternatively `request.timeout_secs` can be increased."
209 );
210 Some(self.build_retry())
211 } else {
212 error!(
213 message =
214 "Request timed out and is not retriable; dropping the request."
215 );
216 None
217 }
218 } else {
219 error!(
220 message = "Unexpected error type; dropping the request.",
221 %error
222 );
223 None
224 }
225 }
226 }
227 }
228
229 fn clone_request(&mut self, request: &Req) -> Option<Req> {
230 Some(request.clone())
231 }
232}
233
234impl Unpin for RetryPolicyFuture {}
237
238impl Future for RetryPolicyFuture {
239 type Output = ();
240
241 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242 std::task::ready!(self.delay.poll_unpin(cx));
243 Poll::Ready(())
244 }
245}
246
247impl<Request> RetryAction<Request> {
248 pub const fn is_retryable(&self) -> bool {
249 matches!(self, RetryAction::Retry(_) | RetryAction::RetryPartial(_))
250 }
251
252 pub const fn is_not_retryable(&self) -> bool {
253 matches!(self, RetryAction::DontRetry(_))
254 }
255
256 pub const fn is_successful(&self) -> bool {
257 matches!(self, RetryAction::Successful)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use std::{fmt, time::Duration};
264
265 use tokio::time;
266 use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task};
267 use tower::retry::RetryLayer;
268 use tower_test::{assert_request_eq, mock};
269
270 use super::*;
271 use crate::test_util::trace_init;
272
273 #[tokio::test]
274 async fn service_error_retry() {
275 trace_init();
276
277 time::pause();
278
279 let policy = FibonacciRetryPolicy::new(
280 5,
281 Duration::from_secs(1),
282 Duration::from_secs(10),
283 SvcRetryLogic,
284 JitterMode::None,
285 );
286
287 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
288
289 assert_ready_ok!(svc.poll_ready());
290
291 let fut = svc.call("hello");
292 let mut fut = task::spawn(fut);
293
294 assert_request_eq!(handle, "hello").send_error(Error(true));
295
296 assert_pending!(fut.poll());
297
298 time::advance(Duration::from_secs(2)).await;
299 assert_pending!(fut.poll());
300
301 assert_request_eq!(handle, "hello").send_response("world");
302 assert_eq!(fut.await.unwrap(), "world");
303 }
304
305 #[tokio::test]
306 async fn service_error_no_retry() {
307 trace_init();
308
309 let policy = FibonacciRetryPolicy::new(
310 5,
311 Duration::from_secs(1),
312 Duration::from_secs(10),
313 SvcRetryLogic,
314 JitterMode::None,
315 );
316
317 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
318
319 assert_ready_ok!(svc.poll_ready());
320
321 let mut fut = task::spawn(svc.call("hello"));
322 assert_request_eq!(handle, "hello").send_error(Error(false));
323 assert_ready_err!(fut.poll());
324 }
325
326 #[tokio::test]
327 async fn timeout_error() {
328 trace_init();
329
330 time::pause();
331
332 let policy = FibonacciRetryPolicy::new(
333 5,
334 Duration::from_secs(1),
335 Duration::from_secs(10),
336 SvcRetryLogic,
337 JitterMode::None,
338 );
339
340 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
341
342 assert_ready_ok!(svc.poll_ready());
343
344 let mut fut = task::spawn(svc.call("hello"));
345 assert_request_eq!(handle, "hello").send_error(Elapsed::new());
346 assert_pending!(fut.poll());
347
348 time::advance(Duration::from_secs(2)).await;
349 assert_pending!(fut.poll());
350
351 assert_request_eq!(handle, "hello").send_response("world");
352 assert_eq!(fut.await.unwrap(), "world");
353 }
354
355 #[tokio::test]
356 async fn timeout_error_no_retry() {
357 trace_init();
358
359 let policy = FibonacciRetryPolicy::new(
360 5,
361 Duration::from_secs(1),
362 Duration::from_secs(10),
363 NoTimeoutRetryLogic,
364 JitterMode::None,
365 );
366
367 let (mut svc, mut handle) = mock::spawn_layer(RetryLayer::new(policy));
368
369 assert_ready_ok!(svc.poll_ready());
370
371 let mut fut = task::spawn(svc.call("hello"));
372 assert_request_eq!(handle, "hello").send_error(Elapsed::new());
373 assert_ready_err!(fut.poll());
374 }
375
376 #[test]
377 fn backoff_grows_to_max() {
378 let mut policy = FibonacciRetryPolicy::new(
379 10,
380 Duration::from_secs(1),
381 Duration::from_secs(10),
382 SvcRetryLogic,
383 JitterMode::None,
384 );
385 assert_eq!(Duration::from_secs(1), policy.backoff());
386
387 policy.advance();
388 assert_eq!(Duration::from_secs(1), policy.backoff());
389
390 policy.advance();
391 assert_eq!(Duration::from_secs(2), policy.backoff());
392
393 policy.advance();
394 assert_eq!(Duration::from_secs(3), policy.backoff());
395
396 policy.advance();
397 assert_eq!(Duration::from_secs(5), policy.backoff());
398
399 policy.advance();
400 assert_eq!(Duration::from_secs(8), policy.backoff());
401
402 policy.advance();
403 assert_eq!(Duration::from_secs(10), policy.backoff());
404
405 policy.advance();
406 assert_eq!(Duration::from_secs(10), policy.backoff());
407 }
408
409 #[test]
410 fn backoff_grows_to_max_with_jitter() {
411 let max_duration = Duration::from_secs(10);
412 let mut policy = FibonacciRetryPolicy::new(
413 10,
414 Duration::from_secs(1),
415 max_duration,
416 SvcRetryLogic,
417 JitterMode::Full,
418 );
419
420 let expected_fib = [1, 1, 2, 3, 5, 8];
421
422 for (i, &exp_fib_secs) in expected_fib.iter().enumerate() {
423 let backoff = policy.backoff();
424 let upper_bound = Duration::from_secs(exp_fib_secs);
425
426 assert!(
428 !backoff.is_zero() && backoff <= upper_bound,
429 "Attempt {}: Expected backoff to be within 0 and {:?}, got {:?}",
430 i + 1,
431 upper_bound,
432 backoff
433 );
434
435 policy.advance();
436 }
437
438 for _ in 0..4 {
440 let backoff = policy.backoff();
441 assert!(
442 !backoff.is_zero() && backoff <= max_duration,
443 "Expected backoff to not exceed {max_duration:?}, got {backoff:?}"
444 );
445
446 policy.advance();
447 }
448 }
449
450 #[derive(Debug, Clone)]
451 struct SvcRetryLogic;
452
453 impl RetryLogic for SvcRetryLogic {
454 type Error = Error;
455 type Request = &'static str;
456 type Response = &'static str;
457
458 fn is_retriable_error(&self, error: &Self::Error) -> bool {
459 error.0
460 }
461 }
462
463 #[derive(Debug, Clone)]
464 struct NoTimeoutRetryLogic;
465
466 impl RetryLogic for NoTimeoutRetryLogic {
467 type Error = Error;
468 type Request = &'static str;
469 type Response = &'static str;
470
471 fn is_retriable_error(&self, error: &Self::Error) -> bool {
472 error.0
473 }
474
475 fn is_retriable_timeout(&self) -> bool {
476 false
477 }
478 }
479
480 #[derive(Debug)]
481 struct Error(bool);
482
483 impl fmt::Display for Error {
484 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485 write!(f, "error")
486 }
487 }
488
489 impl std::error::Error for Error {}
490}