import torch
from torch import nn
from torchsummary import summary
from argparse import ArgumentParser
import os
import sys
from xmeta.maml.maml import explain_test_performance, setup_experiment
import pandas as pd
import numpy as np


def main(args):
    if args.dir is None:
        dir_path = os.path.dirname(args.explainer)
    else:
        dir_path = args.dir

    if not os.path.exists('./cache'):
        os.makedirs('./cache')
    
    nums_pos_evs = [int(x.strip()) for x in args.nums_pos_evs.split(',')]

    save_dir = dir_path
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    csv_name = os.path.basename(args.explainer)
    csv_name = 'selfrank_' + csv_name + '.csv'
    csv_path = os.path.join(dir_path, csv_name)

    if args.dataset == 'omniglot':
        crop_size = int(np.sqrt(args.k))
        k = crop_size ** 2
    else:
        crop_size = None
        k = args.k
 
    means = []
    stds = []
    for n_ev in nums_pos_evs:
        tasks_train, tasks_test, explainer, maml, feature, impurity_dict =\
            setup_experiment(k=k, ways=args.ways, shots=args.shots,
                             num_tasks=args.num_tasks,
                             experiment_dir=dir_path,
                             explainer_path=args.explainer,
                             sift_centroids_path=args.sift_centroids,
                             fft_crop_size=crop_size,
                             dataset=args.dataset
                             )
        
        model = explainer.model.clone()
        summary(model, (args.k,))
        loss = nn.CrossEntropyLoss(reduction='mean')

        def _preprocess(x):
            d, lbl = x
            x = [feature(d).to(device), lbl.to(device)]
            return x

        if n_ev >= 0:
            explainer.set_src_generalized_matrix(n_positive_ev=n_ev)
        
        df = explain_test_performance(explainer, tasks_train, tasks_train,
                                      preprocess=_preprocess,
                                      loss=loss,
                                      shots=args.shots,
                                      ways=args.ways,
                                      num_train_task=args.num_tasks,
                                      num_test_task=args.num_tasks
                                      )
        df['self_rank'] = df.apply(
            lambda row: row['train_task_idx'].index(row['test_task_idx']), axis=1)
        _m = df['self_rank'].mean()
        _s = df['self_rank'].std()
        print(f'n_ev: {n_ev}, mean: {_m}, std: {_s}')
        
        means.append(_m)
        stds.append(_s)

    df_result = pd.DataFrame({'n_ev': nums_pos_evs, 'mean': means, 'std': stds})
    df_result.to_csv(csv_path, mode='a')
    print(f'saved {csv_path}')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--k', type=int, default=32)
    parser.add_argument('--dir', type=str, default=None)
    parser.add_argument('--explainer', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    # parser.add_argument('--num-sift-train-tasks', type=int, default=None)
    parser.add_argument('--sift-centroids', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    parser.add_argument('--nums-pos-evs', type=str, default=None)
    args = parser.parse_args()

    main(args)
