#!/usr/bin/env python3

import sys
import matplotlib.pyplot
import numpy

from traces import *
from tqdm import tqdm
from des import des_first_round_hd_4bits

# start_sample and end_sample define the observation window. The original
# traces have a length of 20002 samples. If you can locate the leakage more
# precisely, you can reduce the window even more in order to speed up the
# attack.
start_sample = 4000
end_sample   = 8000 # 20002
num_samples = end_sample - start_sample + 1


# ==============================================================================
# Global data structures
# ==============================================================================

# accumulator[sbox][key_hyp][hd]
accumulator = [[[numpy.zeros(num_samples) for i in range(5)] for j in range(64)] for k in range(8)]

# acc_num_traces[sbox][key_hyp][hd]
acc_num_traces = [[[0] * 5 for i in range(64)] for j in range(8)]

# mean_traces[sbox][key_hyp][hd]
mean_traces = [[[numpy.zeros(num_samples) for i in range(5)] for j in range(64)] for k in range(8)]

# diff_traces[sbox][key_hyp]
diff_traces = [[numpy.zeros(num_samples) for i in range(64)] for j in range(8)]

# max_diff_traces[sbox][key_hyp]
max_diff_traces = [[0.0] * 64 for i in range(8)]

# guessed_key[sbox]
guessed_key = [0 for i in range(8)]


# ==============================================================================
# Helper functions
# ==============================================================================

# Display a trace
def display_trace(data):
    matplotlib.pyplot.plot(data)
    matplotlib.pyplot.show()


# Display all the 64 differential traces for a given SBox
# and highlight one of them (in green)
def display_diff_trace(sbox, highlighted_key):
    fig = matplotlib.pyplot.figure()
    ax = fig.add_subplot(111)

    for key in range(64):
        if key == highlighted_key:
            ax.plot(diff_traces[sbox][key], 'g')
        else:
            ax.plot(diff_traces[sbox][key], 'r')

    matplotlib.pyplot.show()


# =============================================================================
# DPA algorithm. These are the functions you need to implement:
# =============================================================================


# For a given SBox (0..7) and a given key hypothesis (0..63), compute the
# leakage model (here a HD on 4 bits) for the encryption of the given plaintext
# (measurement.plaintext) and accumulate the trace (measurement.trace) in the
# corresponding set: accumulator[sbox][key_hypothesis][hd]
def accumulate(measurement, sbox, key_hypothesis):
    global accumulator, acc_num_traces, start_sample, end_sample

    # TODO: implement me!
    pass


# When the accumulation of all the traces is done, compute the mean
# of all the sets of traces for all sbox and key hypothesis:
# mean_traces[sbox][key][hd]
def compute_mean_traces():
    global mean_traces, accumulator, acc_num_traces

    # TODO: implement me!
    pass


# For all Sbox and key hypothesis, compute the differential trace:
# diff_traces[sbox][key]
def compute_diff_traces():
    global diff_traces, mean_traces, num_samples

    # TODO: implement me!
    pass


# For all Sbox and key hypothesis, find the maximum of the differential trace:
# max_diff_traces[sbox][key]
def compute_max_diff_traces():
    global max_diff_traces, diff_traces, num_samples

    # TODO: implement me!
    pass


# For each sbox, find the key hypothesis gives the maximum value of the
# differential trace and store it in guessed_key[sbox]
def compute_guessed_key():
    global max_diff_traces, guessed_key

    # TODO: implement me!
    pass


# ==============================================================================
# Main analysis
# ==============================================================================

# =============================================================================
# Main function
# =============================================================================
def main():
    if len(sys.argv) < 3:
        sys.exit("Usage: dpa.py trace_directory number_of_traces_to_use")

    trace_dir = Trace_dir(sys.argv[1])
    num_traces = int(sys.argv[2])

    assert(num_traces > 0)

    print ('Partitioning traces...')
    trace_num = 0
    for measurement in tqdm(trace_dir, total=num_traces):
        trace_num += 1

        # display_trace(measurement.trace)

        for sbox in range(8):
            for key_hypothesis in range(64):
                accumulate(measurement, sbox, key_hypothesis)

        if trace_num == num_traces:
            break

    print ('Computing mean traces...')
    compute_mean_traces()
    print ('Computing differential traces...')
    compute_diff_traces()
    print ('Computing maximum samples...')
    compute_max_diff_traces()
    print ('Guessing key...')
    compute_guessed_key()

    print ('\nThe guessed key is:')
    hex_key = ['{:02x}'.format(x) for x in guessed_key]
    print (' '.join(hex_key))

    # Uncomment the following line to display the differential traces for the first SBox
    # display_diff_trace(0, guessed_key[0])


if __name__ == "__main__":
    main()
