from argparse import ArgumentParser
from xmeta.protonet.protonet import setup_experiment, explain_test_performance
from xmeta.utils.data import ShuffledTaskset
from xmeta.utils.experiment import plot_score_dist
import pickle
import os
import sys
import torch


loss = torch.nn.CrossEntropyLoss(reduction='mean')


def main(args):
    if args.deterministic:
        torch.use_deterministic_algorithms(True)

    if args.dir is None:
        dir_path = os.path.dirname(args.explainer)
    save_dir = dir_path
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))
    
    if args.num_test_tasks is None:
        num_test_tasks = args.num_tasks
    else:
        num_test_tasks = args.num_test_tasks
    
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    pkl_name = os.path.basename(args.explainer)
    # if args.test_recolor_tasks is not None:
    #     pkl_name = f'trt{args.test_recolor_tasks}_' + pkl_name
    # if args.test_bgr_tasks is not None:
    #     pkl_name = f'tbt{args.test_bgr_tasks}_' + pkl_name
    if args.test_with_train_tasks:
        pkl_name = 'ttt_' + pkl_name
    pkl_name = f'df_ntest{num_test_tasks}_' + pkl_name
    pkl_path = os.path.join(dir_path, pkl_name)
 
    tasks_train, tasks_test, explainer, _, feature, impurity_dict =\
        setup_experiment(
                         train_way=args.train_way,
                         train_shot=args.train_shot,
                         train_query=args.train_query,
                         test_way=args.test_way,
                         test_shot=args.test_shot,
                         test_query=args.test_query,
                         num_tasks=args.num_tasks,
                         num_test_tasks=args.num_test_tasks,
                         experiment_dir=dir_path,
                         explainer_path=args.explainer,
                         # train_mask_tasks=args.mask_tasks,
                         train_noise_tasks=args.noise_tasks,
                         # train_shuffle_tasks=args.shuffle_tasks,
                         # train_dark_tasks=args.dark_tasks,
                         # train_recolor_tasks=args.recolor_tasks,
                         # test_recolor_tasks=args.test_recolor_tasks,
                         # train_bgr_tasks=args.bgr_tasks,
                         # test_bgr_tasks=args.test_bgr_tasks,
                         dataset=args.dataset,
                         device=device,
                         seed=args.seed
                         )
    explainer.discard_intermediate()
    print(f'impurity_dict: {impurity_dict}')

    if args.test_with_train_tasks:
        tasks_test = ShuffledTaskset(tasks_train)

    def _preprocess(x):
        d, lbl = x
        x = [feature(d).to(device), lbl.to(device)]
        return x
    
    df_test = explain_test_performance(explainer=explainer, 
                                       test_taskset=tasks_test,
                                       preprocess=_preprocess,
                                       ways=args.train_way,
                                       shots=args.test_shot,
                                       query_num=args.test_query,
                                       num_train_task=args.num_tasks,
                                       num_test_task=args.num_test_tasks,
                                       device=device
                                       )
    if args.test_with_train_tasks:
        df_test['orig_test_task_idx'] =\
            df_test['test_task_idx'].apply(lambda x: tasks_test.idxes[x])
        df_test['self_rank'] = df_test.apply(
            lambda row: row['train_task_idx'].index(row['orig_test_task_idx']), axis=1)
        print(f'self_rank(mean): {df_test["self_rank"].mean()} '
              f'self_rank(std): {df_test["self_rank"].std()}')

    with open(pkl_path, 'wb') as f:
        pickle.dump(df_test, f)
    
    plot_score_dist(df_test, num_tasks=args.num_tasks, index_dict=impurity_dict,
                    stats_only=True)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--train-way', type=int, default=5)
    parser.add_argument('--train-shot', type=int, default=5)
    parser.add_argument('--train-query', type=int, default=5)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--test-shot', type=int, default=5)
    parser.add_argument('--test-query', type=int, default=5)
    parser.add_argument('--dir', type=str, default=None)
    parser.add_argument('--explainer', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=None)
    parser.add_argument('--num-test-tasks', type=int, default=None)
    # parser.add_argument('--mask-labels', type=int, default=None)
    # parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    # parser.add_argument('--shuffle-tasks', type=int, default=None)
    # parser.add_argument('--dark-tasks', type=int, default=None)
    # parser.add_argument('--recolor-tasks', type=int, default=None)
    # parser.add_argument('--test-recolor-tasks', type=int, default=None)
    # parser.add_argument('--bgr-tasks', type=int, default=None)
    # parser.add_argument('--test-bgr-tasks', type=int, default=None)
    parser.add_argument('--test-with-train-tasks', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()
    main(args)
