from argparse import ArgumentParser
from xmeta.maml.maml import explain_test_performance, setup_experiment
# from xmeta.utils.experiment import setup_experiment
from xmeta.utils.data import ShuffledTaskset
from xmeta.utils.experiment import plot_score_dist
import pickle
import os
import sys
import torch

loss = torch.nn.CrossEntropyLoss(reduction='mean')


def main(args):
    if args.dir is not None:
        save_dir = args.dir
    else:
        save_dir = os.path.dirname(args.explainer)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))
    
    if args.num_test_tasks is None:
        num_test_tasks = args.num_tasks
    else:
        num_test_tasks = args.num_test_tasks
    
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    pkl_name = os.path.basename(args.explainer)
    if args.test_recolor_tasks is not None:
        pkl_name = f'trt{args.test_recolor_tasks}_' + pkl_name
    if args.test_bgr_tasks is not None:
        pkl_name = f'tbt{args.test_bgr_tasks}_' + pkl_name
    if args.test_with_train_tasks:
        pkl_name = 'ttt_' + pkl_name
    pkl_name = f'df_ntest{num_test_tasks}_' + pkl_name
    pkl_path = os.path.join(save_dir, pkl_name)
 
    tasks_train, tasks_test, explainer, _, feature, impurity_dict =\
        setup_experiment(k=args.k, ways=args.ways, shots=args.shots,
                         num_tasks=args.num_tasks,
                         num_test_tasks=args.num_test_tasks,
                         experiment_dir=save_dir,
                         explainer_path=args.explainer,
                         train_mask_tasks=args.mask_tasks,
                         train_noise_tasks=args.noise_tasks,
                         train_shuffle_tasks=args.shuffle_tasks,
                         train_dark_tasks=args.dark_tasks,
                         train_recolor_tasks=args.recolor_tasks,
                         test_recolor_tasks=args.test_recolor_tasks,
                         train_bgr_tasks=args.bgr_tasks,
                         test_bgr_tasks=args.test_bgr_tasks,
                         sift_centroids_path=args.sift_centroids,
                         dataset=args.dataset,
                         device=device
                         )
    explainer.discard_intermediate()
    print(f'impurity_dict: {impurity_dict}')

    if args.test_with_train_tasks:
        tasks_test = ShuffledTaskset(tasks_train)

    def _preprocess(x):
        d, lbl = x
        x = [feature(d).to(device), lbl.to(device)]
        return x
    
    df_test = explain_test_performance(explainer, tasks_train, tasks_test,
                                       preprocess=_preprocess,
                                       loss=loss,
                                       shots=args.shots,
                                       ways=args.ways,
                                       num_train_task=args.num_tasks,
                                       num_test_task=args.num_test_tasks
                                       )
    if args.test_with_train_tasks:
        df_test['orig_test_task_idx'] =\
            df_test['test_task_idx'].apply(lambda x: tasks_test.idxes[x])
        df_test['self_rank'] = df_test.apply(
            lambda row: row['train_task_idx'].index(row['orig_test_task_idx']), axis=1)
        print(f'self_rank(mean): {df_test["self_rank"].mean()} '
              f'self_rank(std): {df_test["self_rank"].std()}')

    with open(pkl_path, 'wb') as f:
        pickle.dump(df_test, f)

    plot_score_dist(df_test, num_tasks=args.num_tasks, index_dict=impurity_dict,
                    stats_only=True)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--k', type=int, default=None)
    parser.add_argument('--dir', type=str, default=None)
    parser.add_argument('--explainer', type=str, default=None)
    parser.add_argument('--sift-centroids', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=None)
    parser.add_argument('--num-test-tasks', type=int, default=None)
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--dark-tasks', type=int, default=None)
    parser.add_argument('--recolor-tasks', type=int, default=None)
    parser.add_argument('--test-recolor-tasks', type=int, default=None)
    parser.add_argument('--bgr-tasks', type=int, default=None)
    parser.add_argument('--test-bgr-tasks', type=int, default=None)
    parser.add_argument('--test-with-train-tasks', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()
    main(args)
