#!/usr/bin/env python

"""Measures the effectiveness of a DoS attack by sending N DNS
requests (to the default resolver or to a configurable name server) at
intervals of T (N and T being of course tunable) and displays
aggregates such as "X % of the requests were successful"."""

# Defaults
qtype = "AAAA"
timeout = 2.0
number = 10
delay = 30.0 # Should be a bit higher than the TTL, if talking to a resolver.
jitter = None
server = None # By default, the default resolver
only_noerror = False

# Standard library
import sys
import time
import getopt
import random

# DNSpython https://www.dnspython.org/
import dns.rdatatype
import dns.message
import dns.query
import dns.resolver

def usage(msg=None):
    print("Usage: %s domain-name" % sys.argv[0], file=sys.stderr)
    if msg:
        print(msg, file=sys.stderr)

try:
    optlist, args = getopt.getopt(sys.argv[1:], "d:ehj:n:s:t:y:",
                                  ["delay=", "help", "jitter=",
                                   "number=", "only-noerror",
                                   "server=", "timeout=", "type="])
    for option, value in optlist:
        if option == "--delay" or option == "-d":
            delay = float(value)
        elif option == "--help" or option == "-h":
            usage()
            sys.exit(0)
        elif option == "--jitter" or option == "-j":
            jitter = float(value)
        elif option == "--number" or option == "-n":
            number = int(value)
        elif option == "--only-noerror" or option == "-e":
            only_noerror = True
        elif option == "--server" or option == "-s":
            server = value
        elif option == "--timeout" or option == "-t":
            timeout = float(value)
        elif option == "--type" or option == "-y":
            try:
                qtype = dns.rdatatype.from_text(value) 
            except dns.rdatatype.UnknownRdatatype:
                usage("Wrong DNS resource record type %s" % value)
                sys.exit(1)
        else:
            usage("Unknown option %s" % option)
            sys.exit(1)
except getopt.error as reason:
    usage(reason)
    sys.exit(1)
if jitter is not None and jitter >= delay:
    usage("Jitter (%2.1f) must be lower than delay (%2.1f)" % (jitter, delay))
    sys.exit(1)
if len(args) != 1:
    usage()
    sys.exit(1)
name = args[0]
if server is None:
    server = dns.resolver.get_default_resolver().nameservers[0] # We always use the first one

generator = random.Random()
message = dns.message.make_query(name, qtype)
successes = 0
sent = 0
times = []
print("Expect an answer in more or less %2.1f seconds" % (number*delay)) # TODO display the time here
for i in range(0, number):
    try:
        sent += 1
        start = time.time()
        response = dns.query.udp(message, server, timeout=timeout)
        if not only_noerror or response.rcode() == dns.rcode.NOERROR:
            stop = time.time()
            successes += 1
            times.append(stop-start)
    except dns.exception.Timeout: # TODO other exceptions?
        pass
    if jitter is not None:
        milliseconds = int(jitter*1000)
        extra = generator.randint(-milliseconds, +milliseconds)
    else:
        extra = 0
    time.sleep(delay+(extra/1000))
average = successes/sent
if successes > 0:
    total_time = 0
    for t in times:
        total_time += t
    average_time = total_time/successes
    average_time_s = "%2.3f s" % average_time
    # TODO median
else:
    average_time_s = "N/A"
print("%i requests among %i (%2.1f %%) succeeded. Average time %s. Measurement done on %s." % \
      (successes, sent, average*100.0, average_time_s,
       time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(time.time()))))
# TODO catch Control-C and display the result for the number actually done.
