#!/usr/bin/env python3

"""A simple Python program to test TLS connections and certificate
checks. It uses the standard library."""

# Can be changed on the command line
host = None
path = "/"

# Hardwired
PORT = 443

import socket
import getopt
import sys
import ssl

def error(msg=None):
    if msg is None:
        msg = "Unknown error"
    print(msg,file=sys.stderr)
    sys.exit(1)
    
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
    print("Usage: %s --server hostname" % sys.argv[0], file=sys.stderr)

def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

# Main program
try:
    optlist, args = getopt.getopt (sys.argv[1:], "hs:p:",
                                       ["help", "server=", "path="])
    for option, value in optlist:
        if option == "--help" or option == "-h":
            usage()
            sys.exit(0)
        elif option == "--server" or option == "-s":
            host = value
        elif option == "--path" or option == "-p":
            path = value
        else:
            error("Unknown option %s" % option)
except getopt.error as reason:
    usage(reason)
    sys.exit(1)
if host is None:
    usage("No server given")
    sys.exit(1)    
        
addrinfo = socket.getaddrinfo(host, PORT, 0)

# We should loop over the IP addresses instead of taking only the first one…
sock = socket.socket(addrinfo[0][0], socket.SOCK_STREAM)
addr = addrinfo[0][4]
print("Connecting to %s ..." % str(addr))

context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True # False by default
# Use the OS' default CAs
context.load_default_certs()

session = context.wrap_socket(sock, server_hostname=host)
# SNI is apparently used by default

# TCP
session.connect((addr))

# TLS
session.do_handshake()
cert = session.getpeercert()
print("Connected, its certificate is for \"%s\", delivered by \"%s\"" % \
      (cert['subject'],
       cert['issuer'])) # TODO getting the common name from these tuples is complicated

request = """
GET %s HTTP/1.1
Host: %s
Connection: close

""" % (path, canonicalize(host))
session.write(request.replace("\n", "\r\n").encode())

# In a real application, we would loop to get all the data
data = session.read(256)
print(data.decode())

session.close()
sock.close()
