import torch
from xmeta.maml.maml import xfast_adapt, OPAExplainer, save_explainer
import learn2learn as l2l
from torch import nn
from torchsummary import summary
import pickle
from argparse import ArgumentParser
import os
import sys
from xmeta.utils.data import ImpureTasksets, get_tasksets, RotatedTaskset
import numpy as np


def main(
        ways=5,
        shots=5,
        fast_lr=0.5,
        cuda=True,
        seed=42,
        ckpt='./cache/model.pth',
        num_tasks=-1,
        mask_labels=None,
        mask_tasks=None,
        noise_tasks=None,
        shuffle_tasks=None,
        dark_tasks=None,
        recolor_tasks=None,
        bgr_tasks=None,
        num_rotations=None,
        num_hessian_elements=None,
        discard_intermediate=False,
        ortho_vectors=True,
        # entropy_weight=False,
        exact_inverse=True,
        dataset='mini-imagenet',
):
    if not os.path.exists('./cache'):
        os.makedirs('./cache')

    save_dir = os.path.dirname(ckpt)
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load train/validation/test tasksets using the benchmark interface
    tasksets = get_tasksets(seed=seed,
                            name=dataset,
                            train_ways=ways,
                            train_samples=2 * shots,
                            test_ways=ways,
                            test_samples=2 * shots,
                            num_tasks=num_tasks,
                            root='~/data',
                            )
    
    if dataset == 'omniglot':
        model = l2l.vision.models.OmniglotCNN(ways)
        summary(model, (1, 28, 28))
    else:
        model = l2l.vision.models.MiniImagenetCNN(ways)
        summary(model, (3, 84, 84))

    # [TODO] add weight decay
    loss = nn.CrossEntropyLoss(reduction='mean')
    
    # load trained weights
    assert ckpt is not None
    assert os.path.exists(ckpt)
    model_name = os.path.basename(ckpt)
    model_path = ckpt

    model.load_state_dict(torch.load(model_path))
    print(f'loaded {model_path}')
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)

    tasksets = ImpureTasksets(tasksets, num_tasks=num_tasks, ways=ways, shots=shots,
                              train_mask_labels=mask_labels,
                              train_mask_tasks=mask_tasks,
                              train_noise_tasks=noise_tasks,
                              train_shuffle_tasks=shuffle_tasks,
                              train_dark_tasks=dark_tasks,
                              train_recolor_tasks=recolor_tasks,
                              train_bgr_tasks=bgr_tasks,
                              savedir=save_dir
                              )
    if num_rotations is None:
        tasks_train = tasksets.train
    else:
        tasks_train = RotatedTaskset(tasksets.train, shots=2 * shots, ways=ways,
                                     n_augment=num_rotations)
    
    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    explainer_name = model_name[:-len('.pth')]
    # if discard_intermediate:
    #     explainer_name = 'di_' + explainer_name
    # if entropy_weight:
    #     explainer_name = 'ew_' + explainer_name
    # if ortho_vectors:
    #     explainer_name = 'ov_' + explainer_name
    # if num_hessian_elements is not None:
    #     explainer_name = f'opa_nh{num_hessian_elements}_' + explainer_name
    # else:
    #     explainer_name = 'opa_' + explainer_name
    # explainer_name = 'expl_' + explainer_name
    explainer = OPAExplainer(model=maml, adapt_lr=fast_lr,
                             num_hessian_elements=num_hessian_elements,
                             ortho_vectors=ortho_vectors,
                             # entropy_weight=entropy_weight,
                             savedir=save_dir,
                             tag=explainer_name,
                             )

    for ii in range(num_tasks):
        if num_rotations is None:
            learner = maml.clone()
            task_train = tasks_train[ii]
            task_train = _preprocess(task_train)
            result = xfast_adapt(task_train, learner, loss, shots, ways)
            error = result['evaluation']['error']
            pred = result['evaluation']['prediction']
            explainer.add_src_test_error(error=error / num_tasks, model_output=pred)
        else:
            # error = 0.
            for _ in range(num_rotations):
                learner = maml.clone()
                task_train = tasks_train[ii]
                task_train = _preprocess(task_train)
                result = xfast_adapt(task_train, learner, loss, shots, ways)
                # error += result['evaluation']['error'] / num_rotations
                error = result['evaluation']['error']
                pred = result['evaluation']['prediction']
                explainer.add_src_test_error(error=error / num_tasks, model_output=pred)
            explainer.grad_src_test_errors[- num_rotations, :] =\
                explainer.grad_src_test_errors[- num_rotations:, :].sum(axis=0)
            if num_rotations > 1:
                explainer.grad_src_test_errors = np.delete(
                    explainer.grad_src_test_errors,
                    slice(-num_rotations + 1, None, None),
                    0
                    )
            # print(explainer.grad_src_test_errors.shape)

        # Print some metrics
        print('\n')
        print('Iteration', ii)
        print('Meta Train ZeroShot Error', result['train']['error'].item())
        print('Meta Train ZeroShot Accuracy', result['train']['accuracy'].item())
        print('Meta Train Adaptation Error', result['adaptation']['error'].item())
        print('Meta Train Adaptation Accuracy', result['adaptation']['accuracy'].item())
        print('Meta Train Evaluation Error', result['evaluation']['error'].item())
        print('Meta Train Evaluation Accuracy', result['evaluation']['accuracy'].item())

    explainer.normalize_src_test_hessian()
    save_explainer(explainer, prefix='tmp', postfix=str(num_tasks - 1),
                   model_path=model_path)

    explainer.set_src_param_matrix(taylor_series=(not exact_inverse))
    if discard_intermediate:
        explainer.discard_intermediate()

    # save explainer
    if exact_inverse:
        save_explainer(explainer, prefix='einv', model_path=model_path)
    else:
        save_explainer(explainer, model_path=model_path)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--dark-tasks', type=int, default=None)
    parser.add_argument('--recolor-tasks', type=int, default=None)
    parser.add_argument('--bgr-tasks', type=int, default=None)
    parser.add_argument('--num-rotations', type=int, default=None)
    parser.add_argument('--num-hessian-elements', type=int, default=None)
    parser.add_argument('--discard-intermediate', action='store_true')
    # parser.add_argument('--ortho-vectors', action='store_true')
    # parser.add_argument('--entropy-weight', action='store_true')
    # parser.add_argument('--exact-inverse', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()

    main(shots=args.shots,
         ways=args.ways,
         cuda=args.cuda,
         ckpt=args.ckpt,
         num_tasks=args.num_tasks,
         mask_labels=args.mask_labels,
         mask_tasks=args.mask_tasks,
         noise_tasks=args.noise_tasks,
         shuffle_tasks=args.shuffle_tasks,
         dark_tasks=args.dark_tasks,
         recolor_tasks=args.recolor_tasks,
         bgr_tasks=args.bgr_tasks,
         num_rotations=args.num_rotations,
         num_hessian_elements=args.num_hessian_elements,
         discard_intermediate=args.discard_intermediate,
         # ortho_vectors=args.ortho_vectors,
         # entropy_weight=args.entropy_weight,
         # exact_inverse=args.exact_inverse,
         dataset=args.dataset,
         )
