import os
from cox.readers import CollectionReader
import matplotlib.pyplot as plt
import seaborn as sns

import torch as ch
import numpy as np
import pandas as pd
import argparse
import csv

from helpers import get_classes, load_config, load_files, adjust_if_stl, mix_datasets
from test_statistics import MMD_Stat

config_base = '/PATH/TO/CONFIGS/'
res_dir_base = '/PATH/TO/RESULTS/'
base_out = '/PATH/TO/OUTPUT/'

nice_mapping = {
    '3db': '3DB',
    'stl': 'STL10',
    'pets': 'Oxford-IIIT Pet',
    'cifar10': 'CIFAR10',
    'in': 'ImageNet',
}

good_label_mapping = {
    'active_set': 'Projected',
    'rand_fixed': 'Auxiliary'
}


def compute_mmd_for_eps(targ, aux, eps):
    results = []
    NUM_REPEATS = 5
    all_classes = get_classes('cifar10')
    if targ == 'stl' or aux == 'stl':
        all_classes = get_classes('stl')
    if targ == 'pets' or aux == 'pets':
        all_classes = get_classes('pets')
    for class_name in all_classes:
        print(f'Computing statistics on class {class_name}')
        res_dir = f'{aux}_{targ}_class_{class_name}_eps_{eps}'
        config_name = f'{aux}_{targ}_class_{class_name}_eps_{eps}'
        res = f'{res_dir_base}/{res_dir}/mmd.pth'

        data = {
            'active_set': [], 
            'rand_fixed': [], 
            'real': []
        }

        config = load_config(config_name)
        result = f"{res_dir_base}/{config_name}/mmd.pth"

        d = ch.load(result)

        # get active set result
        z = d['z']
        ds, _ = load_files(config_base, config.sources,z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), aux)

        data['active_set'] = mixed_ds
        
        # get random result
        z = ch.ones_like(z)
        z = z/(z.size(0))
        ds, _ = load_files(config_base, config.sources, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), aux)
        data['rand_fixed'] = mixed_ds

        # get real result
        z = ch.Tensor([1])
        ds, _ = load_files(config_base, config.targets, z)
        mixed_ds = adjust_if_stl(mix_datasets(ds, z), targ)
        data['real'] = mixed_ds

        stat_mmd = MMD_Stat()

        test_stats = {
            'MMD': stat_mmd,
        }

        for repeats in range(NUM_REPEATS):
            # print(repeats)
            for stat_name, test_stat in test_stats.items():
                # print(stat_name)
                stats = {}
                for subset in ['active_set', 'rand_fixed']:
                    if stat_name == 'MMD':
                        stat = np.abs(test_stat(data[subset], data['real'], alphas=[500], use_new_random_encoders=True).detach().cpu().item())
                        stat = stat * 1e3
                    elif stat_name == 'RFR':
                        stat = np.abs(test_stat(data[subset], data['real'], seed=repeats).detach().cpu().item())
                    else:
                        stat = np.abs(test_stat(data[subset], data['real']).detach().cpu().item())
                    
                    results.append([eps, good_label_mapping[subset], stat, class_name, repeats])

    print('Aggregating stats')

    print('Aggregating stats and plotting')
    df = pd.DataFrame(results, columns = ['Epsilon', 'Dataset', 'Statistic', 'Class', 'Repeat'])
    final_df = df.groupby(['Epsilon', 'Dataset']).agg('mean')
    final_res = final_df.add_suffix('').reset_index()
    avg_aux_mmd = final_res[final_res['Dataset'] == 'Auxiliary'].Statistic.mean()
    avg_proj_mmd = final_res[final_res['Dataset'] == 'Projected'].Statistic.mean()
    return avg_aux_mmd, avg_proj_mmd

parser = argparse.ArgumentParser(description='Compute best eps')

parser.add_argument('--targ', dest='targ', type=str)
parser.add_argument('--aux', dest='aux', type=str)

def main():
    args = parser.parse_args()
    targ = args.targ
    aux = args.aux
    print(f'Computing best eps for target {targ} and auxiliary dataset {aux}')
    all_eps = [1, 0.1, 0.01, 0.001, 0.0001]
    
    min_eps = -1
    min_proj_mmd = np.inf

    all_avg_aux_mmds = []
    vals = []
    for eps in all_eps:
        avg_aux_mmd, avg_proj_mmd = compute_mmd_for_eps(targ, aux, eps)
        vals.append([targ, aux, eps, avg_proj_mmd, avg_aux_mmd])
        # print('Aux: ', avg_aux_mmd, 'Proj: ', avg_proj_mmd)
        all_avg_aux_mmds.append(avg_aux_mmd)
        if avg_proj_mmd < min_proj_mmd:
            min_proj_mmd = avg_proj_mmd
            min_eps = eps

    with open('/PATH/TO/RESULTS/best_eps.csv', 'a+', newline='') as f:
        writer = csv.writer(f)
        writer.writerows(vals)

    print(f'Best epsilon is {min_eps} with MMD distance {min_proj_mmd}')
    actual_avg_aux_mmd = np.mean(all_avg_aux_mmds)
    print(f'And here, the random baseline has MMD distance {actual_avg_aux_mmd}')
    # vals = [[targ, aux, min_eps, min_proj_mmd, actual_avg_aux_mmd]]

if __name__ == "__main__":
    main()
