import torch
from torchsummary import summary
from argparse import ArgumentParser
import os
import sys
from xmeta.maml.maml import setup_experiment
import pandas as pd
from xmeta.utils.seed import set_seed
from xmeta.maml.degraded_task import loop_alpha, loop_ratio
import numpy as np


def test_with_degraded_task(explainer, taskset, task_idx,
                            alphas=list(np.linspace(0, 1, 11)),
                            ratios=list(np.linspace(0, 1, 11)),
                            seed=42,
                            preprocess=(lambda x: x)
                            ):

    set_seed(seed)
    ranks, scores, accuracies, errors = loop_alpha(explainer, taskset, task_idx,
                                                   ratio=1, alphas=alphas,
                                                   preprocess=preprocess
                                                   )
    alpha_s = pd.Series(alphas)
    corr_ar = alpha_s.corr(pd.Series(ranks))
    corr_as = alpha_s.corr(pd.Series(scores))
    if max(ranks) == min(ranks):
        ar_immune = True
    else:
        ar_immune = False
    
    if max(scores) == min(scores):
        as_immune = True
    else:
        as_immune = False

    set_seed(seed)
    ranks, scores, accuracies, errors = loop_ratio(explainer, taskset, task_idx,
                                                   ratios=ratios, alpha=1,
                                                   preprocess=preprocess
                                                   )
    ratio_s = pd.Series(ratios)
    corr_rr = ratio_s.corr(pd.Series(ranks))
    corr_rs = ratio_s.corr(pd.Series(scores))
    if max(ranks) == min(ranks):
        rr_immune = True
    else:
        rr_immune = False
    
    if max(scores) == min(scores):
        rs_immune = True
    else:
        rs_immune = False

    return corr_ar, corr_as, corr_rr, corr_rs,\
        ar_immune, as_immune, rr_immune, rs_immune


def main(args):
    if args.dir is None:
        dir_path = os.path.dirname(args.explainer)
    else:
        dir_path = args.dir

    if not os.path.exists('./cache'):
        os.makedirs('./cache')
    
    nums_pos_evs = [int(x.strip()) for x in args.nums_pos_evs.split(',')]

    save_dir = dir_path
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    log_name = os.path.basename(args.explainer)
    log_name = 'stats_' + log_name + '.txt'
    log_path = os.path.join(dir_path, log_name)

    if args.dataset == 'omniglot':
        crop_size = int(np.sqrt(args.k))
        k = crop_size ** 2
    else:
        crop_size = None
        k = args.k
 
    for n_ev in nums_pos_evs:
        csv_name = os.path.basename(args.explainer)
        csv_name = f'corr_nev{n_ev}_' + csv_name + '.csv'
        csv_path = os.path.join(dir_path, csv_name)
        tasks_train, tasks_test, explainer, maml, feature, impurity_dict =\
            setup_experiment(k=k, ways=args.ways, shots=args.shots,
                             num_tasks=args.num_tasks,
                             experiment_dir=dir_path,
                             explainer_path=args.explainer,
                             sift_centroids_path=args.sift_centroids,
                             fft_crop_size=crop_size,
                             dataset=args.dataset
                             )
        
        model = explainer.model.clone()
        summary(model, (args.k,))
        # oss = nn.CrossEntropyLoss(reduction='mean')

        def _preprocess(x):
            d, lbl = x
            x = [feature(d).to(device), lbl.to(device)]
            return x

        if n_ev >= 0:
            explainer.set_src_generalized_matrix(n_positive_ev=n_ev)
        
        correlations_alpha_rank = []
        correlations_alpha_score = []
        correlations_ratio_rank = []
        correlations_ratio_score = []
        valid_ar = 0
        valid_as = 0
        valid_rr = 0
        valid_rs = 0
        for task_idx in range(args.num_tasks):
            corr_ar, corr_as, corr_rr, corr_rs, ar_immune, as_immune, rr_immune, rs_immune =\
                test_with_degraded_task(explainer, tasks_train, task_idx, preprocess=_preprocess)
            print(f'n_ev: {n_ev}, '
                  f'task_idx: {task_idx}, '
                  f'ar: {corr_ar}, as: {corr_as}, rr: {corr_rr}, rs: {corr_rs} '
                  f'ar_immune: {ar_immune}, as_immune: {as_immune}, '
                  f'rr_immune: {rr_immune}, rs_immune: {rs_immune}, '
                  )
            correlations_alpha_rank.append(corr_ar)
            correlations_alpha_score.append(corr_as)
            correlations_ratio_rank.append(corr_rr)
            correlations_ratio_score.append(corr_rs)
            if not ar_immune:
                valid_ar += 1
            if not as_immune:
                valid_as += 1
            if not rr_immune:
                valid_rr += 1
            if not rs_immune:
                valid_rs += 1
        
        corr_ar_avg = np.nanmean(correlations_alpha_rank)
        corr_as_avg = np.nanmean(correlations_alpha_score)
        corr_rr_avg = np.nanmean(correlations_ratio_rank)
        corr_rs_avg = np.nanmean(correlations_ratio_score)
        corr_ar_std = np.nanstd(correlations_alpha_rank)
        corr_as_std = np.nanstd(correlations_alpha_score)
        corr_rr_std = np.nanstd(correlations_ratio_rank)
        corr_rs_std = np.nanstd(correlations_ratio_score)

        print_str = (f'n_ev: {n_ev}, '
                     f'average '
                     f'ar: {corr_ar_avg}, as: {corr_as_avg}, '
                     f'rr: {corr_rr_avg}, rs: {corr_rs_avg}\n'
                     f'n_ev: {n_ev}, '
                     f'std '
                     f'ar: {corr_ar_std}, as: {corr_as_std}, '
                     f'rr: {corr_rr_std}, rs: {corr_rs_std}\n'
                     f'n_ev: {n_ev}, '
                     f'valid_ar: {valid_ar}, '
                     f'valid_as: {valid_as}, '
                     f'valid_rr: {valid_rr}, '
                     f'valid_rs: {valid_rs}\n'
                     )
        print(print_str)
        with open(log_path, mode='a') as f:
            f.write(print_str)
        print(f'saved {log_path}')

        df_result = pd.DataFrame({'n_ev': [n_ev] * args.num_tasks,
                                  'alpha_rank': correlations_alpha_rank,
                                  'alpha_score': correlations_alpha_score,
                                  'ratio_rank': correlations_ratio_rank,
                                  'ratio_score': correlations_ratio_score,
                                  })
        df_result.to_csv(csv_path, mode='a')
        print(f'saved {csv_path}')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--k', type=int, default=32)
    parser.add_argument('--dir', type=str, default=None)
    parser.add_argument('--explainer', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    # parser.add_argument('--num-sift-train-tasks', type=int, default=None)
    parser.add_argument('--sift-centroids', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    parser.add_argument('--nums-pos-evs', type=str, default=None)
    args = parser.parse_args()

    main(args)
