#!/usr/bin/env python

def ROL64(a, n):
    return ((a >> (64-(n%64))) + (a << (n%64))) % (1 << 64)

def KeccakP1600onLanes(lanes, nrRounds):
    R=1
    for round in range(24):
        if (round + nrRounds >= 24):
            # θ
            C = [lanes[x ][0] ^ lanes [x ][1] ^ lanes [x ][2] ^ lanes [x ][3] ^ lanes [x ][4] for x in range(5)]
            D = [C[(x+4)%5] ^ ROL64(C[(x+1)%5], 1) for x in range(5)]
            lanes = [[ lanes [x ][ y]^D[x] for y in range(5)] for x in range(5)]
            # π
            (x, y) = (1, 0)
            current = lanes[x ][ y]
            for t in range(24):
                (x, y) = (y, (2*x+3*y)%5)
                (current, lanes [x ][ y]) = (lanes[x ][ y ], ROL64(current, (t+1)*(t+2)//2))
            # χ
            for y in range(5):
                T = [lanes[x ][ y] for x in range(5)]
                for x in range(5):
                    lanes [x][y] = T[x] ^ ((~T[(x+1)%5]) & T[(x+2)%5])
            # ι
            for j in range(7):
                R = ((R << 1) ^ ((R >> 7)*0x71)) % 256
                if (R & 2):
                    lanes [0][0] = lanes [0][0] ^ (1 << ((1<<j)-1))
        else:
            for j in range(7):
                R = ((R << 1) ^ ((R >> 7)*0x71)) % 256
    return lanes

def load64(b):
    return sum((b[i] << (8*i)) for i in range(8))

def store64(a):
    return bytes((a >> (8*i)) % 256 for i in range(8))

def KeccakP1600(state, nrRounds):
    lanes = [[load64(state[8*(x+5*y):8*(x+5*y)+8]) for y in range(5)] for x in range(5)]
    lanes = KeccakP1600onLanes(lanes, nrRounds)
    state = b'' .join([ store64(lanes [x ][ y]) for y in range(5) for x in range(5)])
    return bytearray(state)

def F(inputBytes, delimitedSuﬃx, outputByteLen):
    outputBytes = b''
    state = bytearray([0 for i in range(200)])
    rateInBytes = 1344//8
    blockSize = 0
    inputOffset = 0
    # === Absorb all the input blocks ===
    while(inputOffset < len(inputBytes)):
        blockSize = min(len(inputBytes)-inputOffset, rateInBytes)
        for i in range(blockSize):
            state [ i ] = state[ i ] ^ inputBytes[i+inputOffset]
        inputOffset = inputOffset + blockSize
        if (blockSize == rateInBytes):
            state = KeccakP1600(state, 12)
            blockSize = 0
    # === Do the padding and switch to the squeezing phase ===
    state [blockSize] = state[blockSize] ^ delimitedSuﬃx
    if ((( delimitedSuﬃx & 0x80) != 0) and (blockSize == (rateInBytes-1))):
        state = KeccakP1600(state, 12)
    state [rateInBytes-1] = state[rateInBytes-1] ^ 0x80
    state = KeccakP1600(state, 12)
    # === Squeeze out all the output blocks ===
    while(outputByteLen > 0):
        blockSize = min(outputByteLen, rateInBytes)
        outputBytes = outputBytes + state[0:blockSize]
        outputByteLen = outputByteLen - blockSize
        if (outputByteLen > 0):
            state = KeccakP1600(state, 12)
    return outputBytes

def length_encode(x):
    S = b''
    while(x > 0):
        S = bytes([x % 256]) + S
        x = x//256
    S = S + bytes([len(S)])
    return S

def KangarooTwelve(inputMessage, customizationString, outputByteLen):
    B = 8192
    c = 256
    S = inputMessage + customizationString + length_encode(len(customizationString))
    # === Cut the input string into chunks of B bytes ===
    n = (len(S)+B-1)//B
    Si = [bytes(S[i*B:(i+1)*B]) for i in range(n)]
    if (n == 1):
        # === Process the tree with only a final node ===
        return F(Si[0], 0x07, outputByteLen)
    else:
        # === Process the tree with kangaroo hopping ===
        CVi = [F(Si[i+1], 0x0B, c//8) for i in range(n-1)]
        NodeStar = Si[0] + b'\x03\x00\x00\x00\x00\x00\x00\x00' + b''.join(CVi) + length_encode(n-1) + b'\xFF\xFF'
        return F(NodeStar, 0x06, outputByteLen)
