1 /++ 2 Socket connection 3 +/ 4 module socketplate.connection; 5 6 import socketplate.log; 7 import std.format; 8 import std.socket; 9 10 @safe: 11 12 /// 13 alias ConnectionHandler = void delegate(SocketConnection) @safe; 14 15 /// 16 enum socketERROR = Socket.ERROR; 17 18 /// 19 struct SocketConnection 20 { 21 @safe: 22 23 private 24 { 25 Socket _socket; 26 } 27 28 @disable private this(); 29 @disable private this(this); 30 31 this(Socket socket) pure nothrow @nogc 32 { 33 _socket = socket; 34 } 35 36 /++ 37 Determines whether the socket is still alive 38 +/ 39 bool isAlive() const 40 { 41 return _socket.isAlive; 42 } 43 44 /++ 45 Determines whether there is no more data to be received 46 +/ 47 bool empty() 48 { 49 ubyte[1] tmp; 50 immutable ptrdiff_t bytesReceived = _socket.receive(tmp, SocketFlags.PEEK); 51 return ((bytesReceived == 0) || (bytesReceived == socketERROR)); 52 } 53 54 /++ 55 Closes the connection 56 +/ 57 void close() nothrow @nogc 58 { 59 _socket.shutdown(SocketShutdown.BOTH); 60 _socket.close(); 61 _socket = null; 62 } 63 64 /++ 65 Reads received data into the provided buffer 66 67 Returns: 68 Number of bytes received 69 (`0` indicated that the connection got closed before receiving any bytes) 70 71 or [socketplate.connection.socketERROR|socketERROR] = on failure 72 73 Throws: 74 [SocketTimeoutException] on timeout 75 +/ 76 ptrdiff_t receive(scope void[] buffer) 77 { 78 logTrace(format!"Receiving bytes (#%X)"(_socket.handle)); 79 immutable ptrdiff_t result = _socket.receive(buffer); 80 81 if (result == socketERROR) 82 detectTimeout(); 83 84 logTrace(format!"Received bytes: %d (#%X)"(result, _socket.handle)); 85 return result; 86 } 87 88 /++ 89 Reads received data into the provided buffer 90 91 Returns: 92 Slice of buffer containing the received data. 93 94 Length of `0` indicates that the connection has been closed 95 96 Throws: 97 $(LIST 98 * [SocketTimeoutException] on timeout 99 * [SocketException] on failure 100 ) 101 +/ 102 T[] receiveSlice(T)(return scope T[] buffer) 103 if (is(T == void) || is(T == ubyte) || is(T == char)) 104 { 105 ptrdiff_t bytesReceived = this.receive(buffer); 106 107 if (bytesReceived == socketERROR) 108 throw new SocketException("An error occured while receiving data"); 109 110 return buffer[0 .. bytesReceived]; 111 } 112 113 /++ 114 Fills the whole provided buffer with received data 115 116 Returns: 117 Slice of buffer containing the received data. 118 119 Length of `0` indicates that the connection has been closed 120 121 Throws: 122 $(LIST 123 * [SocketUnexpectedEndOfDataException] if there wasn't enough data to fill the whole buffer 124 * [SocketTimeoutException] on timeout 125 * [SocketException] on failure 126 ) 127 +/ 128 T[] receiveAll(T)(return scope T[] buffer) 129 if (is(T == void) || is(T == ubyte) || is(T == char)) 130 { 131 ptrdiff_t bytesReceived = 0; 132 T[] bufferLeft = buffer; 133 134 while (bufferLeft.length > 0) 135 { 136 bytesReceived = this.receive(bufferLeft); 137 138 if (bytesReceived == socketERROR) 139 throw new SocketException("An error occured while receiving data"); 140 141 if (bytesReceived == 0) 142 { 143 throw new SocketUnexpectedEndOfDataException( 144 "Connection was closed before all of the provided buffer could have been filled" 145 ); 146 } 147 148 bufferLeft = bufferLeft[bytesReceived .. $]; 149 } 150 151 return buffer; 152 } 153 154 /++ 155 Sends data on the connection 156 157 Returns: 158 number of bytes sent 159 160 or [socketplate.connection.socketERROR|socketERROR] on failure 161 162 Throws: 163 $(LIST 164 * [SocketTimeoutException] on timeout 165 * [SocketException] on failure 166 ) 167 +/ 168 ptrdiff_t send(scope const(void)[] buffer) 169 { 170 logTrace(format!"Sending bytes (#%X)"(_socket.handle)); 171 immutable ptrdiff_t result = _socket.send(buffer); 172 173 if (result == socketERROR) 174 detectTimeout(); 175 176 logTrace(format!"Sent bytes: %d (#%X)"(result, _socket.handle)); 177 return result; 178 } 179 180 /++ 181 Sends all data from the passed slice on the connection 182 183 Throws: 184 [SocketException] on failure 185 +/ 186 void sendAll(scope const(void)[] buffer) 187 { 188 ptrdiff_t bytesSent = 0; 189 const(void)[] bufferLeft = buffer; 190 191 while (bufferLeft.length > 0) 192 { 193 bytesSent = this.send(bufferLeft); 194 195 if (bytesSent < 0) 196 throw new SocketException("An error occured while sending data"); 197 198 bufferLeft = bufferLeft[bytesSent .. $]; 199 } 200 } 201 202 /// 203 Address remoteAddress() 204 { 205 return _socket.remoteAddress; 206 } 207 208 /// 209 Address localAddress() 210 { 211 return _socket.localAddress; 212 } 213 214 /// 215 string popCurrentError() 216 { 217 return _socket.getErrorText(); 218 } 219 220 /// 221 long timeout(Direction direction)() 222 if (direction == Direction.send || direction == Direction.receive) 223 { 224 return _socket.getTimeout!direction(); 225 } 226 227 /// 228 void timeout(Direction direction)(long seconds) if (direction != Direction.none) 229 { 230 return _socket.setTimeout!direction(seconds); 231 } 232 } 233 234 unittest 235 { 236 ubyte[] bufferDyn; 237 ubyte[4] bufferStat; 238 239 // dfmt off 240 static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferDyn))); 241 static assert(__traits(compiles, (SocketConnection sc) => sc.sendAll(bufferStat))); 242 243 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferDyn); })); 244 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveSlice(bufferStat); })); 245 246 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferDyn); })); 247 static assert(__traits(compiles, (SocketConnection sc) { ubyte[] r = sc.receiveAll(bufferStat); })); 248 249 static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.receive); })); 250 static assert( __traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.send); })); 251 static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.both); })); 252 static assert(!__traits(compiles, (SocketConnection sc) { long t = sc.timeout!(Direction.none); })); 253 254 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.both) = 90; })); 255 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.receive) = 90; })); 256 static assert( __traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.send) = 90; })); 257 static assert(!__traits(compiles, (SocketConnection sc) { sc.timeout!(Direction.none) = 90; })); 258 // dfmt on 259 } 260 261 private void detectTimeout(string file = __FILE__, size_t line = __LINE__) 262 { 263 version (Posix) 264 { 265 import core.stdc.errno; 266 267 if (errno() == EAGAIN) 268 throw new SocketTimeoutException(file, line); 269 } 270 else version (Window) 271 { 272 import core.sys.windows.winsock2 : WSAETIMEDOUT, WSAGetLastError; 273 274 if (WSAGetLastError() == WSAETIMEDOUT) 275 throw new SocketTimeoutException(file, line); 276 } 277 } 278 279 /// 280 class SocketTimeoutException : SocketException 281 { 282 public this(string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc 283 { 284 super("Socket operation timed out", file, line); 285 } 286 } 287 288 /// 289 class SocketUnexpectedEndOfDataException : SocketException 290 { 291 public this(string message, string file = __FILE__, size_t line = __LINE__) @safe pure nothrow @nogc 292 { 293 super(message, file, line); 294 } 295 } 296 297 /// 298 enum Direction 299 { 300 /// 301 none = 0b00, /// 302 receive = 0b01, /// 303 send = 0b10, /// 304 both = (receive | send), 305 } 306 307 private 308 { 309 long getTimeout(Direction direction)(Socket socket) 310 if (direction == Direction.send || direction == Direction.receive) 311 { 312 import std.datetime : Duration; 313 314 enum SocketOption sockOpt = (direction == Direction.send) 315 ? SocketOption.RCVTIMEO : SocketOption.RCVTIMEO; 316 317 Duration result; 318 socket.getOption(SocketOptionLevel.SOCKET, sockOpt, result); 319 return result.total!"seconds"; 320 } 321 322 void setTimeout(Direction direction)(Socket socket, long seconds) 323 if (direction != Direction.none) 324 { 325 import std.datetime : durSeconds = seconds; 326 327 static if (direction == Direction.both) 328 { 329 setTimeout!(Direction.send)(socket, seconds); 330 setTimeout!(Direction.receive)(socket, seconds); 331 } 332 else static if (direction == Direction.receive) 333 { 334 logTrace(format!"Setting receive timeout to %d seconds (#%X)"(seconds, socket.handle)); 335 socket.setOption(SocketOptionLevel.SOCKET, SocketOption.RCVTIMEO, durSeconds(seconds)); 336 } 337 else static if (direction == Direction.send) 338 { 339 logTrace(format!"Setting send timeout to %d seconds (#%X)"(seconds, socket.handle)); 340 socket.setOption(SocketOptionLevel.SOCKET, SocketOption.SNDTIMEO, durSeconds(seconds)); 341 } 342 else 343 static assert(false, "Bug"); 344 } 345 }