import numpy as np
import torch
from xmeta.networks.simple_networks import \
    one_layer_net, two_layer_net, three_layer_net, simple_layer_net
from xmeta.utils.seed import set_seed
from xmeta.maml.maml import fast_adapt, meta_test
import learn2learn as l2l
from torch import nn, optim
from torchsummary import summary
from xmeta.utils.sift import SiftFeature
from argparse import ArgumentParser
import os
from xmeta.utils.csv import output_dict
from xmeta.utils.data import ImpureTasksets, get_tasksets
from xmeta.utils.preprocess import TensorImageFFT
import random
import datetime
import sys
from torch.utils.tensorboard import SummaryWriter


def main(
        ways=5,
        shots=5,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        num_test_iterations=100,
        ckpt_interval=100,
        cuda=True,
        seed=42,
        k=10,
        ckpt=None,
        num_tasks=-1,
        layer=2,
        hidden=[],
        mask_labels=None,
        mask_tasks=None,
        noise_tasks=None,
        shuffle_tasks=None,
        activation=None,
        num_sift_train_tasks=None,
        save_tasksets=True,
        weight_decay=0.,
        sift_centroids=None,
        dataset='mini-imagenet'
):

    save_dir = datetime.datetime.now().strftime('%Y-%m%d-%H%M%S')
    save_dir = os.path.join('./cache', save_dir)
    writer = SummaryWriter(log_dir=save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write(" ".join(sys.argv))

    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load train/validation/test tasksets using the benchmark interface
    if num_tasks == -1:
        assert save_tasksets is False

    tasksets = get_tasksets(seed=seed,
                            name=dataset,
                            train_ways=ways,
                            train_samples=2 * shots,
                            test_ways=ways,
                            test_samples=2 * shots,
                            num_tasks=num_tasks,
                            root='~/data',
                            )
    # Create feature extractor before polluting the data
    set_seed(seed)
    if dataset == 'omniglot':
        crop_size = int(np.sqrt(k))
        k = crop_size ** 2
        feature = TensorImageFFT(crop_shape=(crop_size, crop_size))
    else:
        feature = SiftFeature(tasksets.train, k=k, name='mifeature', use_cache=True,
                              n_sample=num_sift_train_tasks, pkl_path=sift_centroids)
    
    set_seed(seed)
    tasksets = ImpureTasksets(tasksets, num_tasks=num_tasks, ways=ways, shots=shots,
                              train_mask_labels=mask_labels,
                              train_mask_tasks=mask_tasks,
                              train_noise_tasks=noise_tasks,
                              train_shuffle_tasks=shuffle_tasks,
                              savedir=save_dir)
    # Create model
    if layer == 1:
        model = one_layer_net(n_in=k, n_out=ways)
    elif layer == 2:
        model = two_layer_net(n_in=k, n_out=ways)
    elif layer == 3:
        model = three_layer_net(n_in=k, n_out=ways)
    else:
        isinstance(hidden, list)
        layer = '-'.join(list(map(str, hidden)))
        model = simple_layer_net(n_in=k, n_out=ways, hidden=hidden,
                                 activation=activation)

    summary(model, (k, ))

    if ckpt is not None:
        # load trained weights
        assert ckpt is not None
        assert os.path.exists(ckpt)
        model_name = os.path.basename(ckpt)
        model_path = ckpt
        model.load_state_dict(torch.load(model_path))
        print(f'loaded {model_path}')

    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)

    opt = optim.Adam(maml.parameters(), meta_lr, weight_decay=weight_decay)
    loss = nn.CrossEntropyLoss(reduction='mean')

    def _mask(x, mask_labels):
        imgs, labels = x
        for mask_label in mask_labels:
            imgs[labels == mask_label] *= 0
        return imgs, labels
    
    def _preprocess(x):
        d, lbl = x
        x = [feature(d).to(device), lbl.to(device)]
        return x

    ii = 0
 
    set_seed(seed)
    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for _ in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()

            if num_tasks > 0:
                jj = random.randrange(num_tasks)
                task = tasksets.train[jj]
            else:
                task = tasksets.train.sample()

            task_train = _preprocess(task)

            evaluation_error, evaluation_accuracy = fast_adapt(task_train,
                                                               learner,
                                                               loss,
                                                               shots,
                                                               ways,
                                                               )
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            learner = maml.clone()
            # data, label = tasksets.validation.sample()
            if num_tasks > 0:
                kk = random.randrange(num_tasks)
                task_vali = tasksets.validation[kk]
            else:
                task_vali = tasksets.validation.sample()
            task_vali = _preprocess(task_vali)
            evaluation_error, evaluation_accuracy = fast_adapt(task_vali,
                                                               learner,
                                                               loss,
                                                               shots,
                                                               ways,
                                                               )
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        meta_train_error = meta_train_error / meta_batch_size
        meta_train_accuracy = meta_train_accuracy / meta_batch_size
        meta_valid_error = meta_valid_error / meta_batch_size
        meta_valid_accuracy = meta_valid_accuracy / meta_batch_size
        
        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error)
        print('Meta Train Accuracy', meta_train_accuracy)
        print('Meta Validation Error', meta_valid_error)
        print('Meta Validation Accuracy', meta_valid_accuracy)
        n_data = iteration * meta_batch_size * 2 * shots * ways
        writer.add_scalar('loss/MetaTrainError', meta_train_error, n_data)
        writer.add_scalar('eval/MetaTrainAccuracy', meta_train_accuracy, n_data)
        writer.add_scalar('loss/MetaValidationError', meta_valid_error, n_data)
        writer.add_scalar('eval/MetaValidationAccuracy', meta_valid_accuracy, n_data)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()
        ii += 1

        # Save model weights
        if ii % ckpt_interval == 0 or ii == num_iterations:
            model_name = (f'maml_k{k}_layer{layer}_tasks{num_tasks}_mbs{meta_batch_size}'
                          f'_ways{ways}_shots{shots}')
            if mask_labels is not None:
                model_name += f'_mask{mask_labels}'
            if mask_tasks is not None:
                model_name += f'_mt{mask_tasks}'
            if noise_tasks is not None:
                model_name += f'_nt{noise_tasks}'
            if shuffle_tasks is not None:
                model_name += f'_st{shuffle_tasks}'
            model_name += f'_{ii}.pth'
            model_path = os.path.join(save_dir, model_name)
            torch.save(model.state_dict(), model_path)

            # set_seed(seed)
            meta_test_error, meta_test_accuracy = meta_test(model=maml,
                                                            tasksets=tasksets,
                                                            preprocess=_preprocess,
                                                            shots=5, ways=5,
                                                            num_test_tasks=num_tasks
                                                            )
            writer.add_scalar('loss/MetaTestError', meta_test_error, n_data)
            writer.add_scalar('eval/MetaTestAccuracy', meta_test_accuracy, n_data)
            
    if save_tasksets:
        tasksets.save()
    
    writer.close()
    return meta_train_error, meta_train_accuracy,\
        meta_test_error, meta_test_accuracy


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--meta-batch-size', type=int, default=32)
    parser.add_argument('--num-iterations', type=int, default=100)
    parser.add_argument('--num-test-iterations', type=int, default=100)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--k', type=int, default=16)
    parser.add_argument('--ckpt', type=str, default=None)
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--layer', type=int, default=None)
    parser.add_argument('--hidden', type=str, default='16')
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--activation', type=str, default=None)
    parser.add_argument('--num-sift-train-tasks', type=int, default=None)
    parser.add_argument('--save-tasksets', action='store_true')
    parser.add_argument('--weight-decay', type=float, default=0.)
    parser.add_argument('--sift-centroids', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()
    hidden = [int(x.strip()) for x in args.hidden.split(',')]
    
    train_error, train_accuracy, test_error, test_accuracy =\
        main(meta_batch_size=args.meta_batch_size,
             num_iterations=args.num_iterations,
             num_test_iterations=args.num_test_iterations,
             shots=args.shots,
             ways=args.ways,
             cuda=args.cuda,
             k=args.k,
             ckpt=args.ckpt,
             num_tasks=args.num_tasks,
             layer=args.layer,
             hidden=hidden,
             mask_labels=args.mask_labels,
             mask_tasks=args.mask_tasks,
             noise_tasks=args.noise_tasks,
             shuffle_tasks=args.shuffle_tasks,
             activation=args.activation,
             num_sift_train_tasks=args.num_sift_train_tasks,
             save_tasksets=args.save_tasksets,
             weight_decay=args.weight_decay,
             sift_centroids=args.sift_centroids,
             dataset=args.dataset
             )
    
    out_dict = vars(args)
    out_dict['train_error'] = train_error
    out_dict['train_accuracy'] = train_accuracy
    out_dict['test_error'] = test_error
    out_dict['test_accuracy'] = test_accuracy
    csv_name = os.path.splitext(os.path.basename(__file__))[0] + '.csv'
    csv_path = os.path.join('./cache', csv_name)
    output_dict(out_dict=out_dict, csv_path=csv_path)
