import argparse
import os.path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc


def smooth_rdd(data):
    data = np.sign(data) * np.log(np.abs(data) + 1)
    return data


def get_rate(data):
    data = np.sign(data)
    data = (data + 1) / 2
    data = 100 * data
    return data


def get_loss(data):
    return data


def get_acc(data):
    return 100 * data


def get_metric(args, data):
    # 0: train_loss,
    # 1: test_loss,
    # 2: train_acc,
    # 3: test_acc,
    # 4: rdd_train,
    # 5: rdd_all,
    # 6: aa,
    # 7: ab,
    # 8: ca,
    # 9: cb
    if args.metric == 'rdd':
        return smooth_rdd(data[4:6])
    elif args.metric == 'rate':
        return get_rate(data[8:9])
    elif args.metric == 'loss':
        return get_loss(data[0:2])
    elif args.metric == 'acc':
        return get_acc(data[2:4])
    assert False


def average_rdd(data, window):
    original_length = data.shape[-1]
    smoothed_length = original_length // window
    length = smoothed_length * window
    if length < original_length:
        data = np.split(data, [length, -1], -1)[0]
    data = np.reshape(data, [-1, smoothed_length, window])
    data = np.average(data, -1)
    return data


def plot_results(args, mean, std, name, legends, colors, window, labels, font_size=24):
    x_lim = len(mean[0])
    plt.figure(figsize=(9, 6))
    plt.xlim([0, x_lim - 1])
    for d, s, l, c in zip(mean, std, legends, colors):
        upper = d + s
        lower = d - s
        plt.plot(d, label=l, lw=2)
        plt.fill_between(np.arange(x_lim), upper, lower, color=c, alpha=0.2)
    if args.show_legend:
        plt.legend(framealpha=0.5, prop={'size': font_size}, loc='upper left')
    plt.xlabel(labels[0], fontsize=font_size)
    plt.ylabel(labels[1], fontsize=font_size)
    plt.tick_params(axis='both', which='major', labelsize=font_size)
    labels = 10 * np.arange(x_lim // 10)
    plt.xticks(labels, window * labels)

    out_folder = os.path.join(args.output_dir, args.experiment_id)
    os.makedirs(out_folder, exist_ok=True)
    fn = os.path.join(out_folder, name + ".pdf")
    plt.savefig(fn, bbox_inches='tight', pad_inches=0.01)
    plt.clf()


def main(args):
    font = {'family': 'serif'}
    rc('font', **font)

    eid = args.experiment_id

    rdd_list = []
    for i in range(1, 6):
        name = os.path.join(args.log_dir, eid, eid + '_' + str(i), 'rdd.npy')
        data_points = np.load(name)
        data_points = np.transpose(data_points)
        rdd_list.append(get_metric(args, data_points))
    rdd_list = np.asarray(rdd_list)
    mean = np.average(rdd_list, 0)
    std = np.std(rdd_list, 0)
    mean = average_rdd(mean, args.window)
    std = average_rdd(std, args.window)

    if args.metric == 'rdd':
        name = eid + '_rdd'
        labels = ['Iterations', 'UDDR']
    elif args.metric == 'rate':
        name = eid + '_rate'
        labels = ['Iterations', 'Rate (%)']
    elif args.metric == 'loss':
        name = eid + '_loss'
        labels = ['Iterations', 'Loss']
    elif args.metric == 'acc':
        name = eid + '_acc'
        labels = ['Iterations', 'Acc (%)']
    else:
        assert False

    legends = ['Train', 'Alternative']
    colors = ['blue', 'orange']
    plot_results(args, mean, std, name, legends, colors, args.window, labels)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--metric', type=str, default='rdd',
                        help='Metric.')
    parser.add_argument('--log_dir', type=str, default='logs',
                        help='Log directory.')
    parser.add_argument('--experiment_id', type=str, default='natural_dnn',
                        help='Experiment ID.')
    parser.add_argument('--output_dir', type=str, default='outputs',
                        help='Output directory.')
    parser.add_argument('--window', type=int, default=50, help='Window to compute average.')
    parser.add_argument('--show_legend', action='store_true', default=False,
                        help='Show legend.')
    main(parser.parse_args())
