
import argparse
import matplotlib
matplotlib.rcParams.update({'font.size': 18})
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pandas as pd
import pickle
import sys

parser = argparse.ArgumentParser(description = 'Runs the summary of the analysis of the synthetic results')
parser.add_argument('--mode', type = str, default = '')
parser.add_argument('--comparison', type = str, default = 'all')
args = parser.parse_args()
    
mode = args.mode
comparison = args.comparison

metrics_performance = {'Discovery Rate': 'dr', 'False Discovery Rate': 'fdr'}
metrics_error = {'Found': 'error_found', 'Not Returned': 'error_returned', 'Mixed Blindspots': 'error_composed', 'Mixed non-Errors': 'error_mixed'}
meta = {'average_performance': metrics_performance, 'average_error': metrics_error}
metrics_subset = {'Fraction of Blindspots Covered': 'covered'}

if comparison == 'all':
    methods = {'PlaneSpot': 'blindspots-True', 'Barlow': 'barlow-True', 'Spotlight': 'spotlight', 'Domino': 'domino'}
elif comparison == 'spotlight-hps':
    methods = {'Spotlight - 0.01': 'spotlight', 'Spotlight - 0.02': 'spotlight-alt'}
else:
    print('Bad "comparison" parameter')
    sys.exit(0)

for name in meta:
    metrics = meta[name]
    
    cols = {}
    cols['Method'] = list(methods)
    for metric_name, metric in metrics.items():
        col = []
        for method_name, method in methods.items():

            with open('./Outputs/analysis/{}/{}.pkl'.format(mode, method), 'rb') as f:
                results = pickle.load(f)['Average']['all']
            mean = results[metric][0]
            error = results[metric][1]

            col.append('{} ({})'.format(np.round(mean, 2), np.round(error, 2)))

        cols[metric_name] = col

    df = pd.DataFrame.from_dict(cols)
    df.to_csv('./Outputs/analysis/{}/{}.csv'.format(mode, name), index = False)

# Warning: pulls args from global scope
def plot():
    for method_name, method in methods.items():
        out_dir = './Outputs/analysis/{}/plots/{}'.format(mode, query)
        Path(out_dir).mkdir(parents = True, exist_ok = True)

        with open('./Outputs/analysis/{}/{}.pkl'.format(mode, method), 'rb') as f:
            results = pickle.load(f)[query]

        x = []
        mean = []
        error = []
        for key in results:
            x.append(str(key))
            mean.append(results[key][metric][0])
            error.append(results[key][metric][1])
        x = np.array(x)
        mean = np.array(mean)
        error = np.array(error)

        order = np.argsort(x)
        x = x[order]
        mean = mean[order]
        error = error[order]

        plt.plot(x, mean, label = method_name)
        plt.fill_between(x, np.clip(mean - 1.96 * error, 0, 1), np.clip(mean + 1.96 * error, 0, 1), alpha = 0.1)

    plt.xlabel(query)
    plt.ylabel(metric_name)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig('{}/{}.png'.format(out_dir, metric_name), bbox_inches = 'tight')
    plt.close()
    
for query in ['Num Blindspots', 'Num Dataset Features']:
    for metric_name, metric in metrics_performance.items():
        plot()
        
for query in ['Num Blindspot Features', 'Blindspot uses "relative position"', 'Blindspot uses "texture"', 'Blindspot uses "presence of circle"']:
    for metric_name, metric in metrics_subset.items():
        plot()