from argparse import ArgumentParser
from xmeta.maml.maml import setup_experiment, explain_test_performance
from xmeta.utils.data import ShuffledTaskset
from xmeta.utils.experiment import plot_score_dist
import pickle
import os
import sys
import torch
import pandas as pd
from torch import nn


def _series_correlation(s0, s1, plot_data=False):
    v_0 = pd.Series(s0['train_task_score'], index=s0['train_task_idx']).sort_index()
    v_1 = pd.Series(s1['train_task_score'], index=s1['train_task_idx']).sort_index()
    # if plot_data:
    #     plt.scatter(v_0.values, v_1.values)
    return v_0.corr(v_1)


def _df_correlation(df0, df1):
    assert len(df0) == len(df1)
    df0 = df0.sort_values(['test_task_idx'])
    df1 = df1.sort_values(['test_task_idx'])
    
    correlations = []
    for ii in range(len(df0)):
        correlations.append(
            _series_correlation(df0.iloc[ii], df1.iloc[ii]))

    df = pd.DataFrame({'test_task_idx': df0['test_task_idx'],
                       'score_0': df0['train_task_score'],
                       'score_1': df1['train_task_score'],
                       'corr': correlations})
    return df


def _get_correlations(explainer_path_exact, explainer_path_gn, n_positive_ev,
                      num_tasks, test_with_train_tasks=False):
    _device = torch.device('cuda')
    loss = nn.CrossEntropyLoss(reduction='mean')
    experiment_dir = os.path.dirname(explainer_path_exact)
    
    # ---
    fft_crop_size = 6
    ways = 5
    shots = 5
    num_test_tasks = 128

    def _preprocess(x, cuda=False):
            d, lbl = x
            x = [feature(d).to(_device), lbl.to(_device)]
            return x
    # ----
    if test_with_train_tasks:
        pkl_name_exact = 'df_exact_ttt_'
        pkl_name_opa = 'df_opa_ttt_'
    else:
        pkl_name_exact = 'df_exact_'
        pkl_name_opa = 'df_opa_' 

    pkl_name_exact = (pkl_name_exact + str(n_positive_ev) + '_' +
                      os.path.basename(explainer_path_exact))
    pkl_name_opa = (pkl_name_opa +
                    os.path.basename(explainer_path_gn))
    pkl_path_exact = os.path.join(experiment_dir, pkl_name_exact)
    pkl_path_opa = os.path.join(experiment_dir, pkl_name_opa)

    if not os.path.exists(pkl_path_exact):
        tasks_train, tasks_test, explainer_exact, maml, feature, impurity_dict =\
            setup_experiment(ways=ways,
                             shots=shots,
                             num_tasks=num_tasks,
                             experiment_dir=experiment_dir,
                             explainer_path=explainer_path_exact,
                             fft_crop_size=fft_crop_size,
                             dataset='omniglot'
                            )
        explainer_exact.set_src_generalized_matrix(n_positive_ev=n_positive_ev)

        if test_with_train_tasks:
            tasks_test = ShuffledTaskset(tasks_train)
        
        df_exact = explain_test_performance(explainer_exact, tasks_train, tasks_test,
                                            preprocess=_preprocess,
                                            loss=loss,
                                            shots=shots,
                                            ways=ways,
                                            num_train_task=num_tasks,
                                            num_test_task=num_test_tasks
                                            )
        if test_with_train_tasks:
            df_exact['orig_test_task_idx'] =\
                df_exact['test_task_idx'].apply(lambda x: tasks_test.idxes[x])
            df_exact['self_rank'] = df_exact.apply(
                lambda row: row['train_task_idx'].index(row['orig_test_task_idx']),
                axis=1)
            print(f'self_rank(mean): {df_exact["self_rank"].mean()} '
                  f'self_rank(std): {df_exact["self_rank"].std()}')
        
        with open(pkl_path_exact, 'wb') as f:
            pickle.dump(df_exact, f)
        print(f'saved {pkl_path_exact}')
    else:
        with open(pkl_path_exact, 'rb') as f:
            df_exact = pickle.load(f)
        print(f'loaded {pkl_path_exact}')
    
    if not os.path.exists(pkl_path_opa):
        tasks_train, tasks_test, explainer_opa, maml, feature, impurity_dict =\
            setup_experiment(ways=ways,
                             shots=shots,
                             num_tasks=num_tasks,
                             experiment_dir=experiment_dir,
                             explainer_path=explainer_path_gn,
                             fft_crop_size=fft_crop_size,
                             dataset='omniglot'
                             )
        
        if test_with_train_tasks:
            tasks_test = ShuffledTaskset(tasks_train)
        
        df_opa = explain_test_performance(explainer_opa, tasks_train, tasks_test,
                                          preprocess=_preprocess,
                                          loss=loss,
                                          shots=shots,
                                          ways=ways,
                                          num_train_task=num_tasks,
                                          num_test_task=num_test_tasks
                                          )
        if test_with_train_tasks:
            df_opa['orig_test_task_idx'] =\
                df_opa['test_task_idx'].apply(lambda x: tasks_test.idxes[x])
            df_opa['self_rank'] = df_opa.apply(
                lambda row: row['train_task_idx'].index(row['orig_test_task_idx']),
                axis=1)
            print(f'self_rank(mean): {df_opa["self_rank"].mean()} '
                  f'self_rank(std): {df_opa["self_rank"].std()}')
        
        with open(pkl_path_opa, 'wb') as f:
            pickle.dump(df_opa, f)
        print(f'saved {pkl_path_opa}')
    else:
        with open(pkl_path_opa, 'rb') as f:
            df_opa = pickle.load(f)
        print(f'loaded {pkl_path_opa}')

    return _df_correlation(df_exact, df_opa)


def main(args):
    params_exact = args.params_exact.split(',')
    params_gn = args.params_exact.split(',')
    save_dir = os.path.dirname(args.explainer_exact)
    if args.test_with_train_tasks:
        mean_csv_path = os.path.join(save_dir, 'df_ttt_corr_mean.csv')
        std_csv_path = os.path.join(save_dir, 'df_ttt_corr_std.csv')
    else:
        mean_csv_path = os.path.join(save_dir, 'df_corr_mean.csv')
        std_csv_path = os.path.join(save_dir, 'df_corr_std.csv')

    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))
    
    # filename_exact = os.path.basename(args.explainer_exact)
    # filename_base_gn = os.path.basename(args.explainer_base_gn)
    
    mean_matrix = {}
    std_matrix = {}
    for _param_ex in params_exact:
        print(f'explainer_exact: {args.explainer_exact}   param: {_param_ex}')
        _means = []
        _stds = []
        for _param_gn in params_gn:
            explainer_path_gn = args.explainer_base_gn.format(_param_gn)
            print('explainer_gn: ' + explainer_path_gn)
            _df_corr = _get_correlations(args.explainer_exact,
                                         explainer_path_gn,
                                         n_positive_ev=int(_param_ex),
                                         num_tasks=args.num_tasks,
                                         test_with_train_tasks=args.test_with_train_tasks
                                         )
            _means.append(_df_corr['corr'].mean())
            _stds.append(_df_corr['corr'].std())
        mean_matrix[_param_ex] = _means
        std_matrix[_param_ex] = _stds

    df_mean = pd.DataFrame(mean_matrix, index=params_gn)
    df_std = pd.DataFrame(std_matrix, index=params_gn)
    print(df_mean)
    print(df_std)
    df_mean.to_csv(mean_csv_path, header=True, index=True)
    df_std.to_csv(std_csv_path, header=True, index=True)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--num-tasks', type=int, default=None)
    parser.add_argument('--explainer-exact', type=str, default=None)
    parser.add_argument('--explainer-base-gn', type=str, default=None)
    parser.add_argument('--params-exact', type=str, default='64,128,256,512,1024')
    parser.add_argument('--params-gn', type=str, default='64,128,256,512,1024')
    parser.add_argument('--test-with-train-tasks', action='store_true')
    args = parser.parse_args()
    main(args)

