valence_network/
byte_channel.rs

1//! A channel specifically for sending/receiving batches of bytes.
2
3#![allow(dead_code)]
4
5use std::sync::{Arc, Mutex};
6
7use bytes::BytesMut;
8use thiserror::Error;
9use tokio::sync::Notify;
10
11pub(crate) fn byte_channel(limit: usize) -> (ByteSender, ByteReceiver) {
12    let shared = Arc::new(Shared {
13        mtx: Mutex::new(Inner {
14            bytes: BytesMut::new(),
15            disconnected: false,
16        }),
17        notify: Notify::new(),
18        limit,
19    });
20
21    let sender = ByteSender {
22        shared: shared.clone(),
23    };
24
25    let receiver = ByteReceiver { shared };
26
27    (sender, receiver)
28}
29
30pub(crate) struct ByteSender {
31    shared: Arc<Shared>,
32}
33
34pub(crate) struct ByteReceiver {
35    shared: Arc<Shared>,
36}
37
38struct Shared {
39    mtx: Mutex<Inner>,
40    notify: Notify,
41    limit: usize,
42}
43
44struct Inner {
45    bytes: BytesMut,
46    disconnected: bool,
47}
48
49impl ByteSender {
50    pub(crate) fn take_capacity(&mut self, additional: usize) -> BytesMut {
51        let mut lck = self.shared.mtx.lock().unwrap();
52
53        lck.bytes.reserve(additional);
54
55        let len = lck.bytes.len();
56        lck.bytes.split_off(len)
57    }
58
59    pub(crate) fn try_send(&mut self, mut bytes: BytesMut) -> Result<(), TrySendError> {
60        let mut lck = self.shared.mtx.lock().unwrap();
61
62        if lck.disconnected {
63            return Err(TrySendError::Disconnected(bytes));
64        }
65
66        if bytes.is_empty() {
67            return Ok(());
68        }
69
70        let available = self.shared.limit - lck.bytes.len();
71
72        if bytes.len() > available {
73            if available > 0 {
74                lck.bytes.unsplit(bytes.split_to(available));
75                self.shared.notify.notify_waiters();
76            }
77
78            return Err(TrySendError::Full(bytes));
79        }
80
81        lck.bytes.unsplit(bytes);
82        self.shared.notify.notify_waiters();
83
84        Ok(())
85    }
86
87    pub(crate) async fn send_async(&mut self, mut bytes: BytesMut) -> Result<(), SendError> {
88        loop {
89            {
90                let mut lck = self.shared.mtx.lock().unwrap();
91
92                if lck.disconnected {
93                    return Err(SendError(bytes));
94                }
95
96                if bytes.is_empty() {
97                    return Ok(());
98                }
99
100                let available = self.shared.limit - lck.bytes.len();
101
102                if bytes.len() <= available {
103                    lck.bytes.unsplit(bytes);
104                    self.shared.notify.notify_waiters();
105                    return Ok(());
106                }
107
108                if available > 0 {
109                    lck.bytes.unsplit(bytes.split_to(available));
110                    self.shared.notify.notify_waiters();
111                }
112            }
113
114            self.shared.notify.notified().await;
115        }
116    }
117
118    pub(crate) fn is_disconnected(&self) -> bool {
119        self.shared.mtx.lock().unwrap().disconnected
120    }
121
122    pub(crate) fn limit(&self) -> usize {
123        self.shared.limit
124    }
125}
126
127/// Contains any excess bytes not sent.
128#[derive(Clone, PartialEq, Eq, Debug, Error)]
129pub(crate) enum TrySendError {
130    #[error("sender disconnected")]
131    Disconnected(BytesMut),
132    #[error("channel full (see `Config::outgoing_capacity`)")]
133    Full(BytesMut),
134}
135
136#[derive(Clone, PartialEq, Eq, Debug, Error)]
137#[error("sender disconnected")]
138pub(crate) struct SendError(pub(crate) BytesMut);
139
140impl SendError {
141    pub(crate) fn into_inner(self) -> BytesMut {
142        self.0
143    }
144}
145
146impl ByteReceiver {
147    pub(crate) fn try_recv(&mut self) -> Result<BytesMut, TryRecvError> {
148        let mut lck = self.shared.mtx.lock().unwrap();
149
150        if !lck.bytes.is_empty() {
151            self.shared.notify.notify_waiters();
152            return Ok(lck.bytes.split());
153        }
154
155        if lck.disconnected {
156            return Err(TryRecvError::Disconnected);
157        }
158
159        Err(TryRecvError::Empty)
160    }
161
162    pub(crate) async fn recv_async(&mut self) -> Result<BytesMut, RecvError> {
163        loop {
164            {
165                let mut lck = self.shared.mtx.lock().unwrap();
166
167                if !lck.bytes.is_empty() {
168                    self.shared.notify.notify_waiters();
169                    return Ok(lck.bytes.split());
170                }
171
172                if lck.disconnected {
173                    return Err(RecvError::Disconnected);
174                }
175            }
176
177            self.shared.notify.notified().await;
178        }
179    }
180
181    pub(crate) fn is_disconnected(&self) -> bool {
182        self.shared.mtx.lock().unwrap().disconnected
183    }
184
185    pub(crate) fn limit(&self) -> usize {
186        self.shared.limit
187    }
188}
189
190#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
191pub(crate) enum TryRecvError {
192    #[error("empty channel")]
193    Empty,
194    #[error("receiver disconnected")]
195    Disconnected,
196}
197
198#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
199pub(crate) enum RecvError {
200    #[error("receiver disconnected")]
201    Disconnected,
202}
203
204impl Drop for ByteSender {
205    fn drop(&mut self) {
206        self.shared.mtx.lock().unwrap().disconnected = true;
207    }
208}
209
210impl Drop for ByteReceiver {
211    fn drop(&mut self) {
212        self.shared.mtx.lock().unwrap().disconnected = true;
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn byte_channel_try() {
222        let (mut sender, mut receiver) = byte_channel(4);
223
224        assert_eq!(
225            sender.try_send("hello".as_bytes().into()),
226            Err(TrySendError::Full("o".as_bytes().into()))
227        );
228
229        assert_eq!(
230            receiver.try_recv().unwrap(),
231            BytesMut::from("hell".as_bytes())
232        );
233    }
234
235    #[tokio::test]
236    async fn byte_channel_async() {
237        let (mut sender, mut receiver) = byte_channel(4);
238
239        let t = tokio::spawn(async move {
240            let bytes = receiver.recv_async().await.unwrap();
241            assert_eq!(&bytes[..], b"hell");
242            let bytes = receiver.recv_async().await.unwrap();
243            assert_eq!(&bytes[..], b"o");
244
245            assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
246        });
247
248        sender.send_async("hello".as_bytes().into()).await.unwrap();
249
250        t.await.unwrap();
251
252        assert!(sender.is_disconnected());
253    }
254}