#!/usr/bin/env python

# Simple server for the "count characters" protocol. Authentification
# with RSA, to prevent in-transit modification of data.

import SocketServer
import socket
import logging
from Crypto.PublicKey import RSA
import cPickle as pickle

PORT=4923
SEPARATOR="---------------------------"
client_key_filename = "./client-public.key"
server_key_filename = "./server-full.key"

class RequestHandler(SocketServer.StreamRequestHandler):

    def handle(self):
        logging.info("Client connection")
        text = "DUMMY"
        signed_data = ""
        size = 0
        while text != "":
            signed_text = self.rfile.read(1)
            if signed_text == "":
                break
            signed_data += signed_text
        # ValueError if no SEPARATOR in the data
        (data, signature_str) = signed_data.split(SEPARATOR)
        if not signature_str:
            raise Exception("No signature in message")
        signature = long(signature_str)
        if not self.server.client_key.verify(data, (signature, )):
            raise Exception("Wrong signature in message")
        # TODO: handle UnicodeDecodeError
        text = data.decode("UTF-8")
        size = len(text)
        response = (u"You sent %i characters" % size).encode("UTF-8")
        self.wfile.write("%s%s%li" % (response, SEPARATOR, self.server.server_key.\
                         sign(response, None)[0]))

class Server(SocketServer.ThreadingMixIn, SocketServer.TCPServer): 

    def __init__(self, address, handler, mykey, client_key):
        # Do not use CBC, since each client will start from scratch,
        # but with an existing server. Security-wise, it is not a good
        # idea, we should use CBC and, to do so, the decoder should be
        # per-connection, not global.
        self.server_key = mykey
        self.client_key = client_key
        SocketServer.TCPServer.__init__(self, address, handler)
        logging.info("Starting server")

logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s %(levelname)s %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')
server_key_file = open(server_key_filename, 'r')
server_key = pickle.load(server_key_file)
server_key_file.close()
client_key_file = open(client_key_filename, 'r')
client_key = pickle.load(client_key_file)
client_key_file.close()
Server.address_family = socket.AF_INET6
Server.allow_reuse_address = True
myserver = Server(("", PORT), RequestHandler, server_key, client_key)
try:
    myserver.serve_forever()
except KeyboardInterrupt:
    logging.info("Server stopped")



