valence_network/
byte_channel.rs
1#![allow(dead_code)]
4
5use std::sync::{Arc, Mutex};
6
7use bytes::BytesMut;
8use thiserror::Error;
9use tokio::sync::Notify;
10
11pub(crate) fn byte_channel(limit: usize) -> (ByteSender, ByteReceiver) {
12 let shared = Arc::new(Shared {
13 mtx: Mutex::new(Inner {
14 bytes: BytesMut::new(),
15 disconnected: false,
16 }),
17 notify: Notify::new(),
18 limit,
19 });
20
21 let sender = ByteSender {
22 shared: shared.clone(),
23 };
24
25 let receiver = ByteReceiver { shared };
26
27 (sender, receiver)
28}
29
30pub(crate) struct ByteSender {
31 shared: Arc<Shared>,
32}
33
34pub(crate) struct ByteReceiver {
35 shared: Arc<Shared>,
36}
37
38struct Shared {
39 mtx: Mutex<Inner>,
40 notify: Notify,
41 limit: usize,
42}
43
44struct Inner {
45 bytes: BytesMut,
46 disconnected: bool,
47}
48
49impl ByteSender {
50 pub(crate) fn take_capacity(&mut self, additional: usize) -> BytesMut {
51 let mut lck = self.shared.mtx.lock().unwrap();
52
53 lck.bytes.reserve(additional);
54
55 let len = lck.bytes.len();
56 lck.bytes.split_off(len)
57 }
58
59 pub(crate) fn try_send(&mut self, mut bytes: BytesMut) -> Result<(), TrySendError> {
60 let mut lck = self.shared.mtx.lock().unwrap();
61
62 if lck.disconnected {
63 return Err(TrySendError::Disconnected(bytes));
64 }
65
66 if bytes.is_empty() {
67 return Ok(());
68 }
69
70 let available = self.shared.limit - lck.bytes.len();
71
72 if bytes.len() > available {
73 if available > 0 {
74 lck.bytes.unsplit(bytes.split_to(available));
75 self.shared.notify.notify_waiters();
76 }
77
78 return Err(TrySendError::Full(bytes));
79 }
80
81 lck.bytes.unsplit(bytes);
82 self.shared.notify.notify_waiters();
83
84 Ok(())
85 }
86
87 pub(crate) async fn send_async(&mut self, mut bytes: BytesMut) -> Result<(), SendError> {
88 loop {
89 {
90 let mut lck = self.shared.mtx.lock().unwrap();
91
92 if lck.disconnected {
93 return Err(SendError(bytes));
94 }
95
96 if bytes.is_empty() {
97 return Ok(());
98 }
99
100 let available = self.shared.limit - lck.bytes.len();
101
102 if bytes.len() <= available {
103 lck.bytes.unsplit(bytes);
104 self.shared.notify.notify_waiters();
105 return Ok(());
106 }
107
108 if available > 0 {
109 lck.bytes.unsplit(bytes.split_to(available));
110 self.shared.notify.notify_waiters();
111 }
112 }
113
114 self.shared.notify.notified().await;
115 }
116 }
117
118 pub(crate) fn is_disconnected(&self) -> bool {
119 self.shared.mtx.lock().unwrap().disconnected
120 }
121
122 pub(crate) fn limit(&self) -> usize {
123 self.shared.limit
124 }
125}
126
127#[derive(Clone, PartialEq, Eq, Debug, Error)]
129pub(crate) enum TrySendError {
130 #[error("sender disconnected")]
131 Disconnected(BytesMut),
132 #[error("channel full (see `Config::outgoing_capacity`)")]
133 Full(BytesMut),
134}
135
136#[derive(Clone, PartialEq, Eq, Debug, Error)]
137#[error("sender disconnected")]
138pub(crate) struct SendError(pub(crate) BytesMut);
139
140impl SendError {
141 pub(crate) fn into_inner(self) -> BytesMut {
142 self.0
143 }
144}
145
146impl ByteReceiver {
147 pub(crate) fn try_recv(&mut self) -> Result<BytesMut, TryRecvError> {
148 let mut lck = self.shared.mtx.lock().unwrap();
149
150 if !lck.bytes.is_empty() {
151 self.shared.notify.notify_waiters();
152 return Ok(lck.bytes.split());
153 }
154
155 if lck.disconnected {
156 return Err(TryRecvError::Disconnected);
157 }
158
159 Err(TryRecvError::Empty)
160 }
161
162 pub(crate) async fn recv_async(&mut self) -> Result<BytesMut, RecvError> {
163 loop {
164 {
165 let mut lck = self.shared.mtx.lock().unwrap();
166
167 if !lck.bytes.is_empty() {
168 self.shared.notify.notify_waiters();
169 return Ok(lck.bytes.split());
170 }
171
172 if lck.disconnected {
173 return Err(RecvError::Disconnected);
174 }
175 }
176
177 self.shared.notify.notified().await;
178 }
179 }
180
181 pub(crate) fn is_disconnected(&self) -> bool {
182 self.shared.mtx.lock().unwrap().disconnected
183 }
184
185 pub(crate) fn limit(&self) -> usize {
186 self.shared.limit
187 }
188}
189
190#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
191pub(crate) enum TryRecvError {
192 #[error("empty channel")]
193 Empty,
194 #[error("receiver disconnected")]
195 Disconnected,
196}
197
198#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
199pub(crate) enum RecvError {
200 #[error("receiver disconnected")]
201 Disconnected,
202}
203
204impl Drop for ByteSender {
205 fn drop(&mut self) {
206 self.shared.mtx.lock().unwrap().disconnected = true;
207 }
208}
209
210impl Drop for ByteReceiver {
211 fn drop(&mut self) {
212 self.shared.mtx.lock().unwrap().disconnected = true;
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn byte_channel_try() {
222 let (mut sender, mut receiver) = byte_channel(4);
223
224 assert_eq!(
225 sender.try_send("hello".as_bytes().into()),
226 Err(TrySendError::Full("o".as_bytes().into()))
227 );
228
229 assert_eq!(
230 receiver.try_recv().unwrap(),
231 BytesMut::from("hell".as_bytes())
232 );
233 }
234
235 #[tokio::test]
236 async fn byte_channel_async() {
237 let (mut sender, mut receiver) = byte_channel(4);
238
239 let t = tokio::spawn(async move {
240 let bytes = receiver.recv_async().await.unwrap();
241 assert_eq!(&bytes[..], b"hell");
242 let bytes = receiver.recv_async().await.unwrap();
243 assert_eq!(&bytes[..], b"o");
244
245 assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
246 });
247
248 sender.send_async("hello".as_bytes().into()).await.unwrap();
249
250 t.await.unwrap();
251
252 assert!(sender.is_disconnected());
253 }
254}