import matplotlib.pyplot as plt
import json
import argparse

def plot(d, prop='test_score', loc=None, coord=None, ylim=None):
    # Setup
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')
    plt.rc('font', size=16)
    
    # Pure Adaboost
    plt.plot([3, 29], [d['Pure Adaboost'][prop]]*2, label="$\\textsc{AdaBoost}$", color='blue')

    # Overlap Majority
    if type(d['Overlap Majority']) is list:
        x = [d2['n_voting_classifiers'] for d2 in d['Overlap Majority']]
        y = [d2[prop] for d2 in d['Overlap Majority']]
    else:
        x = [3, 29]
        y = [d['Overlap Majority'][prop]]*2
    plt.plot(x, y, label="$\\textsc{LarsenRitzert}$", color='red')

    # Disjoint Majority
    x = [d2['n_voting_classifiers'] for d2 in d['Disjoint Majority']]
    y = [d2[prop] for d2 in d['Disjoint Majority']]
    plt.plot(x, y, label="$\\textsc{Majority-of-X}$", color='orange')

    # Bagging Majority
    x = [d2['n_voting_classifiers'] for d2 in d['Bagging Majority']]
    y = [d2[prop] for d2 in d['Bagging Majority']]
    plt.plot(x, y, label="$\\textsc{BaggedAdaBoost}$", color='green')

    plt.xticks([3, 5, 11, 15, 21, 29])
    plt.xlabel('Number of Voting Classifiers')
    plt.ylabel('Test Accuracy')
    if ylim is not None: plt.ylim(ylim)
    plt.grid()
    
    if coord is not None:
        plt.legend(loc='lower left', bbox_to_anchor=coord)
    elif loc is not None:
        plt.legend(loc=loc)
    else:
        plt.legend(loc='best')

    plt.show()

def create_avg(file_name):
    with open(f'results/{file_name}1.json') as f:
        d = json.load(f)
    d['random_seed'] = list(range(1, 6))
    for seed in range(2, 6):
        with open(f'results/{file_name}{seed}.json') as f:
            d2 = json.load(f)
        for alg in ['Disjoint Majority', 'Bagging Majority', 'Overlap Majority', 'Pure Adaboost']:
            if type(d[alg]) is list:
                for i in range(len(d[alg])):
                    d[alg][i]['fit_time'] += d2[alg][i]['fit_time']
                    d[alg][i]['sample_score'] += d2[alg][i]['sample_score']
                    d[alg][i]['test_score'] += d2[alg][i]['test_score']
            else:
                d[alg]['fit_time'] += d2[alg]['fit_time']
                d[alg]['sample_score'] += d2[alg]['sample_score']
                d[alg]['test_score'] += d2[alg]['test_score']
    for alg in ['Disjoint Majority', 'Bagging Majority', 'Overlap Majority', 'Pure Adaboost']:
        if type(d[alg]) is list:
            for i in range(len(d[alg])):
                d[alg][i]['fit_time'] /= 5
                d[alg][i]['sample_score'] /= 5
                d[alg][i]['test_score'] /= 5
        else:
            d[alg]['fit_time'] /= 5
            d[alg]['sample_score'] /= 5
            d[alg]['test_score'] /= 5
    with open(f'results/{file_name}_avg.json', 'w') as f:
        json.dump(d, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        prog="python plot.py",
        description="Compute average of experiments and plot results",
    )
    parser.add_argument("dataset", type=str, help="Dataset to plot the results of")
    parser.add_argument("-a", "--average", action="store_true", help="Compute average and plot this")
    parser.add_argument('-l', '--loc', type=str, help='Specify location of legend e.g. "upper left" (incompatiple with --coord)')
    parser.add_argument('-c', '--coord', type=float, nargs=2, help='Provide 2 integers to specify coordinates of legend (incompatiple with --loc)')
    parser.add_argument('-y', '--ylim', type=float, nargs=2, help='Set y-axis limits')
    args = parser.parse_args()
    if args.average:
        create_avg(args.dataset)
        with open(f'results/{args.dataset}_avg.json') as f:
            d = json.load(f)
    else:
        with open(f'results/{args.dataset}.json') as f:
            d = json.load(f)
    plot(d, loc=args.loc, coord=args.coord, ylim=args.ylim)