1 /++
2 Socket connection
3 +/
4 module socketplate.connection;
5
6 import socketplate.log;
7 import std.format;
8 import std.socket;
9
10 @safe:
11
12 ///
13 alias ConnectionHandler = void delegate(SocketConnection) @safe;
14
15 ///
16 enum socketERROR = Socket.ERROR;
17
18 ///
19 struct SocketConnection
20 {
21 @safe:
22
23 private
24 {
25 Socket _socket;
26 }
27
28 @disable private this();
29 @disable private this(this);
30
31 this(Socket socket) pure nothrow @nogc
32 {
33 _socket = socket;
34 }
35
36 /++
37 Determines whether the socket is still alive
38 +/
39 bool isAlive() const
40 {
41 return _socket.isAlive;
42 }
43
44 /++
45 Determines whether there is no more data to be received
46 +/
47 bool empty()
48 {
49 ubyte[1] tmp;
50 immutable ptrdiff_t bytesReceived = _socket.receive(tmp, SocketFlags.PEEK);
51 return ((bytesReceived == 0) || (bytesReceived == socketERROR));
52 }
53
54 /++
55 Closes the connection
56 +/
57 void close() nothrow @nogc
58 {
59 _socket.shutdown(SocketShutdown.BOTH);
60 _socket.close();
61 _socket = null;
62 }
63
64 /++
65 Reads received data into the provided buffer
66
67 Returns:
68 Number of bytes received
69 (`0` indicated that the connection got closed before receiving any bytes)
70
71 or [socketplate.connection.socketERROR|socketERROR] = on failure
72
73 Throws:
74 [SocketTimeoutException] on timeout
75 +/
76 ptrdiff_t receive(scope void[] buffer)
77 {
78 logTrace(format!"Receiving bytes (#%X)"(_socket.handle));
79 immutable ptrdiff_t result = _socket.receive(buffer);
80
81 if (result == socketERROR)
82 detectTimeout();
83
84 logTrace(format!"Received bytes: %d (#%X)"(result, _socket.handle));
85 return result;
86 }
87
88 /++
89 Reads received data into the provided buffer
90
91 Returns:
92 Slice of buffer containing the received data.
93
94 Length of `0` indicates that the connection has been closed
95
96 Throws:
97 $(LIST
98 * [SocketTimeoutException] on timeout
99 * [SocketException] on failure
100 )
101 +/
102 T[] receiveSlice(T)(return scope T[] buffer)
103 if (is(T == void) || is(T == ubyte) || is(T == char))
104 {
105 ptrdiff_t bytesReceived = this.receive(buffer);
106
107 if (bytesReceived == socketERROR)
108 throw new SocketException("An error occured while receiving data");
109
110 return buffer[0 .. bytesReceived];
111 }
112
113 /++
114 Fills the whole provided buffer with received data
115
116 Returns:
117 Slice of buffer containing the received data.
118
119 Length of `0` indicates that the connection has been closed
120
121 Throws:
122 $(LIST
123 * [SocketUnexpectedEndOfDataException] if there wasn't enough data to fill the whole buffer
124 * [SocketTimeoutException] on timeout
125 * [SocketException] on failure
126 )
127 +/
128 T[] receiveAll(T)(return scope T[] buffer)
129 if (is(T == void) || is(T == ubyte) || is(T == char))
130 {
131 ptrdiff_t bytesReceived = 0;
132 T[] bufferLeft = buffer;
133
134 while (bufferLeft.length > 0)
135 {
136 bytesReceived = this.receive(bufferLeft);
137
138 if (bytesReceived == socketERROR)
139 throw new SocketException("An error occured while receiving data");
140
141 if (bytesReceived == 0)
142 {
143 throw new SocketUnexpectedEndOfDataException(
144 "Connection was closed before all of the provided buffer could have been filled"
145 );
146 }
147
148 bufferLeft = bufferLeft[bytesReceived .. $];
149 }
150
151 return buffer;
152 }
153
154 /++
155 Sends data on the connection
156
157 Returns:
158 number of bytes sent
159
160 or [socketplate.connection.socketERROR|socketERROR] on failure
161
162 Throws:
163 $(LIST
164 * [SocketTimeoutException] on timeout
165 * [SocketException] on failure
166 )
167 +/
168 ptrdiff_t send(scope const(void)[] buffer)
169 {
170 logTrace(format!"Sending bytes (#%X)"(_socket.handle));
171 immutable ptrdiff_t result = _socket.send(buffer);
172
173 if (result == socketERROR)
174 detectTimeout();
175
176 logTrace(format!"Sent bytes: %d (#%X)"(result, _socket.handle));
177 return result;
178 }
179
180 /++
181 Sends all data from the passed slice on the connection
182
183 Throws:
184 [SocketException] on failure
185 +/
186 void sendAll(scope const(void)[] buffer)
187 {
188 ptrdiff_t bytesSent = 0;
189 const(void)[] bufferLeft = buffer;
190
191 while (bufferLeft.length > 0)
192 {
193 bytesSent = this.send(bufferLeft);
194
195 if (bytesSent < 0)
196 throw new SocketException("An error occured while sending data");
197
198 bufferLeft = bufferLeft[bytesSent .. $];
199 }
200 }
201
202 ///
203 Address remoteAddress()
204 {
205 return _socket.remoteAddress;
206 }
207
208 ///
209 Address localAddress()
210 {
211 return _socket.localAddress;
212 }
213
214 ///
215 string popCurrentError()
216 {
217 return _socket.getErrorText();
218 }
219
220 ///
221 long timeout(Direction direction)()
222 if (direction == Direction.send || direction == Direction.receive)
223 {
224 return _socket.getTimeout!direction();
225 }
226
227 ///
228 void timeout(Direction direction)(long seconds) if (direction != Direction.none)
229 {
230 return _socket.setTimeout!direction(seconds);
231 }
232 }
233
234 unittest
235 {
236 ubyte[] bufferDyn;
237 ubyte[4] bufferStat;
238
239 // dfmt off
240 static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferDyn)));
241 static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferStat)));
242
243 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferDyn); }));
244 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferStat); }));
245
246 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferDyn); }));
247 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferStat); }));
248
249 static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.receive); }));
250 static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.send); }));
251 static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.both); }));
252 static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.none); }));
253
254 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.both) = 90; }));
255 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.receive) = 90; }));
256 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.send) = 90; }));
257 static assert(!__traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.none) = 90; }));
258 // dfmt on
259 }
260
261 private void detectTimeout(string file = __FILE__, size_t line = __LINE__)
262 {
263 version (Posix)
264 {
265 import core.stdc.errno;
266
267 if (errno() == EAGAIN)
268 throw new SocketTimeoutException(file, line);
269 }
270 else version (Window)
271 {
272 import core.sys.windows.winsock2 : WSAETIMEDOUT, WSAGetLastError;
273
274 if (WSAGetLastError() == WSAETIMEDOUT)
275 throw new SocketTimeoutException(file, line);
276 }
277 }
278
279 ///
280 class SocketTimeoutException : SocketException
281 {
282 public this(string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc
283 {
284 super("Socket operation timed out", file, line);
285 }
286 }
287
288 ///
289 class SocketUnexpectedEndOfDataException : SocketException
290 {
291 public this(string message, string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc
292 {
293 super(message, file, line);
294 }
295 }
296
297 ///
298 enum Direction
299 {
300 ///
301 none = 0b00, ///
302 receive = 0b01, ///
303 send = 0b10, ///
304 both = (receive | send),
305 }
306
307 private
308 {
309 long getTimeout(Direction direction)(Socket socket)
310 if (direction == Direction.send || direction == Direction.receive)
311 {
312 import std.datetime : Duration;
313
314 enum SocketOption sockOpt = (direction == Direction.send)
315 ? SocketOption.RCVTIMEO : SocketOption.RCVTIMEO;
316
317 Duration result;
318 socket.getOption(SocketOptionLevel.SOCKET, sockOpt, result);
319 return result.total!"seconds";
320 }
321
322 void setTimeout(Direction direction)(Socket socket, long seconds)
323 if (direction != Direction.none)
324 {
325 import std.datetime : durSeconds = seconds;
326
327 static if (direction == Direction.both)
328 {
329 setTimeout!(Direction.send)(socket, seconds);
330 setTimeout!(Direction.receive)(socket, seconds);
331 }
332 else static if (direction == Direction.receive)
333 {
334 logTrace(format!"Setting receive timeout to %d seconds (#%X)"(seconds, socket.handle));
335 socket.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, durSeconds(seconds));
336 }
337 else static if (direction == Direction.send)
338 {
339 logTrace(format!"Setting send timeout to %d seconds (#%X)"(seconds, socket.handle));
340 socket.setOption(SocketOptionLevel.SOCKET, SocketOption.SNDTIMEO, durSeconds(seconds));
341 }
342 else
343 static assert(false, "Bug");
344 }
345 }