import argparse
import os
import json
import math
import statistics
import pdb
from pprint import pprint

import scipy.stats
import matplotlib.pyplot as plt

from plot_learning_curve import plot_slug_record
from summarize_experiments import ExperimentDirectoryFinder



def make_grid(grouped_records, datasets=None, score_types=None, scoretype2label_fn=None, dataset2label_fn=None, rec2val_fn=None, skip_empty=True, plot_output_prefix=None):
    if not scoretype2label_fn:
        scoretype2label_fn = lambda x: x
    if not dataset2label_fn:
        dataset2label_fn = lambda x: x
    if not rec2val_fn:
        rec2val_fn = lambda rec: {'mean': rec['test_trapauc_mean']*100, 'std': rec['test_trapauc_std']/math.sqrt(rec['num_trials'])*100, 'all': [x*100 for x in rec['all_test_aucs']]}

    P_THRESHOLD = 0.05

    avail_score_types = {st for st in score_types if any(st in grouped_records.get(dataset, dict()) for dataset in datasets)}

    if skip_empty:
        score_types = [s for s in score_types if s in avail_score_types]

    grid = [['' for x in range(len(score_types)+1)] for y in range(len(datasets)+1)]

    grid[0][0] = 'Dataset {\\textbackslash} Method'

    for x, score_type in enumerate(score_types):
        grid[0][x+1] = scoretype2label_fn(score_type)

    col_aggregates = [[] for _ in range(len(score_types))]

    for y, dataset in enumerate(datasets):
        grid[y+1][0] = dataset_translation_key[dataset]
        if plot_output_prefix:
            plottedstuff = False
            plt.clf()
            plt.rc('font', size=14)
            plt.rc('axes', titlesize=16, labelsize=16)
            plt.title(dataset)
            plt.xlabel('Dataset size')
            plt.ylabel('Accuracy')
        for x, score_type in enumerate(score_types):
            if dataset in grouped_records and score_type in grouped_records[dataset]:
                if plot_output_prefix:
                    plottedstuff = True
                    plot_slug_record(('experiment_v{}' if int(recs[0]['experiment_id']) >= 1085 else 'multichoice_swag_auto{}').format(recs[0]['experiment_id']), finder, tolerate_errors=True)
                recs = grouped_records[dataset][score_type]
                vals = [rec2val_fn(rec) for rec in recs]

                if all(st in grouped_records[dataset] for st in avail_score_types):
                    col_aggregates[x].append(vals[0]['mean'])

                beats_random = False
                stat_sig_beats_random = False
                random_diff = float('nan')
                if 'random' in grouped_records[dataset]:

                    beats_random = vals[0]['mean'] > rec2val_fn(grouped_records[dataset]['random'][0])['mean']
                    random_diff = (vals[0]['mean'] - rec2val_fn(grouped_records[dataset]['random'][0])['mean'])
                    if beats_random:
                        stat_sig_beats_random = (scipy.stats.ttest_ind(vals[0]['all'], rec2val_fn(grouped_records[dataset]['random'][0])['all'], equal_var=False).pvalue < P_THRESHOLD)
                #grid[y+1][x+1] = ', '.join(
                #    '{ssig}{beats}{mean:0.1f} ({std:0.1f}) (n={ntrials}, e={exid})'.format(
                #        ssig='*' if stat_sig_beats_random else ' ',
                #        beats=('*' if beats_random else ' '),
                #        mean=val['mean'],
                #        std=val['std'],
                #        rdiff=random_diff*100,
                #        exid=rec['experiment_id'],
                #        ntrials=rec['num_trials'])
                #    for rec, val in zip(recs, vals))

                grid[y+1][x+1] = ', '.join(
                    '{ssig}{beats}{mean:0.1f} ({std:0.1f})}}'.format(
                        ssig='\\cellcolor{gray!32}' if stat_sig_beats_random else '',
                        beats='\\textbf{' if beats_random else '{',
                        mean=val['mean'],
                        std=val['std'],
                        rdiff=random_diff*100,
                        exid=rec['experiment_id'],
                        ntrials=rec['num_trials'])
                    for rec, val in zip(recs, vals))
        if plot_output_prefix and plottedstuff:
            plt.legend(loc='lower right')
            plt.tight_layout()
            plt.savefig('{}_{}.pdf'.format(plot_output_prefix, dataset))

    avg_row = ['Average']
    rand_col_idx = score_types.index('random') if 'random' in score_types else -1
    for idx, col in enumerate(col_aggregates):
        if len(col) == 0:
            avg_row.append('')
            continue
        beats_random = False
        if not rand_col_idx < 0:
            assert len(col) == len(col_aggregates[rand_col_idx])
            rdiffs = [col[i] - col_aggregates[rand_col_idx][i] for i in range(len(col))]
            beats_random = statistics.mean(col) > statistics.mean(col_aggregates[rand_col_idx])
            stat_sig_beats_random = (scipy.stats.ttest_rel(col, col_aggregates[rand_col_idx]).pvalue < P_THRESHOLD) and beats_random
        avg_row.append('{ssig}{beats}{mean:0.2f}}}'.format(
            ssig='\\cellcolor{gray!32}' if stat_sig_beats_random else '',
            beats='\\textbf{' if beats_random else '{',
            mean=statistics.mean(col),
        ))
    grid.append(avg_row)

    return grid


def table2str(grid, colsep='&', rowend='  \\\\'):
    s = ''

    col_widths = [max(len(grid[y][x]) for y in range(len(grid))) for x in range(len(grid[0]))]
    for y, row in enumerate(grid):
        if all(cell == '' for cell in row[1:]):
            continue
        s += '    '
        s += ' {} '.format(colsep).join(['{text:{width}s}'.format(width=col_widths[x], text=cell) for x, cell in enumerate(row)])
        s += '{}\n'.format(rowend)
        if y == 0:
            s += '    \\hline\n'
    return s


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiments_dir', required=True)
    parser.add_argument('--records_file', required=True)
    args = parser.parse_args()

    finder = ExperimentDirectoryFinder(args.experiments_dir)

    with open(args.records_file) as f:
        records = [json.loads(line) for line in f]

    with open('runconfigs/data_paths.json') as f:
        data_paths = json.load(f)

    dataset_translation_key = {
        os.path.relpath(v, finder.base_dir): k
        for k, v in data_paths.items()
    }
    dataset_translation_key = {k: v.lower().replace('alphanli', 'anli') for k, v in dataset_translation_key.items()}

    score_translation_key = {
            'greedy_coreset': 'Coreset',
            'random': 'Random',
            'entropy': 'Entropy',
            'least_confident': 'LC',
            'bald': 'BALD-MC',
            'batchbald': 'BatchBALD-MC',
            'alps': 'ALPS',
    }

    model_keys = set()
    datasets = set()
    score_types = set()
    grouped_records = dict()
    for rec in records:

        modelkey = (rec['model'], 'frozen' if rec['freeze_core_weights'] else 'unfrozen', rec['al_batch_size'], rec.get('num_train_retries', 1))
        key = (modelkey, rec['dataset'], rec['al_score_type'])

        cur_rec = grouped_records
        for keypart in key[:-1]:
            if keypart not in cur_rec:
                cur_rec[keypart] = dict()
            cur_rec = cur_rec[keypart]

        if key[-1] not in cur_rec:
            cur_rec[key[-1]] = []
        cur_rec[key[-1]].append(rec)

        model_keys.add(modelkey)
        datasets.add(key[1])
        score_types.add(key[2])

    print(dataset_translation_key)
    model_keys = sorted(list(model_keys))
    datasets = sorted(list(datasets), key=lambda x: dataset_translation_key[x].lower())
    score_types = sorted(list(score_types))

    def quickgrid(modelkey, **kwargs):
        plot_output_prefix = None

        _kwargs = {
                'datasets': datasets,
                'score_types': score_types,
                'scoretype2label_fn': lambda x: score_translation_key.get(x, x),
                'dataset2label_fn': dataset_translation_key.get,
                'rec2val_fn': lambda rec: {'mean': rec['test_trapauc_mean']*100, 'std': rec['test_trapauc_std']/math.sqrt(rec['num_trials'])*100, 'all': [x*100 for x in rec['all_test_aucs']]},
                'skip_empty': True,
                'plot_output_prefix': plot_output_prefix,
            }
        _kwargs.update(kwargs)
                

        return make_grid(
                grouped_records[modelkey],
                **_kwargs,
            )

    # All tables with everything included
    for modelkey in model_keys:
        print()
        print(modelkey)
        print(table2str(quickgrid(modelkey)))

    multichoice_dsets = [x.lower() for x in ['SWAG', 'HellaSWAG', 'PIQA', 'aNLI', 'CSQA', 'CODAH']]
    cls_ablation_dsets = [x + '-c' for x in multichoice_dsets]


    ##################
    # PAPER TABLES
    ##################

    # Roberta-base normal
    modelkey = ('roberta-base', 'unfrozen', 25, 1)
    print()
    print('Roberta-base normal', modelkey)
    print(
        table2str(
            quickgrid(
                modelkey,
                datasets=[
                    d for d in datasets
                    if (not any(dataset_translation_key[d].startswith(x+'-') for x in multichoice_dsets))])))


    # Roberta-large normal
    modelkey = ('roberta-large', 'unfrozen', 25, 1)
    print()
    print('Roberta-large normal', modelkey)
    print(
        table2str(
            quickgrid(
                modelkey,
                datasets=[
                    d for d in datasets
                    if (not any(dataset_translation_key[d].startswith(x+'-') for x in multichoice_dsets))])))


    # Roberta-large convex hull
    modelkey = ('roberta-large', 'unfrozen', 25, 1)
    print()
    print('Roberta-large convex hull', modelkey)
    print(
        table2str(
            quickgrid(
                modelkey,
                rec2val_fn=lambda rec: {'mean': rec['test_convex_auc_mean']*100, 'std': rec['test_convex_auc_std']/math.sqrt(rec['num_trials'])*100, 'all': [x*100 for x in rec['all_test_convex_aucs']]},
                datasets=[
                    d for d in datasets
                    if (not any(dataset_translation_key[d].startswith(x+'-') for x in multichoice_dsets))])))


    # Roberta-base retrain
    modelkey = ('roberta-base', 'unfrozen', 25, 5)
    print()
    print('Roberta-base retrain', modelkey)
    retrain_score_types = ['batchbald', 'random']
    retrain_datasets = [
        d for d in datasets
        if (not any(dataset_translation_key[d].startswith(x+'-') for x in multichoice_dsets)
            and all(st in grouped_records[modelkey].get(d, dict()) for st in retrain_score_types))]
    grid_retrain = quickgrid(
        modelkey,
        score_types=retrain_score_types,
        datasets=retrain_datasets)
    normal_modelkey = modelkey[:-1] + (1,)
    grid_normal = quickgrid(
        normal_modelkey,
        score_types=['batchbald', 'random'],
        scoretype2label_fn=lambda x: score_translation_key.get(x, x) + ' (orig)',
        datasets=retrain_datasets,
        skip_empty=False)
    assert len(grid_normal) == len(grid_retrain)
    # Strip row labels
    grid_normal = [x[1:] for x in grid_normal]
    # Combine
    grid = [grid_retrain[y] + grid_normal[y] for y in range(len(grid_retrain))]
    print(table2str(grid))

    pairs = []
    for dataset in retrain_datasets:
        for score_type in retrain_score_types:
            if (dataset not in grouped_records[modelkey]
                or score_type not in grouped_records[modelkey][dataset]
                or dataset not in grouped_records[normal_modelkey]
                or score_type not in grouped_records[normal_modelkey][dataset]): 
                continue
            val_retrain = grouped_records[modelkey][dataset][score_type][0]['test_trapauc_mean']
            val_normal = grouped_records[normal_modelkey][dataset][score_type][0]['test_trapauc_mean']
            pairs.append((val_retrain, val_normal))
    print('p-val for retrain vs normal overall: {}'.format(scipy.stats.ttest_rel(*zip(*pairs)).pvalue))


    # Bert-base normal
    modelkey = ('bert-base-uncased', 'unfrozen', 25, 1)
    print()
    print('Bert-base normal', modelkey)
    print(
        table2str(
            quickgrid(
                modelkey,
                datasets=[
                    d for d in datasets
                    if (not any(dataset_translation_key[d].startswith(x+'-') for x in multichoice_dsets))])))

