import torch
from xmeta.maml.maml import xfast_adapt
import learn2learn as l2l
from torch import nn
from torchsummary import summary
import pickle
from argparse import ArgumentParser
import os
import sys
from xmeta.utils.data import ImpureTasksets, get_tasksets, RotatedTaskset
import numpy as np


def main(
        ways=5,
        shots=5,
        fast_lr=0.5,
        cuda=True,
        seed=42,
        ckpt='./cache/model.pth',
        num_tasks=-1,
        mask_labels=None,
        mask_tasks=None,
        noise_tasks=None,
        shuffle_tasks=None,
        dark_tasks=None,
        recolor_tasks=None,
        bgr_tasks=None,
        num_rotations=None,
        test_with_test_tasks=False,
        dataset='mini-imagenet',
):
    if not os.path.exists('./cache'):
        os.makedirs('./cache')

    save_dir = os.path.dirname(ckpt)
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write("\n" + " ".join(sys.argv))

    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load train/validation/test tasksets using the benchmark interface
    tasksets = get_tasksets(seed=seed,
                            name=dataset,
                            train_ways=ways,
                            train_samples=2 * shots,
                            test_ways=ways,
                            test_samples=2 * shots,
                            num_tasks=num_tasks,
                            root='~/data',
                            )
    
    if dataset == 'omniglot':
        model = l2l.vision.models.OmniglotCNN(ways)
        summary(model, (1, 28, 28))
    else:
        model = l2l.vision.models.MiniImagenetCNN(ways)
        summary(model, (3, 84, 84))
    
    # [TODO] add weight decay
    loss = nn.CrossEntropyLoss(reduction='mean')
    
    # load trained weights
    assert ckpt is not None
    assert os.path.exists(ckpt)
    # model_name = os.path.basename(ckpt)
    model_path = ckpt

    model.load_state_dict(torch.load(model_path))
    print(f'loaded {model_path}')
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)

    tasksets = ImpureTasksets(tasksets, num_tasks=num_tasks, ways=ways, shots=shots,
                              train_mask_labels=mask_labels,
                              train_mask_tasks=mask_tasks,
                              train_noise_tasks=noise_tasks,
                              train_shuffle_tasks=shuffle_tasks,
                              train_dark_tasks=dark_tasks,
                              train_recolor_tasks=recolor_tasks,
                              train_bgr_tasks=bgr_tasks,
                              savedir=save_dir
                              )
    if test_with_test_tasks:
        taskset = tasksets.test
    else:
        taskset = tasksets.train

    if num_rotations is not None:
        taskset = RotatedTaskset(taskset, shots=2 * shots, ways=ways,
                                 n_augment=num_rotations)
    
    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    errors = []
    accuracies = []
    for ii in range(num_tasks):
        if num_rotations is None:
            learner = maml.clone()
            task_train = taskset[ii]
            task_train = _preprocess(task_train)
            result = xfast_adapt(task_train, learner, loss, shots, ways)
            errors.append(result['evaluation']['error'].item())
            accuracies.append(result['evaluation']['accuracy'].item()) 
        else:
            # error = 0.
            for _ in range(num_rotations):
                learner = maml.clone()
                task_train = taskset[ii]
                task_train = _preprocess(task_train)
                result = xfast_adapt(task_train, learner, loss, shots, ways)
                errors.append(result['evaluation']['error'].item())
                accuracies.append(result['evaluation']['accuracy'].item()) 

        # Print some metrics
        print('\n')
        print('Iteration', ii)
        print('Meta Train ZeroShot Error', result['train']['error'].item())
        print('Meta Train ZeroShot Accuracy', result['train']['accuracy'].item())
        print('Meta Train Adaptation Error', result['adaptation']['error'].item())
        print('Meta Train Adaptation Accuracy', result['adaptation']['accuracy'].item())
        print('Meta Train Evaluation Error', result['evaluation']['error'].item())
        print('Meta Train Evaluation Accuracy', result['evaluation']['accuracy'].item())
    
        print(f'average error: {np.mean(errors)}')
        print(f'average accuracy: {np.mean(accuracies)}')


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    # parser.add_argument('--layer', type=int, default=None)
    # parser.add_argument('--hidden', type=str, default='16')
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--dark-tasks', type=int, default=None)
    parser.add_argument('--recolor-tasks', type=int, default=None)
    parser.add_argument('--bgr-tasks', type=int, default=None)
    parser.add_argument('--num-rotations', type=int, default=None)
    # parser.add_argument('--train-masks-path', type=str, default=None)
    # parser.add_argument('--activation', type=str, default=None)
    # parser.add_argument('--weight-decay', type=float, default=0.)
    parser.add_argument('--test-with-test-tasks', action='store_true')
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()

    main(shots=args.shots,
         ways=args.ways,
         cuda=args.cuda,
         ckpt=args.ckpt,
         num_tasks=args.num_tasks,
         # layer=args.layer,
         # hidden=hidden,
         mask_labels=args.mask_labels,
         mask_tasks=args.mask_tasks,
         noise_tasks=args.noise_tasks,
         shuffle_tasks=args.shuffle_tasks,
         dark_tasks=args.dark_tasks,
         recolor_tasks=args.recolor_tasks,
         bgr_tasks=args.bgr_tasks,
         num_rotations=args.num_rotations,
         # train_masks_path=args.train_masks_path,
         # activation=args.activation,
         test_with_test_tasks=args.test_with_test_tasks,
         dataset=args.dataset,
         )
