1 /++
2     Socket connection
3  +/
4 module socketplate.connection;
5 
6 import socketplate.log;
7 import std.format;
8 import std.socket;
9 
10 @safe:
11 
12 ///
13 alias ConnectionHandler = void delegate(SocketConnection) @safe;
14 
15 ///
16 enum socketERROR = Socket.ERROR;
17 
18 ///
19 struct SocketConnection
20 {
21 @safe:
22 
23     private
24     {
25         Socket _socket;
26     }
27 
28     @disable private this();
29     @disable private this(this);
30 
31     this(Socket socket) pure nothrow @nogc
32     {
33         _socket = socket;
34     }
35 
36     /++
37         Determines whether the socket is still alive
38      +/
39     bool isAlive() const
40     {
41         return _socket.isAlive;
42     }
43 
44     /++
45         Determines whether there is no more data to be received
46      +/
47     bool empty()
48     {
49         ubyte[1] tmp;
50         immutable ptrdiff_t bytesReceived = _socket.receive(tmp, SocketFlags.PEEK);
51         return ((bytesReceived == 0) || (bytesReceived == socketERROR));
52     }
53 
54     /++
55         Closes the connection
56      +/
57     void close() nothrow @nogc
58     {
59         _socket.shutdown(SocketShutdown.BOTH);
60         _socket.close();
61         _socket = null;
62     }
63 
64     /++
65         Reads received data into the provided buffer
66 
67         Returns:
68             Number of bytes received
69             (`0` indicated that the connection got closed before receiving any bytes)
70 
71             or [socketplate.connection.socketERROR|socketERROR] = on failure
72 
73         Throws:
74             [SocketTimeoutException] on timeout
75      +/
76     ptrdiff_t receive(scope void[] buffer)
77     {
78         logTrace(format!"Receiving bytes (#%X)"(_socket.handle));
79         immutable ptrdiff_t result = _socket.receive(buffer);
80 
81         if (result == socketERROR)
82             detectTimeout();
83 
84         logTrace(format!"Received bytes: %d (#%X)"(result, _socket.handle));
85         return result;
86     }
87 
88     /++
89         Reads received data into the provided buffer
90 
91         Returns:
92             Slice of buffer containing the received data.
93 
94             Length of `0` indicates that the connection has been closed
95 
96         Throws:
97             $(LIST
98                 * [SocketTimeoutException] on timeout
99                 * [SocketException] on failure
100             )
101      +/
102     T[] receiveSlice(T)(return scope T[] buffer)
103             if (is(T == void) || is(T == ubyte) || is(T == char))
104     {
105         ptrdiff_t bytesReceived = this.receive(buffer);
106 
107         if (bytesReceived == socketERROR)
108             throw new SocketException("An error occured while receiving data");
109 
110         return buffer[0 .. bytesReceived];
111     }
112 
113     /++
114         Fills the whole provided buffer with received data
115 
116         Returns:
117             Slice of buffer containing the received data.
118 
119             Length of `0` indicates that the connection has been closed
120 
121         Throws:
122             $(LIST
123                 * [SocketUnexpectedEndOfDataException] if there wasn't enough data to fill the whole buffer
124                 * [SocketTimeoutException] on timeout
125                 * [SocketException] on failure
126             )
127      +/
128     T[] receiveAll(T)(return scope T[] buffer)
129             if (is(T == void) || is(T == ubyte) || is(T == char))
130     {
131         ptrdiff_t bytesReceived = 0;
132         T[] bufferLeft = buffer;
133 
134         while (bufferLeft.length > 0)
135         {
136             bytesReceived = this.receive(bufferLeft);
137 
138             if (bytesReceived == socketERROR)
139                 throw new SocketException("An error occured while receiving data");
140 
141             if (bytesReceived == 0)
142             {
143                 throw new SocketUnexpectedEndOfDataException(
144                     "Connection was closed before all of the provided buffer could have been filled"
145                 );
146             }
147 
148             bufferLeft = bufferLeft[bytesReceived .. $];
149         }
150 
151         return buffer;
152     }
153 
154     /++
155         Sends data on the connection
156 
157         Returns:
158             number of bytes sent
159 
160             or [socketplate.connection.socketERROR|socketERROR] on failure
161 
162         Throws:
163             $(LIST
164                 * [SocketTimeoutException] on timeout
165                 * [SocketException] on failure
166             )
167      +/
168     ptrdiff_t send(scope const(void)[] buffer)
169     {
170         logTrace(format!"Sending bytes (#%X)"(_socket.handle));
171         immutable ptrdiff_t result = _socket.send(buffer);
172 
173         if (result == socketERROR)
174             detectTimeout();
175 
176         logTrace(format!"Sent bytes: %d (#%X)"(result, _socket.handle));
177         return result;
178     }
179 
180     /++
181         Sends all data from the passed slice on the connection
182 
183         Throws:
184             [SocketException] on failure
185      +/
186     void sendAll(scope const(void)[] buffer)
187     {
188         ptrdiff_t bytesSent = 0;
189         const(void)[] bufferLeft = buffer;
190 
191         while (bufferLeft.length > 0)
192         {
193             bytesSent = this.send(bufferLeft);
194 
195             if (bytesSent < 0)
196                 throw new SocketException("An error occured while sending data");
197 
198             bufferLeft = bufferLeft[bytesSent .. $];
199         }
200     }
201 
202     ///
203     Address remoteAddress()
204     {
205         return _socket.remoteAddress;
206     }
207 
208     ///
209     Address localAddress()
210     {
211         return _socket.localAddress;
212     }
213 
214     ///
215     string popCurrentError()
216     {
217         return _socket.getErrorText();
218     }
219 
220     ///
221     long timeout(Direction direction)()
222             if (direction == Direction.send || direction == Direction.receive)
223     {
224         return _socket.getTimeout!direction();
225     }
226 
227     ///
228     void timeout(Direction direction)(long seconds) if (direction != Direction.none)
229     {
230         return _socket.setTimeout!direction(seconds);
231     }
232 }
233 
234 unittest
235 {
236     ubyte[] bufferDyn;
237     ubyte[4] bufferStat;
238 
239     // dfmt off
240     static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferDyn)));
241     static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferStat)));
242 
243     static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferDyn); }));
244     static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferStat); }));
245 
246     static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferDyn); }));
247     static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferStat); }));
248 
249     static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.receive); }));
250     static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.send); }));
251     static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.both); }));
252     static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.none); }));
253 
254     static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.both) = 90; }));
255     static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.receive) = 90; }));
256     static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.send) = 90; }));
257     static assert(!__traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.none) = 90; }));
258     // dfmt on
259 }
260 
261 private void detectTimeout(string file = __FILE__, size_t line = __LINE__)
262 {
263     version (Posix)
264     {
265         import core.stdc.errno;
266 
267         if (errno() == EAGAIN)
268             throw new SocketTimeoutException(file, line);
269     }
270     else version (Window)
271     {
272         import core.sys.windows.winsock2 : WSAETIMEDOUT, WSAGetLastError;
273 
274         if (WSAGetLastError() == WSAETIMEDOUT)
275             throw new SocketTimeoutException(file, line);
276     }
277 }
278 
279 ///
280 class SocketTimeoutException : SocketException
281 {
282     public this(string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc
283     {
284         super("Socket operation timed out", file, line);
285     }
286 }
287 
288 ///
289 class SocketUnexpectedEndOfDataException : SocketException
290 {
291     public this(string message, string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc
292     {
293         super(message, file, line);
294     }
295 }
296 
297 ///
298 enum Direction
299 {
300     ///
301     none = 0b00, ///
302     receive = 0b01, ///
303     send = 0b10, ///
304     both = (receive | send),
305 }
306 
307 private
308 {
309     long getTimeout(Direction direction)(Socket socket)
310             if (direction == Direction.send || direction == Direction.receive)
311     {
312         import std.datetime : Duration;
313 
314         enum SocketOption sockOpt = (direction == Direction.send)
315             ? SocketOption.RCVTIMEO : SocketOption.RCVTIMEO;
316 
317         Duration result;
318         socket.getOption(SocketOptionLevel.SOCKET, sockOpt, result);
319         return result.total!"seconds";
320     }
321 
322     void setTimeout(Direction direction)(Socket socket, long seconds)
323             if (direction != Direction.none)
324     {
325         import std.datetime : durSeconds = seconds;
326 
327         static if (direction == Direction.both)
328         {
329             setTimeout!(Direction.send)(socket, seconds);
330             setTimeout!(Direction.receive)(socket, seconds);
331         }
332         else static if (direction == Direction.receive)
333         {
334             logTrace(format!"Setting receive timeout to %d seconds (#%X)"(seconds, socket.handle));
335             socket.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, durSeconds(seconds));
336         }
337         else static if (direction == Direction.send)
338         {
339             logTrace(format!"Setting send timeout to %d seconds (#%X)"(seconds, socket.handle));
340             socket.setOption(SocketOptionLevel.SOCKET, SocketOption.SNDTIMEO, durSeconds(seconds));
341         }
342         else
343             static assert(false, "Bug");
344     }
345 }