import pandas as pd
import seaborn as sns
import torch as ch
import numpy as np
import math
import glob
import os
import matplotlib.pyplot as plt
import json
import argparse
from tqdm import tqdm

from helpers import get_classes, load_config, load_files, shuffle_and_subset, mix_datasets, adjust_if_stl

from test_statistics import MMD_Stat, Luminance_Stat, RFR_Stat, Contrast_Stat
from kernel_tests import MMD
from pdd import dict2namedtuple, subsample

config_base = '/PATH/TO/CONFIGS/'
res_dir_base = '/PATH/TO/RESULTS'
base_out = '/PATH/TO/OUTPUT/'
nice_mapping = {
    '3db': '3DB',
    'stl': 'STL10',
    'pets': 'Oxford-IIIT Pet',
    'cifar10': 'CIFAR10',
    'in': 'ImageNet',
}

good_label_mapping = {
    'active_set': 'Projected',
    'rand_fixed': 'Auxiliary'
}


def plot_metrics(targ, aux):
    results = []
    NUM_REPEATS = 5
    all_classes = get_classes('cifar10')
    if targ == 'stl' or aux == 'stl':
        all_classes = get_classes('stl')
    if targ == 'pets' or aux == 'pets':
        all_classes = get_classes('pets')
    for class_name in all_classes:
        print(f'Computing statistics on class {class_name}')
        res_dir = f'{aux}_{targ}_class_{class_name}_eps_0.001'
        config_name = f'{aux}_{targ}_class_{class_name}_eps_0.001'
        res = f'{res_dir_base}/{res_dir}/mmd.pth'

        data = {
            'active_set': [], 
            'rand_fixed': [], 
            'real': []
        }

        config = load_config(config_name)
        result = f"{res_dir_base}/{config_name}/mmd.pth"

        d = ch.load(result)

        # get active set result
        z = d['z']
        ds, _ = load_files(config_base, config.sources,z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), aux)
        data['active_set'] = mixed_ds
        
        # get random result
        z = ch.ones_like(z)
        z = z/(z.size(0))
        ds, _ = load_files(config_base, config.sources, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), aux)
        data['rand_fixed'] = mixed_ds

        # get real result
        z = ch.Tensor([1])
        ds, _ = load_files(config_base, config.targets, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), targ)
        data['real'] = mixed_ds

        stat_lum = Luminance_Stat()
        stat_rfr = RFR_Stat()
        # stat_mmd = MMD_Stat(dataset='cifar')
        stat_mmd = MMD_Stat()
        stat_con = Contrast_Stat()

        test_stats = {
            'Luminance': stat_lum,
            'RFR': stat_rfr,
            'Contrast': stat_con,
            'MMD': stat_mmd,
        }

        for repeats in range(NUM_REPEATS):
            # print(repeats)
            for stat_name, test_stat in test_stats.items():
                print(stat_name)
                stats = {}
                for subset in ['active_set', 'rand_fixed']:
                    if stat_name == 'MMD':
                        stat = np.abs(test_stat(data[subset], data['real'], alphas=[500], use_new_random_encoders=True).detach().cpu().item())
                        stat = stat * 1e3
                    elif stat_name == 'RFR':
                        stat = np.abs(test_stat(data[subset], data['real'], seed=repeats).detach().cpu().item())
                    else:
                        stat = np.abs(test_stat(data[subset], data['real']).detach().cpu().item())
                    
                    results.append([stat_name, good_label_mapping[subset], stat, class_name, repeats])

    print('Aggregating stats and plotting')

    df = pd.DataFrame(results, columns = ['Test statistic', 'Dataset', 'Statistic', 'Class', 'Repeat'])

    agg_df = df.groupby(['Test statistic', 'Dataset', 'Repeat']).agg('mean')
    agg_res = agg_df.add_suffix('').reset_index()
    # Reindex some stuff so MMD appears at the end of the plot
    orig_indices = np.arange(NUM_REPEATS*2*3)
    mmd_indices = NUM_REPEATS*2*3 + np.arange(NUM_REPEATS*2)
    final_indices = np.concatenate([orig_indices[0:NUM_REPEATS*2*2], mmd_indices, orig_indices[NUM_REPEATS*2*2:]])
    new_res = agg_res.reindex(final_indices)
    sns.set(rc={'figure.figsize':(12, 8)}, font_scale=2.5)

    plt.clf()
    ax=sns.barplot(x='Test statistic', y='Statistic', hue='Dataset', data=new_res, ci=68)
    ax.set_title(f'Projecting {nice_mapping[targ]} onto {nice_mapping[aux]}')
    plt.legend()
    plt.savefig(f'{base_out}/stats_witherr_allclasses_{targ}_{aux}.pdf', bbox_inches='tight')
    plt.savefig(f'{base_out}/TEST{targ}_{aux}.pdf', bbox_inches='tight')


parser = argparse.ArgumentParser(description='Plot stats')

parser.add_argument('--targ', dest='targ', type=str)
parser.add_argument('--aux', dest='aux', type=str)

def main():
    args = parser.parse_args()
    targ = args.targ
    aux = args.aux
    print(f'Computing aggregate statistics for target {targ} and auxiliary dataset {aux}')
    plot_metrics(targ, aux)

if __name__ == "__main__":
    main()
