import numpy as np
import torch
from xmeta.utils.seed import set_seed
from xmeta.maml.maml import fast_adapt, meta_test
import learn2learn as l2l
from torch import nn, optim
from torchsummary import summary
from argparse import ArgumentParser
import os
from xmeta.utils.csv import output_dict
from xmeta.utils.data import ImpureTasksets, get_tasksets, RotatedTaskset
import random
import datetime
import sys
from torch.utils.tensorboard import SummaryWriter


def main(
        ways=5,
        shots=5,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        num_iterations=60000,
        ckpt_interval=1000,
        cuda=True,
        seed=42,
        num_tasks=-1,
        mask_labels=None,
        mask_tasks=None,
        noise_tasks=None,
        shuffle_tasks=None,
        dark_tasks=None,
        recolor_tasks=None,
        bgr_tasks=None,
        save_tasksets=True,
        weight_decay=0.,
        validation_interval=8,
        sgd=False,
        num_rotations=None,
        dataset='mini-imagenet',
):

    save_dir = datetime.datetime.now().strftime('%Y-%m%d-%H%M%S')
    save_dir = os.path.join('./cache', save_dir)
    writer = SummaryWriter(log_dir=save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write(" ".join(sys.argv))

    if cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    # Load train/validation/test tasksets using the benchmark interface
    if num_tasks == -1:
        assert save_tasksets is False
    
    tasksets = get_tasksets(seed=seed,
                            name=dataset,
                            train_ways=ways,
                            train_samples=2 * shots,
                            test_ways=ways,
                            test_samples=2 * shots,
                            num_tasks=num_tasks,
                            root='~/data',
                            )
    set_seed(seed)
    tasksets = ImpureTasksets(tasksets, num_tasks=num_tasks, ways=ways, shots=shots,
                              train_mask_labels=mask_labels,
                              train_mask_tasks=mask_tasks,
                              train_noise_tasks=noise_tasks,
                              train_shuffle_tasks=shuffle_tasks,
                              train_dark_tasks=dark_tasks,
                              train_recolor_tasks=recolor_tasks,
                              train_bgr_tasks=bgr_tasks,
                              savedir=save_dir)
    
    if dataset == 'omniglot':
        model = l2l.vision.models.OmniglotCNN(ways)
        summary(model, (1, 28, 28))
    else:
        model = l2l.vision.models.MiniImagenetCNN(ways)
        summary(model, (3, 84, 84))
    
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)

    if sgd:
        opt = optim.SGD(maml.parameters(), meta_lr, weight_decay=weight_decay)
    else:
        opt = optim.Adam(maml.parameters(), meta_lr, weight_decay=weight_decay)
    loss = nn.CrossEntropyLoss(reduction='mean')
    
    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    if num_rotations is None:
        tasks_train = tasksets.train
    else:
        tasks_train = RotatedTaskset(tasksets.train, shots=2 * shots, ways=ways,
                                     n_augment=num_rotations)
    
    set_seed(seed)
    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_errors = []
        meta_train_accuracys = []
        meta_valid_errors = []
        meta_valid_accuracys = []
        for ii in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()

            if num_tasks > 0:
                jj = random.randrange(num_tasks)
                # task = tasksets.train[jj]
                task = tasks_train[jj]
            else:
                # task = tasksets.train.sample()
                task = tasks_train.sample()

            task_train = _preprocess(task)
            evaluation_error, evaluation_accuracy = fast_adapt(task_train,
                                                               learner,
                                                               loss,
                                                               shots,
                                                               ways,
                                                               )
            evaluation_error.backward()
            meta_train_errors.append(evaluation_error.item())
            meta_train_accuracys.append(evaluation_accuracy.item())
            
            # Compute meta-validation loss
            if (ii % validation_interval == 0) or\
               (ii % meta_batch_size == 0):
                learner = maml.clone()

                if num_tasks > 0:
                    kk = random.randrange(num_tasks)
                    task_vali = tasksets.validation[kk]
                else:
                    task_vali = tasksets.validation.sample()
                
                task_vali = _preprocess(task_vali)
                evaluation_error, evaluation_accuracy = fast_adapt(task_vali,
                                                                   learner,
                                                                   loss,
                                                                   shots,
                                                                   ways,
                                                                   )
                meta_valid_errors.append(evaluation_error.item())
                meta_valid_accuracys.append(evaluation_accuracy.item())

        meta_train_error = np.mean(meta_train_errors)
        meta_train_accuracy = np.mean(meta_train_accuracys)
        meta_valid_error = np.mean(meta_valid_errors)
        meta_valid_accuracy = np.mean(meta_valid_accuracys)
        
        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error)
        print('Meta Train Accuracy', meta_train_accuracy)
        print('Meta Validation Error', meta_valid_error)
        print('Meta Validation Accuracy', meta_valid_accuracy)
        n_data = iteration * meta_batch_size * 2 * shots * ways
        writer.add_scalar('loss/MetaTrainError', meta_train_error, n_data)
        writer.add_scalar('eval/MetaTrainAccuracy', meta_train_accuracy, n_data)
        writer.add_scalar('loss/MetaValidationError', meta_valid_error, n_data)
        writer.add_scalar('eval/MetaValidationAccuracy', meta_valid_accuracy, n_data)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

        # Save model weights
        if (iteration % ckpt_interval) == 0 or (iteration == num_iterations - 1):
            model_name = (f'maml_tasks{num_tasks}'
                          f'_mbs{meta_batch_size}_ways{ways}_shots{shots}')
            if mask_labels is not None:
                model_name += f'_mask{mask_labels}'
            if mask_tasks is not None:
                model_name += f'_mt{mask_tasks}'
            if noise_tasks is not None:
                model_name += f'_nt{noise_tasks}'
            if shuffle_tasks is not None:
                model_name += f'_st{shuffle_tasks}'
            if dark_tasks is not None:
                model_name += f'_dt{dark_tasks}'
            if recolor_tasks is not None:
                model_name += f'_rt{recolor_tasks}'
            if bgr_tasks is not None:
                model_name += f'_bt{bgr_tasks}'
            if weight_decay > 0:
                model_name += f'_wd{weight_decay}'
            if num_rotations is not None:
                model_name += f'_nrot{num_rotations}'
            model_name += f'_{iteration}.pth'
            model_path = os.path.join(save_dir, model_name)
            torch.save(model.state_dict(), model_path)

            # set_seed(seed)
            meta_test_error, meta_test_accuracy = meta_test(model=maml,
                                                            tasksets=tasksets,
                                                            preprocess=_preprocess,
                                                            shots=5, ways=5,
                                                            num_test_tasks=num_tasks
                                                            )
            writer.add_scalar('loss/MetaTestError', meta_test_error, n_data)
            writer.add_scalar('eval/MetaTestAccuracy', meta_test_accuracy, n_data)
            
    if save_tasksets:
        tasksets.save()
    
    writer.close()
    return meta_train_error, meta_train_accuracy,\
        meta_test_error, meta_test_accuracy


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--meta-batch-size', type=int, default=32)
    parser.add_argument('--num-iterations', type=int, default=100)
    # parser.add_argument('--num-test-iterations', type=int, default=100)
    parser.add_argument('--ways', type=int, default=5)
    parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--num-tasks', type=int, default=-1)
    # parser.add_argument('--layer', type=int, default=None)
    # parser.add_argument('--hidden', type=str, default='16')
    parser.add_argument('--mask-labels', type=int, default=None)
    parser.add_argument('--mask-tasks', type=int, default=None)
    parser.add_argument('--noise-tasks', type=int, default=None)
    parser.add_argument('--shuffle-tasks', type=int, default=None)
    parser.add_argument('--dark-tasks', type=int, default=None)
    parser.add_argument('--recolor-tasks', type=int, default=None)
    parser.add_argument('--bgr-tasks', type=int, default=None)
    parser.add_argument('--save-tasksets', action='store_true')
    parser.add_argument('--weight-decay', type=float, default=0.)
    parser.add_argument('--ckpt-interval', type=int, default=1000)
    parser.add_argument('--sgd', action='store_true')
    parser.add_argument('--meta-lr', type=float, default=0.003)
    parser.add_argument('--num-rotations', type=int, default=None)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    args = parser.parse_args()
    
    train_error, train_accuracy, test_error, test_accuracy =\
        main(meta_batch_size=args.meta_batch_size,
             num_iterations=args.num_iterations,
             shots=args.shots,
             ways=args.ways,
             cuda=args.cuda,
             num_tasks=args.num_tasks,
             mask_labels=args.mask_labels,
             mask_tasks=args.mask_tasks,
             noise_tasks=args.noise_tasks,
             shuffle_tasks=args.shuffle_tasks,
             dark_tasks=args.dark_tasks,
             recolor_tasks=args.recolor_tasks,
             bgr_tasks=args.bgr_tasks,
             save_tasksets=args.save_tasksets,
             weight_decay=args.weight_decay,
             ckpt_interval=args.ckpt_interval,
             sgd=args.sgd,
             meta_lr=args.meta_lr,
             num_rotations=args.num_rotations,
             dataset=args.dataset,
             )
    
    out_dict = vars(args)
    out_dict['train_error'] = train_error
    out_dict['train_accuracy'] = train_accuracy
    out_dict['test_error'] = test_error
    out_dict['test_accuracy'] = test_accuracy
    csv_name = os.path.splitext(os.path.basename(__file__))[0] + '.csv'
    csv_path = os.path.join('./cache', csv_name)
    output_dict(out_dict=out_dict, csv_path=csv_path)
