import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
from torchsummary import summary
from xmeta.networks.simple_networks import ConvFeature as Convnet
from xmeta.protonet.protonet import fast_adapt, pairwise_distances_logits, meta_test
from xmeta.utils.data import ImpureTasksets, get_tasksets, RotatedTaskset
from xmeta.utils.seed import set_seed
import os, sys
import datetime
from torch.utils.tensorboard import SummaryWriter
import random


def main(args):
    if args.deterministic:
        torch.use_deterministic_algorithms(True)
    save_dir = datetime.datetime.now().strftime('%Y-%m%d-%H%M%S')
    save_dir = os.path.join('./cache', save_dir)
    writer = SummaryWriter(log_dir=save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    args_file_name = 'args_' + os.path.basename(__file__) + '.txt'
    args_file_path = os.path.join(save_dir, args_file_name)
    with open(args_file_path, mode='a') as f:
        f.write(" ".join(sys.argv))

    if args.cuda and torch.cuda.device_count():
        print("Using gpu")
        # torch.cuda.manual_seed(args.seed)
        set_seed(args.seed)
        device = torch.device('cuda')
    else:
        set_seed(args.seed)
        device = torch.device('cpu')

    def _preprocess(x):
        d, lbl = x
        x = [d.to(device), lbl.to(device)]
        return x

    set_seed(args.seed)
    if args.dataset == 'omniglot':
        model = Convnet(x_dim=1)
        summary(model, (1, 28, 28))
    else:
        model = Convnet()
        summary(model, (3, 84, 84))
    model.to(device)

    if args.num_tasks == -1:
        assert args.save_tasksets is False

    tasksets = get_tasksets(seed=args.seed,
                            name=args.dataset,
                            train_ways=args.train_way,
                            train_samples=args.train_shot + args.train_query,
                            test_ways=args.test_way,
                            test_samples=args.test_shot + args.test_query,
                            num_tasks=args.num_tasks,
                            root='~/data',
                            )
    # set_seed(args.seed)
    tasksets = ImpureTasksets(tasksets, num_tasks=args.num_tasks,
                              ways=args.train_way, shots=args.train_shot,
                              # train_mask_labels=mask_labels,
                              # train_mask_tasks=mask_tasks,
                              train_noise_tasks=args.noise_tasks,
                              # train_shuffle_tasks=shuffle_tasks,
                              # train_dark_tasks=dark_tasks,
                              # train_recolor_tasks=recolor_tasks,
                              # train_bgr_tasks=bgr_tasks,
                              savedir=save_dir,
                              seed=args.seed)
    if args.save_tasksets:
        tasksets.save()
    tasks_train = tasksets.train
    tasks_valid = tasksets.validation
    tasks_test = tasksets.test

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=20, gamma=0.5)

    set_seed(args.seed)
    for iteration in range(args.num_iterations):
        model.train()

        loss_ctr = 0
        n_loss = 0
        n_acc = 0
        for i in range(args.meta_batch_size):
            if args.num_tasks > 0:
                jj = random.randrange(args.num_tasks)
                batch = tasks_train[jj]
            else:
                batch = tasks_train.sample()

            batch = _preprocess(batch)
            loss, acc = fast_adapt(model,
                                   batch,
                                   args.train_way,
                                   args.train_shot,
                                   args.train_query,
                                   metric=pairwise_distances_logits,
                                   device=device)
            loss_ctr += 1
            n_loss += loss.item()
            n_acc += acc.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        lr_scheduler.step()
        
        meta_train_error = n_loss / loss_ctr
        meta_train_accuracy = n_acc / loss_ctr
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error)
        print('Meta Train Accuracy', meta_train_accuracy)
        n_data = (iteration + 1) * args.meta_batch_size *\
            (args.train_shot + args.train_query) * args.train_way
        writer.add_scalar('loss/MetaTrainError', meta_train_error, n_data)
        writer.add_scalar('eval/MetaTrainAccuracy', meta_train_accuracy, n_data)
        
        if (iteration % args.validation_interval) == 0:
            model.eval()
            loss_ctr = 0
            n_loss = 0
            n_acc = 0
            for i in range(args.meta_batch_size):
                if args.num_tasks > 0:
                    jj = random.randrange(args.num_tasks)
                    batch = tasks_valid[jj]
                else:
                    batch = tasks_valid.sample()

                batch = _preprocess(batch)
                loss, acc = fast_adapt(model,
                                       batch,
                                       args.train_way,
                                       args.train_shot,
                                       args.train_query,
                                       metric=pairwise_distances_logits,
                                       device=device)

                loss_ctr += 1
                n_loss += loss.item()
                n_acc += acc.item()
                meta_valid_error = n_loss / loss_ctr
                meta_valid_accuracy = n_acc / loss_ctr

            print('Meta Validation Error', meta_valid_error)
            print('Meta Validation Accuracy', meta_valid_accuracy)
            writer.add_scalar('loss/MetaValidationError', meta_valid_error, n_data)
            writer.add_scalar('eval/MetaValidationAccuracy', meta_valid_accuracy, n_data)    

        if (iteration % args.ckpt_interval) == 0 or (iteration == args.num_iterations - 1):
            model_name = (f'protonet_tasks{args.num_tasks}_mbs{args.meta_batch_size}'
                          f'_ways{args.train_way}_shots{args.train_shot}')
            # if mask_labels is not None:
            #     model_name += f'_mask{mask_labels}'
            # if mask_tasks is not None:
            #     model_name += f'_mt{mask_tasks}'
            if args.noise_tasks is not None:
                model_name += f'_nt{args.noise_tasks}'
            # if shuffle_tasks is not None:
            #     model_name += f'_st{shuffle_tasks}'
            # if dark_tasks is not None:
            #     model_name += f'_dt{dark_tasks}'
            # if recolor_tasks is not None:
            #     model_name += f'_rt{recolor_tasks}'
            # if bgr_tasks is not None:
            #     model_name += f'_bt{bgr_tasks}'
            # if weight_decay > 0:
            #     model_name += f'_wd{weight_decay}'
            # if num_rotations is not None:
            #     model_name += f'_nrot{num_rotations}'
            model_name += f'_{iteration}.pth'
            model_path = os.path.join(save_dir, model_name)
            torch.save(model.state_dict(), model_path)

            # set_seed(seed)
            model.eval()
            meta_test_error, meta_test_accuracy = meta_test(
                model=model, tasksets=tasksets,
                shots=args.test_shot, ways=args.test_way, queries=args.test_query,
                num_test_tasks=args.num_tasks, device=device)
            writer.add_scalar('loss/MetaTestError', meta_test_error, n_data)
            writer.add_scalar('eval/MetaTestAccuracy', meta_test_accuracy, n_data)

    loss_ctr = 0
    n_acc = 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--meta-batch-size', type=int, default=32)
    parser.add_argument('--num-iterations', type=int, default=100)
    # parser.add_argument('--ways', type=int, default=5)
    # parser.add_argument('--shots', type=int, default=5)
    parser.add_argument('--cuda', action='store_true')
    parser.add_argument('--deterministic', action='store_true')
    parser.add_argument('--num-tasks', type=int, default=-1)
    parser.add_argument('--train-way', type=int, default=5)
    parser.add_argument('--train-shot', type=int, default=5)
    parser.add_argument('--train-query', type=int, default=5)
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--test-shot', type=int, default=5)
    parser.add_argument('--test-query', type=int, default=5)
    parser.add_argument('--save-tasksets', action='store_true')
    parser.add_argument('--validation-interval', type=int, default=8)
    parser.add_argument('--ckpt-interval', type=int, default=1000)
    parser.add_argument('--dataset', type=str, default='mini-imagenet')
    parser.add_argument('--noise-tasks', type=int, default=None)
    args = parser.parse_args()

    main(args)
