/* Test of various "hash-table" solutions for a C program. This one
   analyses a ".pcap" file to find out the top requests which yielded
   a NXDOMAIN response so, "the most popular non-existing domains". */

#include <stdio.h>
#include <pcap.h>
#include <stdint.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <assert.h>

#ifdef HASHTABLE_GLIB
#include <glib.h>
#elif HASHTABLE_JUDY
#include <Judy.h>
#elif HASHTABLE_HAMSTERDB
#include <ham/hamsterdb.h>
#elif HASHTABLE_BERKELEYDB
#include <db.h>
#else
#error "No hash table defined"
#endif

#include "test-hashtable-nxdomain.h"

#define MAX_NAME 255
#define MAX_LABEL 63
#define MAX_NUM_LABELS 127
#define MAX_KEY 256

#ifdef HASHTABLE_GLIB
static GHashTable *table;
#elif HASHTABLE_JUDY
Pvoid_t         table;
#elif HASHTABLE_HAMSTERDB
ham_db_t       *table;
char          **sorted_array;
#elif HASHTABLE_BERKELEYDB
static DB      *table;
/* for type try DB_BTREE or DB_HASH */
char          **sorted_array;
static const DBTYPE db_type = DB_BTREE;
#endif

static unsigned int total_packets = 0;

#ifdef HASHTABLE_GLIB
static void
print_value(gpointer key, /* @unused@ */ gpointer user_data)
{
    gpointer        value = g_hash_table_lookup(table, key);
    printf("%s: %u requests\n", (char *) key, *((unsigned int *) value));
}

static          gint
compare(gconstpointer left, gconstpointer right)
{
    unsigned int   *left_value, *right_value;
    left_value = (unsigned int *) g_hash_table_lookup(table, left);
    right_value = (unsigned int *) g_hash_table_lookup(table, right);
    if (*left_value < *right_value) {
        return 1;
    } else if (*left_value > *right_value) {
        return -1;
    } else {
        return 0;
    }
}
#elif  HASHTABLE_JUDY
#elif  HASHTABLE_HAMSTERDB
static unsigned int
occurrences_of(const char *domain)
{
    ham_key_t       key;
    ham_record_t    record;
    ham_status_t    status;
    unsigned int    result;
    memset(&key, 0, sizeof(key));
    memset(&record, 0, sizeof(record));
    key.data = (char *) domain;
    key.size = strlen(key.data) + 1;    /* +1 for the terminating zero-byte */
    record.data = "";
    record.size = 0;
    if ((status = ham_find(table, NULL, &key, &record, 0)) != HAM_SUCCESS) {
        if (status == HAM_KEY_NOT_FOUND) {
            result = 0;
        } else {
            fprintf(stderr, "ham_find failed: %d (%s)\n", status,
                    ham_strerror(status));
            return (0);
        }
    } else {
        result = *((unsigned int *) record.data) + 1;
    }
    return result;
}

static int
compare_by_occurrences(const void *lhs, const void *rhs)
{
    unsigned int    nlhs = occurrences_of(*(char *const *) lhs);
    unsigned int    nrhs = occurrences_of(*(char *const *) rhs);
    if (nlhs < nrhs)
        return +1;
    if (nrhs < nlhs)
        return -1;
    return 0;
}
#elif HASHTABLE_BERKELEYDB
static unsigned int
occurrences_of(const char *domain)
{
    DBT             key;
    DBT             record;
    int             status;
    unsigned int    result;
    memset(&key, 0, sizeof(key));
    memset(&record, 0, sizeof(record));
    key.data = (char *) domain;
    key.size = strlen(key.data) + 1;    /* +1 for the terminating zero-byte */
    record.data = &result;
    record.size = record.ulen = sizeof result;
    record.flags = DB_DBT_USERMEM;
    if ((status = table->get(table, NULL, &key, &record, 0))) {
        table->err(table, status, "table->get failed for %s", domain);
        return 0;
    }
    return result;
}

static int
compare_by_occurrences(const void *lhs, const void *rhs)
{
    unsigned int    nlhs = occurrences_of(*(char *const *) lhs);
    unsigned int    nrhs = occurrences_of(*(char *const *) rhs);
    if (nlhs < nrhs)
        return +1;
    if (nrhs < nlhs)
        return -1;
    return 0;
}
#endif

static /* @null@ */ char *
reg_domain(char *domain)
{
    char           *labels[MAX_NUM_LABELS];
    char           *result = malloc(MAX_NAME);
    unsigned int    n, x, cut;
    int             i;
    n = 0;
    x = 0;
    assert(result != NULL);
    assert(domain != NULL);
    labels[0] = malloc(MAX_LABEL);
#ifdef DEBUG
    printf("\tDEBUG: reg_domain(%s)\n", domain);
#endif
    for (i = (int) strlen(domain) - 1; i >= 0; i--) {
#ifdef DEBUG
        printf("\tDEBUG: %u, %u, %c\n", i, x, domain[i]);
#endif
        if (domain[i] == '.' || i == 0) {
            if (i == 0) {
                cut = 0;
                x = x + 1;
                if (x >= MAX_LABEL) {
                    fprintf(stderr,
                            "WARNING: invalid packet %u: label %u is too long\n",
                            total_packets, n);
                    return NULL;
                }
            } else {
                cut = (unsigned int) i + 1;
            }
            labels[n] = malloc(MAX_LABEL);
            assert(labels[n] != NULL);
            labels[n][0] = '\0';
            strncpy(labels[n], domain + cut, (size_t) x);
            labels[n][x] = '\0';
#ifdef DEBUG
            fprintf(stdout, "\t\tDEBUG: \"%s\" %u (%u)\n", labels[n], cut, x);
#endif
            x = 0;
            cut = 0;
            if (i != 0) {
                n = n + 1;
            } else {
                break;
            }
        } else {
            x = x + 1;
        }
    }
    if (n == 0) {
        return NULL;
    }
    assert(labels[0] != NULL);  /* At least the TLD */
    if (n >= 1) {
        assert(labels[1] != NULL);
        if ((n >= 2) && ((strcmp(labels[0], "fr") == 0) &&
                         ((strcmp(labels[1], "cci") == 0) ||
                          (strcmp(labels[1], "asso") == 0) ||
                          (strcmp(labels[1], "gouv") == 0) ||
                          (strcmp(labels[1], "nom") == 0) ||
                          (strcmp(labels[1], "tm") == 0) ||
                          (strcmp(labels[1], "com") == 0)))) {
            assert(labels[2] != NULL);
            strcpy(result, labels[2]);
            strcat(result, ".");
            strcat(result, labels[1]);
        } else {
            strcpy(result, labels[1]);
        }
        strcat(result, ".");
        strcat(result, labels[0]);
    } else {
        strcpy(result, labels[0]);
    }
    for (i = 0; i <= (int) n; i++) {
        free(labels[i]);
    }
#ifdef DEBUG
    printf("DEBUG: %s -> %s\n", domain, result);
#endif
    return result;
}

int
main(int argc, char *argv[])
{
    char           *filename, errbuf[PCAP_ERRBUF_SIZE];
    pcap_t         *handle;
    const uint8_t  *packet;     /* The actual packet */
    struct pcap_pkthdr header;  /* The header that pcap gives us */
    const struct sniff_ethernet *ethernet;      /* The ethernet header */
    const struct sniff_ipv4 *ipv4;      /* The IPv4 header */
    const struct sniff_ipv6 *ipv6;      /* The IPv6 header */
    const struct sniff_udp *udp;        /* The UDP header */
    const struct sniff_dns *dns;
    const uint8_t  *qsection;

    unsigned int    i, filenum;

    unsigned short  ip_version;
    u_int           size_ip;
    u_short         source_port, dest_port;
    uint8_t         labelsize;
    const uint8_t  *nameptr;
    char           *fqdn, *interesting_domain;
    unsigned int    total_domains = 0;
    unsigned int    nxdomain_packets = 0;
    unsigned int   *requests;
#ifdef HASHTABLE_GLIB
    GList          *name_list, *sorted_name_list;
    gpointer       *result;
#elif HASHTABLE_JUDY
    uint8_t        *key;
    Word_t         *result, *value;
#elif HASHTABLE_HAMSTERDB
    ham_key_t       key;
    ham_record_t    record;
    ham_status_t    status;
    ham_cursor_t   *cursor;
#elif HASHTABLE_BERKELEYDB
    DBT             key;
    DBT             record;
    DBC            *cursor;
    unsigned int    record_data;
    int             status;
#endif

    if (argc < 2) {
        fprintf(stderr, "Usage: readfile filename(s)\n");
        return (2);
    }
#ifdef HASHTABLE_GLIB
    table = g_hash_table_new(g_str_hash, g_str_equal);
#elif HASHTABLE_JUDY
    table = (Pvoid_t) NULL;
#elif HASHTABLE_HAMSTERDB
    status = ham_new(&table);
    if (status != HAM_SUCCESS) {
        fprintf(stderr, "ham_new: %d", status);
        return (2);
    }
    status = ham_create(table, NULL, HAM_IN_MEMORY_DB, 0);
    if (status != HAM_SUCCESS) {
        fprintf(stderr, "ham_create: %d", status);
        return (2);
    }
#elif HASHTABLE_BERKELEYDB
    memset(&key, 0, sizeof key);
    memset(&record, 0, sizeof record);
    record.data = &record_data;
    record.size = record.ulen = sizeof record_data;
    record.flags = DB_DBT_USERMEM;
    if ((status = db_create(&table, NULL, 0))) {
        fprintf(stderr, "db_create: %d: %s\n", status, db_strerror(status));
      err:
        if (table && (status = table->close(table, 0)))
            table->err(table, status, "table->close(...) failed");
        return 2;
    }
    if ((status =
         table->set_cachesize(table, CACHE_GIGAS, CACHE_BYTES, CACHE_COUNT)))
        table->err(table, status, "table->set_cachesize failed, continuing");
    if ((status = table->open(table, NULL, NULL, "table", db_type, DB_CREATE, 0))) {
        table->err(table, status, "table->open(...) failed");
        goto err;
    }
#endif
    for (filenum = 1; filenum < argc; filenum++) {
        filename = argv[filenum];
        handle = pcap_open_offline(filename, errbuf);
        if (handle == NULL) {
            fprintf(stderr, "Couldn't open file %s: %s\n", filename, errbuf);
            return (2);
        }
#ifdef DEBUG
        printf("DEBUG: file %i (%s), we are at %u packets\n", filenum, filename,
               total_packets);
#endif
        for (;;) {
            /* Grab a packet */
            packet = (uint8_t *) pcap_next(handle, &header);
            if (packet == NULL) {       /* End of file */
                break;
            }
            total_packets++;
#ifdef DEBUG
            printf("DEBUG: packet %u of size %u (%u in the file)\n", total_packets,
                   header.len, header.caplen);
#endif
            ethernet = (struct sniff_ethernet *) (packet);
            if (ntohs(ethernet->ether_type) == IPv6_ETHERTYPE) {
                ipv6 = (struct sniff_ipv6 *) (packet + SIZE_ETHERNET);
                assert(IP_VERSION(ipv6) == 6);
                ip_version = 6;
                size_ip = SIZE_IPv6;
            } else if (ntohs(ethernet->ether_type) == IPv4_ETHERTYPE) {
                ipv4 = (struct sniff_ipv4 *) (packet + SIZE_ETHERNET);
                assert(IP_V(ipv4) == 4);
                ip_version = 4;
                size_ip = IP_HL(ipv4) * 4;
            } else {
                ip_version = 0;
            }
            assert(ip_version != 0);
            if ((ip_version == 6 && ipv6->ip_nxt == UDP)
                || (ip_version == 4 && ipv4->ip_p == UDP)) {
                udp = (struct sniff_udp *) (packet + SIZE_ETHERNET + size_ip);
                source_port = ntohs(udp->sport);
                dest_port = ntohs(udp->dport);
                if (source_port == DNS_PORT || dest_port == DNS_PORT) {
                    dns =
                        (struct sniff_dns
                         *) (packet + SIZE_ETHERNET + size_ip + SIZE_UDP);
                    if (DNS_RCODE(dns) == NXDOMAIN) {
                        nxdomain_packets++;
                        qsection = (uint8_t *) (packet +
                                                SIZE_ETHERNET +
                                                size_ip + SIZE_UDP + SIZE_DNS);
                        fqdn = malloc(MAX_NAME);
                        assert(fqdn != NULL);
                        fqdn[0] = '\0';
                        for (nameptr = qsection; nameptr != NULL;) {
                            labelsize = (uint8_t) * nameptr;
                            if (labelsize == 0) {       /* End of name */
                                nameptr = NULL;
                            } else {
                                if (strlen(fqdn) == 0) {
                                    strncpy(fqdn, (char *)
                                            nameptr + 1, labelsize);
                                    fqdn[labelsize] = '\0';
                                } else {
                                    fqdn = strncat(fqdn, ".", 1);
                                    fqdn = strncat(fqdn, (char *)
                                                   nameptr + 1, labelsize);
                                }
                                nameptr = nameptr + labelsize + 1;
                            }
                        }
                        for (i = 0; i < strlen(fqdn); i++) {
                            fqdn[i] = (char) tolower((int) fqdn[i]);
                        }
                        interesting_domain = reg_domain(fqdn);
                        if (interesting_domain == NULL) {
                            continue;
                        }
#ifdef HASHTABLE_GLIB
                        result = g_hash_table_lookup(table, interesting_domain);
                        if (result == NULL) {
                            requests = malloc(sizeof(unsigned int));
                            assert(requests != NULL);
#ifdef DEBUG
                            printf("DEBUG: inserting \"%s\" (from \"%s\")\n",
                                   interesting_domain, fqdn);
#endif
                            total_domains++;
                            *requests = 1;
                            g_hash_table_insert(table,
                                                interesting_domain,
                                                (gpointer) requests);
                        } else {
                            (*(unsigned int *) result)++;
                        }
#elif HASHTABLE_JUDY
                        JSLG(result, table, (const uint8_t *) interesting_domain);
                        if (result == NULL) {
                            total_domains++;
                            requests = malloc(sizeof(unsigned int));
                            assert(requests != NULL);
                            JSLI(requests, table,
                                 (const uint8_t *) interesting_domain);
                            *requests = 1;
                            total_domains++;
                        } else {
                            requests = (unsigned int *) result;
                            *requests = (*requests) + 1;
                        }
#elif HASHTABLE_HAMSTERDB
                        memset(&key, 0, sizeof(key));
                        memset(&record, 0, sizeof(record));
                        key.data = interesting_domain;
                        key.size = strlen(key.data) + 1;        /* +1 for the
                                                                 * terminating
                                                                 * zero-byte */
                        record.data = "";
                        if ((status =
                             ham_find(table, NULL, &key, &record,
                                      0)) != HAM_SUCCESS) {
                            if (status == HAM_KEY_NOT_FOUND) {
                                total_domains++;
                                requests = malloc(sizeof(unsigned int));
                                assert(requests != NULL);
                                *requests = 1;
                            } else {
                                printf("ham_find failed: %d (%s)\n", status,
                                       ham_strerror(status));
                                return (2);
                            }
                        } else {
                            *requests = *((unsigned int *) record.data) + 1;
                        }
                        record.data = requests;
                        record.size = sizeof(*requests);
                        if ((status =
                             ham_insert(table, NULL, &key, &record,
                                        HAM_OVERWRITE)) != HAM_SUCCESS) {
                            fprintf(stderr, "ham_insert failed: %d (%s)\n", status,
                                    ham_strerror(status));
                            return (2);
                        }
#elif HASHTABLE_BERKELEYDB
                        key.data = interesting_domain;
                        key.size = strlen(key.data) + 1;
                        status = table->get(table, 0, &key, &record, 0);
                        switch (status) {
                        case 0:
                            record_data++;
                            break;
                        case DB_NOTFOUND:
                            record_data = 1;
                            total_domains++;
                            break;
                        default:
                            table->err(table, status, "table->get failed");
                            goto err;
                        }
                        if ((status = table->put(table, 0, &key, &record, 0))) {
                            table->err(table, status, "table->put failed");
                            goto err;
                        }
#endif
                    }
                }
            }
        }
        pcap_close(handle);
    }
    fprintf(stderr,
            "%i DNS packets handled, %i NXDOMAIN responses, a total of %i \"not found\" domains were found\n",
            total_packets, nxdomain_packets, total_domains);
#ifdef HASHTABLE_GLIB
    name_list = g_hash_table_get_keys(table);
    sorted_name_list = g_list_sort(name_list, compare);
    g_list_foreach(sorted_name_list, print_value, NULL);
#elif HASHTABLE_JUDY
    /* TODO: sort */
    key = (uint8_t *) malloc(MAX_KEY);
    value = NULL;
    JSLF(value, table, key);
    while (value != NULL) {
        printf("%s: %u\n", (char *) key, *((unsigned int *) value));
        JSLN(value, table, key);
    }
#elif HASHTABLE_HAMSTERDB
    if ((status = ham_cursor_create(table, NULL, 0, &cursor)) != HAM_SUCCESS) {
        fprintf(stderr, "ham_cursor_create: %d (%s)\n", status,
                ham_strerror(status));
        return (2);
    }
    sorted_array = malloc(total_domains * sizeof(char *));
    for (i = 0; i < total_domains; i++) {
        if ((status =
             ham_cursor_move(cursor, &key, &record,
                             HAM_CURSOR_NEXT)) != HAM_SUCCESS) {
            if (status == HAM_KEY_NOT_FOUND) {
                assert(i < total_domains);
                break;          // reached the end of the database
            } else {
                fprintf(stderr, "ham_cursor_move: %d (%s) for %ith domain\n", status,
                        ham_strerror(status), i);
                return (2);
            }
        } else {
            /* Key found, OK */
        }
        sorted_array[i] = malloc(key.size + 1);
        strcpy(sorted_array[i], (char *) key.data);
    }
    qsort(sorted_array, total_domains, sizeof(char *), compare_by_occurrences);
    for (i = 0; i < total_domains; i++) {
        printf("%s: %i\n", sorted_array[i], occurrences_of(sorted_array[i]));
    }
#elif HASHTABLE_BERKELEYDB
    sorted_array = malloc(total_domains * sizeof(char *));
    if (!sorted_array) {
        perror("malloc failed");
        goto err;
    }
    if ((status = table->cursor(table, 0, &cursor, 0))) {
        table->err(table, status, "table->cursor failed");
        goto err;
    }
    for (i = 0; i < total_domains; i++) {
        if ((status = cursor->get(cursor, &key, &record, DB_NEXT))) {
            table->err(table, status, "cursor->get failed");
            goto err;
        }
        sorted_array[i] = malloc(key.size);
        assert(sorted_array[i]);
        memcpy(sorted_array[i], key.data, key.size);
    }
    assert(cursor->get(cursor, &key, &record, DB_NEXT) == DB_NOTFOUND);
    if ((status = cursor->close(cursor)))
        table->err(table, status, "cursor->close failed, continuing");
    qsort(sorted_array, total_domains, sizeof(char *), compare_by_occurrences);
    for (i = 0; i < total_domains; i++)
        printf("%s: %u requests\n", sorted_array[i],
               occurrences_of(sorted_array[i]));
    if ((status = table->close(table, 0)))
        table->err(table, status, "table->close(...) failed, continuing");
#endif
    return (0);
}
