1 /++
2     Listener and worker implementation
3  +/
4 module socketplate.server.worker;
5 
6 import core.atomic : atomicLoad, atomicStore;
7 import socketplate.connection;
8 import socketplate.log;
9 import std.conv : to;
10 import std.string : format;
11 import std.socket : Address, Socket, SocketShutdown;
12 
13 @safe:
14 
15 final class SocketListener
16 {
17 @safe:
18 
19     private enum State
20     {
21         initial,
22         bound,
23         listening,
24         closed,
25     }
26 
27     private
28     {
29         State _state;
30 
31         Socket _socket;
32         Address _address;
33         ConnectionHandler _callback;
34         int _timeout;
35         static Socket _accepted = null;
36     }
37 
38     public this(Socket socket, Address address, ConnectionHandler callback, int timeout) pure nothrow @nogc
39     {
40         _socket = socket;
41         _address = address;
42         _callback = callback;
43         _timeout = timeout;
44 
45         _state = State.initial;
46     }
47 
48     public bool isClosed() pure nothrow @nogc
49     {
50         return (_state == State.closed);
51     }
52 
53     public void bind(bool socketOptionREUSEADDR = true)
54     in (_state == State.initial)
55     {
56         // unlink Unix Domain Socket file if applicable
57         unlinkUnixDomainSocket(_address);
58 
59         // enable address reuse
60         _socket.setReuseAddr = socketOptionREUSEADDR;
61 
62         logTrace(format!"Binding to %s (#%X)"(_address.toString, _socket.handle));
63         _socket.bind(_address);
64         _state = State.bound;
65     }
66 
67     public void listen(int backlog)
68     in (_state == State.bound)
69     {
70         logTrace(format!"Listening on %s (#%X)"(_address.toString, _socket.handle));
71         _socket.listen(backlog);
72         _state = State.listening;
73     }
74 
75     private void accept(size_t workerID)
76     in (_state == State.listening)
77     {
78         import std.socket : socket_t;
79 
80         logTrace(format!"Accepting incoming connections (#%X @%02d)"(_socket.handle, workerID));
81         _accepted = _socket.accept();
82 
83         socket_t acceptedID = _accepted.handle;
84 
85         logTrace(format!"Incoming connection accepted (#%X @%02d)"(acceptedID, workerID));
86         try
87             _callback(makeSocketConnection(_accepted, _timeout));
88         catch (Exception ex)
89             logError(
90                 format!"Unhandled Exception in connection handler (#%X): %s"(acceptedID, ex.msg)
91             );
92 
93         logTrace(format!"Connection handled (#%X)"(acceptedID));
94 
95         if (_accepted.isAlive)
96         {
97             logTrace(format!"Closing still-alive connection (#%X)"(acceptedID));
98             _accepted.close();
99         }
100     }
101 
102     private void shutdownClose(bool doLog = true)()
103     in (_state != State.closed)
104     {
105         static if (doLog)
106             logTrace(format!"Shutting down socket (#%X)"(_socket.handle));
107         _socket.shutdown(SocketShutdown.BOTH);
108 
109         static if (doLog)
110             logTrace(format!"Closing socket (#%X)"(_socket.handle));
111         _socket.close();
112 
113         _state = State.closed;
114     }
115 
116     private void ensureShutdownClosed()
117     {
118         if (_state == State.closed)
119             return;
120 
121         shutdownClose!true();
122     }
123 
124     private void ensureShutdownClosedNoLog() nothrow @nogc
125     {
126         if (_state == State.closed)
127             return;
128 
129         shutdownClose!false();
130     }
131 
132     private void shutdownAccepted() nothrow @nogc
133     {
134         if (_accepted is null)
135             return;
136 
137         _accepted.shutdown(SocketShutdown.BOTH);
138         _accepted.close();
139     }
140 }
141 
142 class Worker
143 {
144 @safe:
145 
146     private
147     {
148         shared(bool) _active = false;
149 
150         size_t _id;
151         SocketListener _listener;
152         bool _setupSignalHandlers;
153     }
154 
155     public this(SocketListener listener, size_t id, bool setupSignalHandlers)
156     {
157         _listener = listener;
158         _id = id;
159         _setupSignalHandlers = setupSignalHandlers;
160     }
161 
162     public void run()
163     {
164         import std.socket : SocketException;
165 
166         scope (exit)
167             logTrace(format!"Worker @%02d says goodbye"(_id));
168 
169         if (_setupSignalHandlers)
170             doSetupSignalHandlers();
171 
172         scope (exit)
173         {
174             logInfo(format!"Worker @%02d exiting"(_id));
175             _listener.ensureShutdownClosed();
176         }
177 
178         _active.atomicStore = true;
179         while (atomicLoad(_active))
180         {
181             try
182                 _listener.accept(_id);
183             catch (SocketException)
184                 break;
185         }
186     }
187 
188     public void shutdown() nothrow @nogc
189     {
190         _active.atomicStore = false;
191     }
192 
193     private void doSetupSignalHandlers()
194     {
195         import socketplate.signal;
196 
197         setupSignalHandlers((int) @safe nothrow @nogc {
198             this.shutdown();
199             this._listener.ensureShutdownClosedNoLog();
200             this._listener.shutdownAccepted();
201         });
202     }
203 }
204 
205 private SocketConnection makeSocketConnection(Socket socket, int seconds)
206 {
207     auto sc = SocketConnection(socket);
208     sc.timeout!(Direction.receive)(seconds);
209     return sc;
210 }
211 
212 private void setReuseAddr(Socket socket, bool enable)
213 {
214     import std.socket : SocketOption, SocketOptionLevel;
215 
216     socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, enable);
217 }
218 
219 private void unlinkUnixDomainSocket(Address addr)
220 {
221     import std.socket : AddressFamily;
222 
223     version (Posix)
224     {
225         if (addr.addressFamily == AddressFamily.UNIX)
226         {
227             import core.sys.posix.unistd : unlink;
228             import std.file : exists;
229             import std.socket : UnixAddress;
230             import std.string : toStringz;
231 
232             UnixAddress uaddr = cast(UnixAddress) addr;
233 
234             if (uaddr is null)
235             {
236                 logError("Cannot determine path of Unix Domain Socket");
237                 return;
238             }
239 
240             if (!uaddr.path.exists)
241             {
242                 logTrace("Unix Domain Socket path does not exists; nothing to unlink");
243                 return;
244             }
245 
246             logTrace(format!"Unlinking Unix Domain Socket file: %s"(uaddr.path));
247             int r = () @trusted { return unlink(uaddr.path.toStringz); }();
248 
249             if (r != 0)
250                 logTrace(format!"Unlinking failed with status: %d"(r));
251         }
252     }
253 }