1 /++
2     Socket server implementation
3  +/
4 module socketplate.server.server;
5 
6 import core.thread;
7 import socketplate.address;
8 import socketplate.connection;
9 import socketplate.log;
10 import socketplate.server.worker;
11 import std.format;
12 import std.socket;
13 
14 @safe:
15 
16 /++
17     Options to tailor the socket server to your needs
18  +/
19 struct SocketServerTunables
20 {
21     /++
22         Listening backlog
23      +/
24     int backlog = SOMAXCONN;
25 
26     /++
27         Receive/read timeout
28      +/
29     int timeout = 60;
30 
31     /++
32         Number of workers per listener
33      +/
34     int workers = 2;
35 
36     /++
37         Whether to set up signal handlers
38      +/
39     bool setupSignalHandlers = true;
40 }
41 
42 ///
43 final class SocketServer
44 {
45 @safe:
46 
47     private
48     {
49         SocketServerTunables _tunables;
50         bool _shutdown = false;
51 
52         SocketListener[] _listeners;
53     }
54 
55     ///
56     public this(SocketServerTunables tunables) pure nothrow @nogc
57     {
58         _tunables = tunables;
59     }
60 
61     /// ditto
62     public this() pure nothrow @nogc
63     {
64         this(SocketServerTunables());
65     }
66 
67     public
68     {
69         ///
70         int run()
71         {
72             if (_listeners.length == 0)
73             {
74                 logWarning("There are no listeners, hence no workers to spawn.");
75                 return 0;
76             }
77 
78             logTrace("Running");
79             int x = spawnWorkers();
80             logTrace("Exiting (Main Thread)");
81             return x;
82         }
83 
84         ///
85         void bind(bool socketOptionREUSEADDR = true)
86         {
87             foreach (listener; _listeners)
88                 listener.bind(socketOptionREUSEADDR);
89         }
90 
91         void registerListener(SocketListener listener)
92         {
93             _listeners ~= listener;
94         }
95     }
96 
97     private
98     {
99         int spawnWorkers()
100         {
101             logTrace("Starting SocketServer in Threading mode");
102 
103             Thread[] threads;
104             Worker[] workers;
105 
106             size_t nWorkers = (_listeners.length * _tunables.workers);
107             workers.reserve(nWorkers);
108 
109             scope (exit)
110                 foreach (worker; workers)
111                     worker.shutdown();
112 
113             foreach (SocketListener listener; _listeners)
114             {
115                 listener.listen(_tunables.backlog);
116 
117                 foreach (i; 0 .. _tunables.workers)
118                     threads ~= spawnWorkerThread(threads.length, listener, _tunables, workers);
119             }
120 
121             // setup signal handlers (if requested)
122             if (_tunables.setupSignalHandlers)
123             {
124                 import socketplate.signal;
125 
126                 setupSignalHandlers(delegate(int signal) @safe nothrow @nogc {
127                     // signal threads
128                     forwardSignal(signal, threads);
129                 });
130             }
131 
132             // start worker threads
133             foreach (Thread thread; threads)
134                 function(Thread thread) @trusted { thread.start(); }(thread);
135 
136             bool error = false;
137 
138             // wait for workers to exit
139             foreach (thread; threads)
140             {
141                 function(Thread thread, ref error) @trusted {
142                     try
143                         thread.join();
144                     catch (Exception)
145                         error = true;
146                 }(thread, error);
147             }
148 
149             return (error) ? 1 : 0;
150         }
151 
152         static Thread spawnWorkerThread(
153             size_t id,
154             SocketListener listener,
155             const SocketServerTunables tunables,
156             ref Worker[] workers
157         )
158         {
159             auto worker = new Worker(listener, id, tunables.setupSignalHandlers);
160             workers ~= worker;
161             return new Thread(&worker.run);
162         }
163     }
164 }
165 
166 /++
167     Registers a new TCP listener
168  +/
169 void listenTCP(SocketServer server, Address address, ConnectionHandler handler)
170 {
171     logTrace("Registering TCP listener on ", address.toString);
172 
173     ProtocolType protocolType = (address.addressFamily == AddressFamily.UNIX)
174         ? cast(ProtocolType) 0 : ProtocolType.TCP;
175 
176     auto listener = new SocketListener(
177         new Socket(address.addressFamily, SocketType.STREAM, protocolType),
178         address,
179         handler,
180         server._tunables.timeout,
181     );
182 
183     server.registerListener(listener);
184 }
185 
186 /// ditto
187 void listenTCP(SocketServer server, SocketAddress listenOn, ConnectionHandler handler)
188 {
189     return listenTCP(server, listenOn.toPhobos(), handler);
190 }
191 
192 /// ditto
193 void listenTCP(SocketServer server, string listenOn, ConnectionHandler handler)
194 {
195     SocketAddress sockAddr;
196     assert(parseSocketAddress(listenOn, sockAddr), "Invalid listening address");
197     return listenTCP(server, sockAddr, handler);
198 }
199 
200 // Converts a SocketAddress to an `std.socket.Address`
201 private Address toPhobos(SocketAddress sockAddr)
202 {
203     try
204     {
205         final switch (sockAddr.type) with (SocketAddress.Type)
206         {
207         case unixDomain:
208             version (Posix)
209                 return new UnixAddress(sockAddr.address);
210             else
211                 assert(false, "Unix Domain sockets unavailable");
212 
213         case ipv4:
214             assert(sockAddr.port > 0);
215             return new InternetAddress(sockAddr.address, cast(ushort) sockAddr.port);
216 
217         case ipv6:
218             assert(sockAddr.port > 0);
219             return new Internet6Address(sockAddr.address, cast(ushort) sockAddr.port);
220 
221         case invalid:
222             assert(false, "Invalid address");
223         }
224     }
225     catch (AddressException ex)
226     {
227         assert(false, "Invalid address: " ~ ex.msg);
228     }
229 }