Skip to main content

tower/steer/
mod.rs

1//! This module provides functionality to aid managing routing requests between [`Service`]s.
2//!
3//! # Example
4//!
5//! [`Steer`] can for example be used to create a router, akin to what you might find in web
6//! frameworks.
7//!
8//! Here, `GET /` will be sent to the `root` service, while all other requests go to `not_found`.
9//!
10//! ```rust
11//! # use std::task::{Context, Poll, ready};
12//! # use tower_service::Service;
13//! # use tower::steer::Steer;
14//! # use tower::service_fn;
15//! # use tower::util::BoxService;
16//! # use tower::ServiceExt;
17//! # use std::convert::Infallible;
18//! use http::{Request, Response, StatusCode, Method};
19//!
20//! # #[tokio::main]
21//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! // Service that responds to `GET /`
23//! let root = service_fn(|req: Request<String>| async move {
24//!     # assert_eq!(req.uri().path(), "/");
25//!     let res = Response::new("Hello, World!".to_string());
26//!     Ok::<_, Infallible>(res)
27//! });
28//! // We have to box the service so its type gets erased and we can put it in a `Vec` with other
29//! // services
30//! let root = BoxService::new(root);
31//!
32//! // Service that responds with `404 Not Found` to all requests
33//! let not_found = service_fn(|req: Request<String>| async move {
34//!     let res = Response::builder()
35//!         .status(StatusCode::NOT_FOUND)
36//!         .body(String::new())
37//!         .expect("response is valid");
38//!     Ok::<_, Infallible>(res)
39//! });
40//! // Box that as well
41//! let not_found = BoxService::new(not_found);
42//!
43//! let mut svc = Steer::new(
44//!     // All services we route between
45//!     vec![root, not_found],
46//!     // How we pick which service to send the request to
47//!     |req: &Request<String>, _services: &[_]| {
48//!         if req.method() == Method::GET && req.uri().path() == "/" {
49//!             0 // Index of `root`
50//!         } else {
51//!             1 // Index of `not_found`
52//!         }
53//!     },
54//! );
55//!
56//! // This request will get sent to `root`
57//! let req = Request::get("/").body(String::new()).unwrap();
58//! let res = svc.ready().await?.call(req).await?;
59//! assert_eq!(res.into_body(), "Hello, World!");
60//!
61//! // This request will get sent to `not_found`
62//! let req = Request::get("/does/not/exist").body(String::new()).unwrap();
63//! let res = svc.ready().await?.call(req).await?;
64//! assert_eq!(res.status(), StatusCode::NOT_FOUND);
65//! assert_eq!(res.into_body(), "");
66//! #
67//! # Ok(())
68//! # }
69//! ```
70use std::task::{Context, Poll};
71use std::{collections::VecDeque, fmt, marker::PhantomData};
72use tower_service::Service;
73
74/// This is how callers of [`Steer`] tell it which `Service` a `Req` corresponds to.
75pub trait Picker<S, Req> {
76    /// Return an index into the iterator of `Service` passed to [`Steer::new`].
77    fn pick(&mut self, r: &Req, services: &[S]) -> usize;
78}
79
80impl<S, F, Req> Picker<S, Req> for F
81where
82    F: Fn(&Req, &[S]) -> usize,
83{
84    fn pick(&mut self, r: &Req, services: &[S]) -> usize {
85        self(r, services)
86    }
87}
88
89/// [`Steer`] manages a list of [`Service`]s which all handle the same type of request.
90///
91/// An example use case is a sharded service.
92/// It accepts new requests, then:
93/// 1. Determines, via the provided [`Picker`], which [`Service`] the request corresponds to.
94/// 2. Waits (in [`Service::poll_ready`]) for *all* services to be ready.
95/// 3. Calls the correct [`Service`] with the request, and returns a future corresponding to the
96///    call.
97///
98/// Note that [`Steer`] must wait for all services to be ready since it can't know ahead of time
99/// which [`Service`] the next message will arrive for, and is unwilling to buffer items
100/// indefinitely. This will cause head-of-line blocking unless paired with a [`Service`] that does
101/// buffer items indefinitely, and thus always returns [`Poll::Ready`]. For example, wrapping each
102/// component service with a [`Buffer`] with a high enough limit (the maximum number of concurrent
103/// requests) will prevent head-of-line blocking in [`Steer`].
104///
105/// [`Buffer`]: crate::buffer::Buffer
106pub struct Steer<S, F, Req> {
107    router: F,
108    services: Vec<S>,
109    not_ready: VecDeque<usize>,
110    _phantom: PhantomData<Req>,
111}
112
113impl<S, F, Req> Steer<S, F, Req> {
114    /// Make a new [`Steer`] with a list of [`Service`]'s and a [`Picker`].
115    ///
116    /// # Panics
117    ///
118    /// Panics if the `services` collection is empty
119    ///
120    /// Note: the order of the [`Service`]'s is significant for [`Picker::pick`]'s return value.
121    pub fn new(services: impl IntoIterator<Item = S>, router: F) -> Self {
122        let services: Vec<_> = services.into_iter().collect();
123        assert!(
124            !services.is_empty(),
125            "steer must contain at least one service"
126        );
127        let not_ready: VecDeque<_> = services.iter().enumerate().map(|(i, _)| i).collect();
128        Self {
129            router,
130            services,
131            not_ready,
132            _phantom: PhantomData,
133        }
134    }
135}
136
137impl<S, Req, F> Service<Req> for Steer<S, F, Req>
138where
139    S: Service<Req>,
140    F: Picker<S, Req>,
141{
142    type Response = S::Response;
143    type Error = S::Error;
144    type Future = S::Future;
145
146    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
147        loop {
148            // must wait for *all* services to be ready.
149            // this will cause head-of-line blocking unless the underlying services are always ready.
150            if self.not_ready.is_empty() {
151                return Poll::Ready(Ok(()));
152            } else {
153                if self.services[self.not_ready[0]]
154                    .poll_ready(cx)?
155                    .is_pending()
156                {
157                    return Poll::Pending;
158                }
159
160                self.not_ready.pop_front();
161            }
162        }
163    }
164
165    fn call(&mut self, req: Req) -> Self::Future {
166        assert!(
167            self.not_ready.is_empty(),
168            "Steer must wait for all services to be ready. Did you forget to call poll_ready()?"
169        );
170
171        let idx = self.router.pick(&req, &self.services[..]);
172        let cl = &mut self.services[idx];
173        self.not_ready.push_back(idx);
174        cl.call(req)
175    }
176}
177
178impl<S, F, Req> Clone for Steer<S, F, Req>
179where
180    S: Clone,
181    F: Clone,
182{
183    fn clone(&self) -> Self {
184        Self {
185            router: self.router.clone(),
186            services: self.services.clone(),
187            not_ready: self.not_ready.clone(),
188            _phantom: PhantomData,
189        }
190    }
191}
192
193impl<S, F, Req> fmt::Debug for Steer<S, F, Req>
194where
195    S: fmt::Debug,
196    F: fmt::Debug,
197{
198    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
199        let Self {
200            router,
201            services,
202            not_ready,
203            _phantom,
204        } = self;
205        f.debug_struct("Steer")
206            .field("router", router)
207            .field("services", services)
208            .field("not_ready", not_ready)
209            .finish()
210    }
211}