#!/usr/bin/python

# Simple server for the "count characters" protocol. Symmetric crypto
# with AES, to prevent sniffing.

import SocketServer
import socket
import logging
from Crypto.Cipher import AES

PORT=4923
KEY="fubaaar12AZZ666."

class RequestHandler(SocketServer.StreamRequestHandler):

    def handle(self):
        logging.info("Client connection")
        text = "DUMMY"
        data = ""
        size = 0
        while text != "":
            encrypted_text = self.rfile.read(16)
            if encrypted_text == "":
                break
            if len(encrypted_text) < 16:
                n = 16 - len(encrypted_text)
                for i in range(0, n):
                    encrypted_text += '\0'
            text = self.server.decoder.decrypt(encrypted_text)
            if text == "\n" or text == "":
                break
            # TODO: handle UnicodeDecodeError
            data += text.decode("UTF-8")
            # Remove padding
            data = data.rstrip('\0')
        size = len(data)
        self.wfile.write((u"You sent %i characters" % size).encode("UTF-8"))

class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer): 

    def __init__(self, address, handler):
        # Do not use CBC, since each client will start from scratch,
        # but with an existing server. Security-wise, it is not a good
        # idea, we should use CBC and, to do so, the decoder should be
        # per-connection, not global.
        self.decoder = AES.new(KEY, AES.MODE_ECB)
        SocketServer.TCPServer.__init__(self, address, handler)
        logging.info("Starting server")

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s %(levelname)s %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
Server.address_family = socket.AF_INET6
Server.allow_reuse_address = True
myserver = Server(("", PORT), RequestHandler)
try:
    myserver.serve_forever()
except KeyboardInterrupt:
    logging.info("Server stopped")


