import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc


def draw(args, lists, stds, legends, basedir, colors, lw, loc, v_name,
         plot=True):
    x_lim = len(lists[0])
    if x_lim >= 1000:
        marker_scale = 10
    else:
        marker_scale = 2
    directory = os.path.dirname(basedir)
    if not os.path.exists(directory):
        os.makedirs(directory)

    with open(basedir + '.txt', 'w') as f:
        for i in range(len(lists[0])):
            f.write(str(i + 1))
            for e in lists:
                f.write('\t' + str(e[i]))
            for e in stds:
                f.write('\t' + str(e[i]))
            f.write('\n')

    if not plot:
        return

    plt.figure(figsize=(9, 6))
    ax = plt.subplot(1, 1, 1)
    font_size = 24
    ax.tick_params(axis='both', which='major', labelsize=font_size)

    ax.axhline(100, lw=6, c='lightgray', ls='--', zorder=0)
    ax.axhline(50, lw=6, c='lightgray', ls='--', zorder=0)

    for i, (l, s, legend) in enumerate(zip(lists, stds, legends)):
        color_index = min(i, len(colors) - 1)
        entries = colors[color_index]
        color, marker = entries
        l1 = [0]
        l1.extend(l)
        l1 = np.asarray(l1)
        s1 = [0]
        s1.extend(s)
        s1 = np.asarray(s1)
        if i % 2 == 0:
            ls = '-'
        else:
            ls = '--'
        ax.plot(l1, lw=lw, markevery=(marker_scale * 10, marker_scale * 20),
                ls=ls,
                marker=marker, markersize=16, markeredgewidth=2,
                markerfacecolor='none', color=color, label=legend)
        ax.fill_between(np.arange(x_lim + 1), l1 - s1, l1 + s1,
                        color=color, alpha=0.2)

    ax.set_xlim([1, x_lim])
    ax.set_ylim([45, 105])
    legend_font_size = 18
    ax.legend(loc=loc, prop={'size': legend_font_size})
    ax.set_xlabel('Inference steps', fontsize=font_size)
    ax.set_ylabel(v_name, fontsize=font_size)
    ax.xaxis.labelpad = 10
    ax.yaxis.labelpad = 1
    plt.savefig(basedir + '.pdf', bbox_inches='tight', pad_inches=0.01)


def get_numbers(args, lines):
    index = -3
    steps = []
    values = []
    for line in lines:
        terms = line.strip().split(' ')
        steps.append(int(terms[0]))
        values.append(float(terms[index]))
    return steps, values


def load(args, name):
    with open(name, 'r') as f:
        lines = f.readlines()
    s = [12, 113]
    original_steps, original_values = get_numbers(args, lines[s[0]:s[0] + 100])
    transfer_steps, transfer_values = get_numbers(args, lines[s[1]:s[1] + 100])
    return original_values, transfer_values, original_values, transfer_values


def get_results(args, path):
    exps = ['A', 'B', 'C', 'D', 'E']

    results = [[], []]
    for e in exps:
        fn = os.path.join(path + e, "log.txt")
        eval1, eval2, eval3, eval4 = load(args, fn)
        results[0].append(eval1)
        results[1].append(eval2)

    for r in results[0]:
        assert len(results[0][0]) == len(r)

    means = []
    stds = []
    for result in results:
        matrix = np.asarray(result)
        means.append(np.mean(matrix, axis=0))
        stds.append(np.std(matrix, axis=0))

    return means, stds


def get_params(args):
    pairs = [
        ('Baseline', 'logs/overlap_proposed_', ('c', 's')),
        ('Proposed', 'logs/overlap_proposed_', ('g', 'o')),
    ]
    file_list = [x[1] for x in pairs]
    legends = [x[0] for x in pairs]
    colors = [x[2] for x in pairs]
    output_list = [
        os.path.join(args.output_dir, 'ter')
    ]
    lw = 2
    loc = 'lower right'
    return file_list, legends, output_list, colors, lw, loc


def main(args):
    file_list, legends, output_list, colors, lw, loc = get_params(args)
    eval1_list = []
    eval2_list = []
    std1_list = []
    std2_list = []
    for fn in file_list:
        means, stds = get_results(args, fn)
        eval1, eval2 = means
        std1, std2 = stds

        eval1_list.append(eval1)
        eval2_list.append(eval2)
        std1_list.append(std1)
        std2_list.append(std2)

    font = {'family': 'serif'}
    rc('font', **font)

    acc_mean2 = [100 * x for x in eval2_list]
    acc_std2 = [100 * x for x in std2_list]

    length = len(acc_mean2[0])

    acc_mean = [[89.5] * length, acc_mean2[0]]
    acc_std = [[0.7] * length, acc_std2[0]]
    draw(args, acc_mean, acc_std, legends, output_list[0], colors, lw, loc,
         'TER (%)')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_dir', type=str, default='outputs',
                        help='Output dir.')
    args = parser.parse_args()
    main(args)
