import torch
from xmeta.networks.simple_networks import\
    one_layer_net, two_layer_net, three_layer_net, simple_layer_net
from xmeta.maml.maml import xfast_adapt, MAMLExplainer, MAMLExplainerOPA
from xmeta.utils.data import ImpureTasksets, get_tasksets
from xmeta.utils.sift import SiftFeature
from xmeta.utils.preprocess import TensorImageFFT
from xmeta.utils.seed import set_seed
import learn2learn as l2l
from torch import nn
from torchsummary import summary
import pickle
from argparse import ArgumentParser
import os
import sys
import numpy as np


def main(
        ways=5,
        shots=5,
        fast_lr=0.5,
        adaptation_steps=1,
        cuda=True,
        seed=42,
        k=10, num_test_task=100,
        ckpt='./cache/model.pth',
        num_tasks=-1,
        layer=None,
        hidden=[],
        mask_labels=None,
        mask_tasks=None,
        noise_tasks=None,
        shuffle_tasks=None,
        train_masks_path=None,
        activation=None,
        num_sift_train_tasks=100,
        weight_decay=0.,
        sift_centroids=None,
        opa=False,
        num_hessian_elements=None,
        dataset='mini-imagenet'
):
    if not os.path.exists('./cache'):
        os.makedirs('./cache')

    save_dir = os.path.dirname(ckpt)
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Create model
    if layer == 1:
        model = one_layer_net(n_in=k, n_out=ways)
    elif layer == 2:
        model = two_layer_net(n_in=k, n_out=ways)
    elif layer == 3:
        model = three_layer_net(n_in=k, n_out=ways)
    else:
        isinstance(hidden, list)
        layer = '-'.join(list(map(str, hidden)))
        model = simple_layer_net(n_in=k, n_out=ways, hidden=hidden,
                                 activation=activation)
    summary(model, (k,))

    # [TODO] add weight decay
    loss = nn.CrossEntropyLoss(reduction='mean')
    
    # load trained weights
    assert ckpt is not None
    assert os.path.exists(ckpt)
    model_name = os.path.basename(ckpt)
    model_path = ckpt

    model.load_state_dict(torch.load(model_path))
    print(f'loaded {model_path}')
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)

    # Create feature extractor before polluting the data
    tasksets = get_tasksets(seed=seed,
                            name=dataset,
                            train_ways=ways,
                            train_samples=2 * shots,
                            test_ways=ways,
                            test_samples=2 * shots,
                            num_tasks=num_tasks,
                            root='~/data',
                            )
    set_seed(seed)
    if dataset == 'omniglot':
        crop_size = int(np.sqrt(k))
        k = crop_size ** 2
        feature = TensorImageFFT(crop_shape=(crop_size, crop_size))
    else:
        feature = SiftFeature(tasksets.train, k=k, name='mifeature', use_cache=True,
                              n_sample=num_sift_train_tasks, pkl_path=sift_centroids)
    
    tasksets = ImpureTasksets(tasksets, num_tasks=num_tasks, ways=ways, shots=shots,
                              train_mask_labels=mask_labels,
                              train_mask_tasks=mask_tasks,
                              train_noise_tasks=noise_tasks,
                              train_shuffle_tasks=shuffle_tasks,
                              savedir=save_dir
                              )
    train_tasks = tasksets.train
    
    def _preprocess(x):
        d, lbl = x
        x = [feature(d).to(device), lbl.to(device)]
        return x

    if opa:
        explainer = MAMLExplainerOPA(model=maml, adapt_lr=fast_lr,
                                     ortho_vectors=True,
                                     num_hessian_elements=num_hessian_elements,
                                     # savedir=save_dir,
                                     # tag=explainer_tag, 
                                     )
    else:
        explainer = MAMLExplainer(model=maml, adapt_lr=fast_lr)

    for ii in range(num_tasks):
        learner = maml.clone()
        task_train = train_tasks[ii]
        task_train = _preprocess(task_train)
        result = xfast_adapt(batch=task_train, learner=learner, loss=loss,
                             shots=shots, ways=ways)

        if opa:
            explainer.add_src_test_error(error=result['evaluation']['error'] / num_tasks,
                                         model_output=result['evaluation']['prediction'])
        else:
            explainer.add_src_test_error(result['evaluation']['error'] / num_tasks)

        # Print some metrics
        print('\n')
        print('Iteration', ii)
        print('Meta Train ZeroShot Error', result['train']['error'].item())
        print('Meta Train ZeroShot Accuracy', result['train']['accuracy'].item())
        print('Meta Train Adaptation Error', result['adaptation']['error'].item())
        print('Meta Train Adaptation Accuracy', result['adaptation']['accuracy'].item())
        print('Meta Train Evaluation Error', result['evaluation']['error'].item())
        print('Meta Train Evaluation Accuracy', result['evaluation']['accuracy'].item())

    explainer.set_src_param_matrix()

    # save explainer
    explainer_name = 'expl_' + model_name[:-len('pth')] + 'pkl'
    if opa:
        if num_hessian_elements is not None:
            explainer_name = f'opa_nh{num_hessian_elements}_' + explainer_name
        else:
            explainer_name = 'opa_' + explainer_name
    explainer_path = os.path.join(save_dir, explainer_name)
    with open(explainer_path, 'wb') as f:
        pickle.dump(explainer, f)
    print(f'saved {explainer_path}')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--k', type=int, default=16)
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--layer', type=int, default=None)
    parser.add_argument('--hidden', type=str, default='16')
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--train-masks-path', type=str, default=None)
    parser.add_argument('--activation', type=str, default=None)
    parser.add_argument('--num-sift-train-tasks', type=int, default=None)
    parser.add_argument('--weight-decay', type=float, default=0.)
    parser.add_argument('--sift-centroids', type=str, default=None)
    parser.add_argument('--opa', action='store_true')
    parser.add_argument('--num-hessian-elements', type=int, default=None)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()
    hidden = [int(x.strip()) for x in args.hidden.split(',')]

    main(shots=args.shots,
         ways=args.ways,
         cuda=args.cuda,
         k=args.k,
         ckpt=args.ckpt,
         num_tasks=args.num_tasks,
         layer=args.layer,
         hidden=hidden,
         mask_labels=args.mask_labels,
         mask_tasks=args.mask_tasks,
         noise_tasks=args.noise_tasks,
         shuffle_tasks=args.shuffle_tasks,
         train_masks_path=args.train_masks_path,
         activation=args.activation,
         num_sift_train_tasks=args.num_sift_train_tasks,
         sift_centroids=args.sift_centroids,
         opa=args.opa,
         num_hessian_elements=args.num_hessian_elements,
         dataset=args.dataset
         )
