1 module draklib.server.session;
2 import draklib.core;
3 import draklib.util;
4 import draklib.bytestream;
5 import draklib.server.raknetserver;
6 import draklib.protocol.offline;
7 import draklib.protocol.reliability;
8 import draklib.protocol.online;
9 import draklib.protocol.connected;
10 
11 import std.conv;
12 
13 enum SessionState {
14 	DISCONNECTED = 0,
15 	OFFLINE_1 = 1,
16 	OFFLINE_2 = 2,
17 	ONLINE_HANDSHAKE = 3,
18 	ONLINE_CONNECTED = 4
19 }
20 
21 class Session {
22 	static immutable uint MAX_SPLIT_SIZE = 128;
23 	static immutable uint MAX_SPLIT_COUNT = 4;
24 
25 	private uint state;
26 	private ushort mtu;
27 	private long clientGUID;
28 	private long timeLastPacketReceived;
29 	
30 	private shared int lastPing = -99;
31 	
32 	private int lastSeqNum = -1;
33 	private uint sendSeqNum = 0;
34 	
35 	private uint messageIndex = 0;
36 	private ushort splitID = 0;
37 	
38 	private ContainerPacket sendQueue;
39 	private ContainerPacket[uint] recoveryQueue;
40 	private bool[uint] ACKQueue;
41 	private bool[uint] NACKQueue;
42 	private EncapsulatedPacket[int][int] splitQueue;
43 	
44 	private RakNetServer server;
45 	private const string ip;
46 	private const ushort port;
47 
48 	this(RakNetServer server, in string ip, in ushort port) {
49 		this.server = server;
50 		this.ip = ip;
51 		this.port = port;
52 
53 		state = SessionState.OFFLINE_1;
54 
55 		sendQueue = new ContainerPacket();
56 		sendQueue.header = 0x84; //Default
57 	}
58 
59 	package void update() {
60 		if(state == SessionState.DISCONNECTED) return;
61 		if((getTimeMillis() - timeLastPacketReceived) >= server.options.timeoutThreshold) {
62 			disconnect("connection timed out");
63 		} else {
64 			if(ACKQueue.length > 0) {
65 				ACKPacket ack = new ACKPacket();
66 				ack.nums = cast(uint[]) [];
67 				foreach(uint num; ACKQueue.keys) {
68 					ack.nums ~= num;
69 				}
70 				byte[] data;
71 				ack.encode(data);
72 				sendRaw(data);
73 				version(DigitalMars) ACKQueue.clear();
74 				else {
75 					ACKQueue = [0 : true];
76 					ACKQueue.remove(0);
77 				}
78 			}
79 			if(NACKQueue.length > 0) {
80 				NACKPacket nack = new NACKPacket();
81 				nack.nums = cast(uint[]) [];
82 				foreach(uint num; NACKQueue.keys) {
83 					nack.nums ~= num;
84 				}
85 				byte[] data;
86 				nack.encode(data);
87 				sendRaw(data);
88 				version(DigitalMars) NACKQueue.clear();
89 				else {
90 					NACKQueue = [0 : true];
91 					NACKQueue.remove(0);
92 				}
93 			}
94 			
95 			sendQueuedPackets();
96 		}
97 	}
98 	
99 	private void sendQueuedPackets() {
100 		if(sendQueue.packets.length > 0) {
101 			sendQueue.sequenceNumber = sendSeqNum++;
102 			byte[] data;
103 			sendQueue.encode(data);
104 			sendRaw(data);
105 			recoveryQueue[sendQueue.sequenceNumber] = sendQueue;
106 			sendQueue.packets = [];
107 		}
108 	}
109 	
110 	/**
111 	 * Adds an EncapsulatedPacket to the queue, and sets its
112 	 * messageIndex, orderIndex, and any other values
113 	 * depending on the Reliability.
114 	 * 
115 	 * If the packet's total length is longer than the MTU (Maximum Transport Unit)
116 	 * then the packet will be split into smaller chunks, which each
117 	 * will be added to the queue.
118 	 * Params:
119 	 *     pk =         The EncapsulatedPacket to be added
120 	 *     immediate =  If the packet should skip the queue
121 	 *                  and be sent immediately.
122 	 */
123 	public void addToQueue(EncapsulatedPacket pk, in bool immediate = false) {
124 		switch(pk.reliability) {
125 			case Reliability.RELIABLE_ORDERED:
126 				//TODO: orderIndex
127 				goto case;
128 			case Reliability.RELIABLE:
129 			case Reliability.RELIABLE_SEQUENCED:
130 			case Reliability.RELIABLE_WITH_ACK_RECEIPT:
131 			case Reliability.RELIABLE_ORDERED_WITH_ACK_RECEIPT:
132 				pk.messageIndex = messageIndex++;
133 				debug(logMessageIndex) server.logger.logDebug("Set message index to: " ~ to!string(pk.messageIndex));
134 				break;
135 			default:
136 				break;
137 		}
138 		
139 		if(pk.getSize() + 4 > mtu) { //4 is overhead for CustomPacket header
140 			//Packet is too big, needs to be split
141 			byte[][] buffers = splitByteArray(pk.payload, mtu - 34);
142 			ushort splitID = this.splitID++;
143 			for(uint count = 0; count < buffers.length; count++) {
144 				EncapsulatedPacket ep = new EncapsulatedPacket();
145 				ep.splitID = splitID;
146 				ep.split = true;
147 				ep.splitCount = cast(uint) buffers.length;
148 				ep.reliability = pk.reliability;
149 				ep.splitIndex = count;
150 				ep.payload = buffers[count];
151 				
152 				if(count > 0) {
153 					ep.messageIndex = messageIndex++;
154 				} else {
155 					ep.messageIndex = pk.messageIndex;
156 				}
157 				if(ep.reliability == Reliability.RELIABLE_ORDERED) {
158 					ep.orderChannel = pk.orderChannel;
159 					ep.orderIndex = pk.orderIndex;
160 				}
161 				
162 				queuePacket(ep, true);
163 			}
164 		} else {
165 			queuePacket(pk, immediate);
166 		}
167 	}
168 	
169 	private void queuePacket(EncapsulatedPacket pkt, in bool immediate) {
170 		if(immediate) {
171 			ContainerPacket cp = new ContainerPacket();
172 			cp.header = 0x84;
173 			cp.packets = cast(EncapsulatedPacket[]) [];
174 			cp.packets ~= pkt;
175 			cp.sequenceNumber = sendSeqNum++;
176 			byte[] data;
177 			cp.encode(data);
178 			sendRaw(data);
179 			
180 			recoveryQueue[cp.sequenceNumber] = cp;
181 		} else {
182 			if((sendQueue.getSize() + pkt.getSize()) > mtu) {
183 				sendQueuedPackets();
184 			}
185 			sendQueue.packets ~= pkt;
186 		}
187 	}
188 	
189 	public void sendRaw(in byte[] data) {
190 		import std.socket : InternetAddress;
191 		server.sendPacket(new InternetAddress(ip, port), data);
192 	}
193 	
194 	package void handlePacket(byte[] packet) {
195 		if(state == SessionState.DISCONNECTED) return;
196 
197 		timeLastPacketReceived = getTimeMillis();
198 		byte[] data;
199 		switch(cast(ubyte) packet[0]) {
200 			// Non - Reliable Packets
201 			case RakNetInfo.OFFLINE_CONNECTION_REQUEST_1:
202 				if(state != SessionState.OFFLINE_1) return;
203 				OfflineConnectionRequest1 req1 = new OfflineConnectionRequest1();
204 				req1.decode(packet);
205 				mtu = req1.mtuSize;
206 				
207 				debug(sessionInfo) server.logger.logDebug("MTU: " ~ to!string(mtu));
208 				
209 				OfflineConnectionResponse1 res1 = new OfflineConnectionResponse1();
210 				res1.serverGUID = server.options.serverGUID;
211 				res1.mtu = mtu;
212 
213 				res1.encode(data);
214 				sendRaw(data);
215 				
216 				state = SessionState.OFFLINE_2;
217 				debug(sessionInfo) server.logger.logDebug("Enter state OFFLINE_2");
218 				break;
219 			case RakNetInfo.OFFLINE_CONNECTION_REQUEST_2:
220 				if(state != SessionState.OFFLINE_2) break;
221 				OfflineConnectionRequest2 req2 = new OfflineConnectionRequest2();
222 				req2.decode(packet);
223 				clientGUID = req2.clientGUID;
224 				
225 				OfflineConnectionResponse2 res2 = new OfflineConnectionResponse2();
226 				res2.serverGUID = server.options.serverGUID;
227 				res2.clientAddress = ip;
228 				res2.clientPort = port;
229 				res2.mtu = mtu;
230 				res2.encryptionEnabled = false; // RakNet encryption not implemented
231 
232 				res2.encode(data);
233 				sendRaw(data);
234 				
235 				state = SessionState.ONLINE_HANDSHAKE;
236 				debug(sessionInfo) server.logger.logDebug("Enter state ONLINE_HANDSHAKE");
237 				break;
238 				// ACK/NACK
239 			case RakNetInfo.ACK:
240 				ACKPacket ack = new ACKPacket();
241 				ack.decode(packet);
242 				
243 				foreach(uint num; ack.nums) {
244 					if(num in recoveryQueue) {
245 						recoveryQueue.remove(num);
246 					}
247 				}
248 				break;
249 			case RakNetInfo.NACK:
250 				NACKPacket nack = new NACKPacket();
251 				nack.decode(packet);
252 				
253 				foreach(uint num; nack.nums) {
254 					if(num in recoveryQueue) {
255 						ContainerPacket cp = recoveryQueue[num];
256 						cp.sequenceNumber = sendSeqNum++;
257 
258 						cp.encode(data);
259 						sendRaw(data);
260 
261 						recoveryQueue.remove(num);
262 					} else debug(sessionInfo) server.logger.logWarn("NACK " ~ to!string(num) ~ " not found in recovery queue");
263 				}
264 				break;
265 			default:
266 				if(cast(ubyte) (packet[0]) >= 0x80 && cast(ubyte) (packet[0]) <= 0x8F) {
267 					ContainerPacket cp = new ContainerPacket();
268 					cp.decode(packet);
269 					handleContainerPacket(cp);
270 				}
271 				break;
272 		}
273 	}
274 	
275 	private void handleContainerPacket(ContainerPacket cp) {
276 		int diff = cp.sequenceNumber - lastSeqNum;
277 		if(NACKQueue.length > 0) {
278 			NACKQueue.remove(cp.sequenceNumber);
279 			if(diff != 1) {
280 				for(int i = lastSeqNum + 1; i < cp.sequenceNumber; i++) {
281 					NACKQueue[i] = true;
282 				}
283 			}
284 		}
285 		
286 		ACKQueue[cp.sequenceNumber] = true;
287 		
288 		if(diff >= 1) lastSeqNum = cp.sequenceNumber;
289 		
290 		foreach(EncapsulatedPacket pk; cp.packets) {
291 			handleEncapsulatedPacket(pk);
292 		}
293 	}
294 	
295 	private void handleSplitPacket(EncapsulatedPacket pk) {
296 		if(pk.splitCount >= MAX_SPLIT_SIZE || pk.splitIndex >= MAX_SPLIT_SIZE) {
297 			debug server.logger.logWarn("Skipped split Encapsulated: size too big (splitCount: " ~ to!string(pk.splitCount) ~ ", splitIndex: " ~ to!string(pk.splitIndex) ~ ")");
298 			return;
299 		}
300 		
301 		if(!(pk.splitID in splitQueue)) {
302 			if(splitQueue.length >= MAX_SPLIT_COUNT) {
303 				debug server.logger.logWarn("Skipped split Encapsulated: too many in queue (" ~ to!string(splitQueue.length) ~ ")");
304 				return;
305 			}
306 			EncapsulatedPacket[int] m;
307 			m[pk.splitIndex] = pk;
308 			splitQueue[pk.splitID] = m;
309 		} else {
310 			auto m = splitQueue[pk.splitID];
311 			m[pk.splitIndex] = pk;
312 			splitQueue[pk.splitID] = m;
313 		}
314 		
315 		if(splitQueue[pk.splitID].keys.length == pk.splitCount) {
316 			EncapsulatedPacket ep = new EncapsulatedPacket();
317 			ByteStream bs = ByteStream.alloc(1024 * 1024);
318 			auto packets = splitQueue[pk.splitID];
319 			foreach(EncapsulatedPacket packet; packets) {
320 				bs.write(packet.payload);
321 			}
322 			
323 			splitQueue.remove(pk.splitID);
324 			
325 			ep.payload = bs.getBuffer()[0..bs.getPosition()].dup;
326 			bs = null;
327 			
328 			handleEncapsulatedPacket(ep);
329 		}
330 	}
331 	
332 	private void handleEncapsulatedPacket(EncapsulatedPacket pk) {
333 		assert(pk.payload.length > 0);
334 		if(!(state == SessionState.ONLINE_CONNECTED || state == SessionState.ONLINE_HANDSHAKE)) {
335 			debug server.logger.logWarn("Skipped Encapsulated: not in right state (" ~ to!string(state) ~ ")");
336 			return;
337 		}
338 		if(pk.split) {
339 			if(state == SessionState.ONLINE_CONNECTED)
340 				handleSplitPacket(pk);
341 			else debug server.logger.logWarn("Skipped split Encapsulated: not in right state (" ~ to!string(state) ~ ")");
342 		}
343 		
344 		switch(cast(ubyte) pk.payload[0]) {
345 			case RakNetInfo.DISCONNECT_NOTIFICATION:
346 				disconnect("client disconnected");
347 				break;
348 			case RakNetInfo.ONLINE_CONNECTION_REQUEST:
349 				OnlineConnectionRequest ocr = new OnlineConnectionRequest();
350 				ocr.decode(pk.payload);
351 				
352 				OnlineConnectionRequestAccepted ocra = new OnlineConnectionRequestAccepted();
353 				ocra.clientAddress = ip;
354 				ocra.clientPort = port;
355 				ocra.requestTime = ocr.time;
356 				ocra.time = ocr.time + 1000L;
357 				
358 				EncapsulatedPacket ep = new EncapsulatedPacket();
359 				ep.reliability = Reliability.UNRELIABLE;
360 				ocra.encode(ep.payload);
361 				addToQueue(ep, true);
362 				break;
363 			case 0x13:
364 				state = SessionState.ONLINE_CONNECTED;
365 				debug(sessionInfo) server.logger.logDebug("Enter state ONLINE_CONNECTED");
366 				server.onSessionOpen(this, clientGUID);
367 				break;
368 			case RakNetInfo.CONNECTED_PING:
369 				ConnectedPingPacket ping = new ConnectedPingPacket();
370 				ping.decode(pk.payload);
371 
372 				ConnectedPongPacket pong = new ConnectedPongPacket();
373 				pong.time = ping.time;
374 
375 				EncapsulatedPacket ep = new EncapsulatedPacket();
376 				ep.reliability = Reliability.UNRELIABLE;
377 				pong.encode(ep.payload);
378 				addToQueue(ep, true);
379 				break;
380 			default:
381 				server.onSessionReceivePacket(this, cast(shared) pk.payload);
382 				break;
383 		}
384 	}
385 	
386 	public void disconnect(in string reason = null) {
387 		EncapsulatedPacket ep = new EncapsulatedPacket();
388 		ep.reliability = Reliability.UNRELIABLE;
389 		ep.payload = cast(byte[]) [0x15];
390 		addToQueue(ep, true);
391 		
392 		server.addToBlacklist(getIdentifier(), 30);
393 		
394 		state = SessionState.DISCONNECTED;
395 		
396 		server.onSessionClose(this, reason);
397 	}
398 	
399 	public RakNetServer getServer() {
400 		return server;
401 	}
402 	
403 	public string getIpAddress() {
404 		return ip;
405 	}
406 
407 	public ushort getPort() {
408 		return port;
409 	}
410 
411 	public string getIdentifier() {
412 		return ip ~ ":" ~ to!string(port);
413 	}
414 
415 	public long getClientGUID() {
416 		return clientGUID;
417 	}
418 }