tokio_tungstenite/
compat.rs1use log::*;
2use std::{
3 io::{Read, Write},
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use futures_util::task;
9use std::sync::Arc;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14 Read,
15 Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20 inner: S,
21 write_waker_proxy: Arc<WakerProxy>,
45 read_waker_proxy: Arc<WakerProxy>,
46}
47
48pub(crate) trait SetWaker {
54 fn set_waker(&self, waker: &task::Waker);
55}
56
57impl<S> SetWaker for AllowStd<S> {
58 fn set_waker(&self, waker: &task::Waker) {
59 self.set_waker(ContextWaker::Read, waker);
60 }
61}
62
63impl<S> AllowStd<S> {
64 pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
65 let res = Self {
66 inner,
67 write_waker_proxy: Default::default(),
68 read_waker_proxy: Default::default(),
69 };
70
71 res.write_waker_proxy.read_waker.register(waker);
74 res.read_waker_proxy.read_waker.register(waker);
75
76 res
77 }
78
79 pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
88 match kind {
89 ContextWaker::Read => {
90 self.write_waker_proxy.read_waker.register(waker);
91 self.read_waker_proxy.read_waker.register(waker);
92 }
93 ContextWaker::Write => {
94 self.write_waker_proxy.write_waker.register(waker);
95 self.read_waker_proxy.write_waker.register(waker);
96 }
97 }
98 }
99}
100
101#[derive(Debug, Default)]
106struct WakerProxy {
107 read_waker: task::AtomicWaker,
108 write_waker: task::AtomicWaker,
109}
110
111impl task::ArcWake for WakerProxy {
112 fn wake_by_ref(arc_self: &Arc<Self>) {
113 arc_self.read_waker.wake();
114 arc_self.write_waker.wake();
115 }
116}
117
118impl<S> AllowStd<S>
119where
120 S: Unpin,
121{
122 fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
123 where
124 F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
125 {
126 trace!("{}:{} AllowStd.with_context", file!(), line!());
127 let waker = match kind {
128 ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
129 ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
130 };
131 let mut context = task::Context::from_waker(&waker);
132 f(&mut context, Pin::new(&mut self.inner))
133 }
134
135 pub(crate) fn get_mut(&mut self) -> &mut S {
136 &mut self.inner
137 }
138
139 pub(crate) fn get_ref(&self) -> &S {
140 &self.inner
141 }
142}
143
144impl<S> Read for AllowStd<S>
145where
146 S: AsyncRead + Unpin,
147{
148 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
149 trace!("{}:{} Read.read", file!(), line!());
150 let mut buf = ReadBuf::new(buf);
151 match self.with_context(ContextWaker::Read, |ctx, stream| {
152 trace!("{}:{} Read.with_context read -> poll_read", file!(), line!());
153 stream.poll_read(ctx, &mut buf)
154 }) {
155 Poll::Ready(Ok(_)) => Ok(buf.filled().len()),
156 Poll::Ready(Err(err)) => Err(err),
157 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
158 }
159 }
160}
161
162impl<S> Write for AllowStd<S>
163where
164 S: AsyncWrite + Unpin,
165{
166 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167 trace!("{}:{} Write.write", file!(), line!());
168 match self.with_context(ContextWaker::Write, |ctx, stream| {
169 trace!("{}:{} Write.with_context write -> poll_write", file!(), line!());
170 stream.poll_write(ctx, buf)
171 }) {
172 Poll::Ready(r) => r,
173 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
174 }
175 }
176
177 fn flush(&mut self) -> std::io::Result<()> {
178 trace!("{}:{} Write.flush", file!(), line!());
179 match self.with_context(ContextWaker::Write, |ctx, stream| {
180 trace!("{}:{} Write.with_context flush -> poll_flush", file!(), line!());
181 stream.poll_flush(ctx)
182 }) {
183 Poll::Ready(r) => r,
184 Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
185 }
186 }
187}
188
189pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
190 match r {
191 Ok(v) => Poll::Ready(Ok(v)),
192 Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
193 trace!("WouldBlock");
194 Poll::Pending
195 }
196 Err(e) => Poll::Ready(Err(e)),
197 }
198}