#!/usr/bin/env python

import sys
import matplotlib.pyplot

import dpa_traces
import dpa_des


# Improve the speed of the attack by only considering part of the trace
start_sample = 4000
end_sample = 8000 # 20002
num_samples = end_sample - start_sample + 1


# accumulator[sbox][key_hyp][hd][sample]
accumulator = [[[[0.0] * num_samples for i in range(5)] for j in range(64)] for k in range(8)]

# acc_num_traces[sbox][key_hyp][hd]
acc_num_traces = [[[0] * 5 for i in range(64)] for j in range(8)]

# mean_traces[sbox][key_hyp][hd][sample]
mean_traces = [[[[0.0] * num_samples for i in range(5)] for j in range(64)] for k in range(8)]

# diff_traces[sbox][key_hyp][sample]
diff_traces = [[[0.0] * num_samples for i in range(64)] for j in range(8)]

# max_diff_traces[sbox][key_hyp]
max_diff_traces = [[0.0] * 64 for i in range(8)]


# Display a trace
def display_trace(data):
    matplotlib.pyplot.plot(data)
    matplotlib.pyplot.show()


# Display all the 64 differential traces for a given SBox
# and highlight one of them (in green)
def display_diff_trace(sbox, highlighted_key):
    fig = matplotlib.pyplot.figure()
    ax = fig.add_subplot(111)

    for key in range(64):
        if key == highlighted_key:
            ax.plot(diff_traces[sbox][key], 'g')
        else:
            ax.plot(diff_traces[sbox][key], 'r')

    matplotlib.pyplot.show()


# For a given SBox (0..7) and a given key hypothesis (0..63),
# compute the leakage model (here a HD on 4 bits) for the 
# encryption of the given plaintext (trace.plaintext) and 
# accumulate the trace in the corresponding set:
# accumulator[sbox][key_hypothesis][hd]
def accumulate(trace, sbox, key_hypothesis):
    global accumulator, acc_num_traces, start_sample, end_sample

    # TODO
    pass

# When the accumulation of all the traces is done, compute the mean
# of all the sets of traces for all sbox and key hypothesis:
# mean_traces[sbox][key][hd]
def compute_mean_traces():
    global mean_traces, accumulator, acc_num_traces, num_samples

    # TODO
    pass

# For all Sbox and key hypothesis, compute the differential trace:
# diff_traces[sbox][key]
def compute_diff_traces():
    global diff_traces, mean_traces, num_samples

    # TODO
    pass

# For all Sbox and key hypothesis, find the maximum of the differential trace:
# max_diff_traces[sbox][key]
def compute_max_diff_traces():
    global max_diff_traces, diff_traces, num_samples

    # TODO
    pass

# Display the result (for all sbox, find and display which key hypothesis gives
# the maximum value of the differential trace)
def display_key_results():
    global max_diff_traces

    # TODO
    pass

def main():
    if len(sys.argv) < 3:
        sys.exit("Usage: dpa.py trace_directory number_of_traces_to_use")

    trace_dir = dpa_traces.Trace_dir(sys.argv[1])
    
    num_traces = int(sys.argv[2])

    assert(num_traces > 0)

    trace_num = 0
    for trace in trace_dir:
        trace_num += 1
        if trace_num % 20 == 0:
            print ('Trace #%i...' % (trace_num))

#        display_trace(trace.samples)

        for sbox in range(8):
            for key_hypothesis in range(64):
                accumulate(trace, sbox, key_hypothesis)

        if trace_num == num_traces:
            compute_mean_traces()
            compute_diff_traces()
            compute_max_diff_traces()
            display_key_results()
            break



if __name__ == "__main__":
    main()
