import sys
import time
import random
import copy

from tqdm import tqdm

import torch
from torchvision import transforms
import numpy as np

from code.data import DeterministicDataset
from code.data.datasets.mmnist import MNIST
from code.evaluation.mmnist import MMnistEvaluator
from code.optim import *
from code.models.mmnist import LeNetEncoder, CLSTaskHead



def do_train(
    model, 
    balancer,
    train_dataset, 
    optimizer,
    device,
    deterministic=True,
    **train_kwargs
):

    if deterministic is True:
        train_kwargs.update({"shuffle":False})
        train_dataset = DeterministicDataset(train_dataset)
        train_dataset.update_order()
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,**train_kwargs
    )

    model.train()
    if not isinstance(balancer, GradNormBalancer):
        encoder = model["encoder"]
        last_shared = None
    else:
        encoder = torch.nn.Identity()
        last_shared = model["encoder"]

    decoders = torch.nn.ModuleDict(
        {"left": model["left"], "right": model["right"]}
    )
    criteria = {
        "left": torch.nn.NLLLoss(),
        "right": torch.nn.NLLLoss(),
    }

    pbar = tqdm(total=len(train_loader))
    for (data, target1, target2) in train_loader:
        data = data.to(device) 
        target1, target2 = target1.to(device), target2.to(device)

        targets = {'left': target1, 'right': target2}
        optimizer.zero_grad()

        balancer.step(
            input=data, 
            targets=targets, 
            encoder=encoder,
            decoders=decoders,
            criteria=criteria,
            layer=last_shared,
        )
        losses = balancer.losses

        pbar.set_postfix({"loss": sum(losses)})
        optimizer.step()
        pbar.update(1)
    pbar.clear()
    pbar.close()
    del pbar


def do_test(model, test_dataset, device, **test_kwargs):
    model.eval()
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **test_kwargs
    )
    return MMnistEvaluator.evaluate(model, test_loader, device)


def experiment(args):

    seed1, seed2 = 1625412766, 1377899733
    print(f"Seeds are: {seed1}, {seed2}")
    
    np.random.seed(seed1)
    torch.manual_seed(seed2)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    train_kwargs = {'batch_size': args.train_batch,
                    'drop_last': False,
                    'shuffle': False}
    test_kwargs = {'batch_size': args.test_batch,
                'shuffle': False}
    cuda_kwargs = {'num_workers': 1,
                'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = MNIST(args.data_path, train=True, download=True, transform=transform, multi=True)
    dataset2 = MNIST(args.data_path, train=False, transform=transform, multi=True)
    
    use_cuda = torch.cuda.is_available()
    print(f"Cuda is used: {use_cuda}")
    print(f"Balancer: {args.balancer}")
    device = torch.device("cuda" if use_cuda else "cpu")

    assert args.balancer in ["zalign", "talign", "mgda", "mgdaub", "pcgrad", "gradnorm", "uncertainty"]

    if args.balancer == "zalign":   balancer = ZAlignedBalancer(args.scale_heads)
    elif args.balancer == "talign": balancer = ThetaAlignedBalancer(args.scale_heads)
    elif args.balancer == "mgda":   balancer = MGDABalancer(args.scale_heads)
    elif args.balancer == "mgdaub": balancer = MGDAUBBalancer(args.scale_heads)
    elif args.balancer == "pcgrad": balancer = PCGradBalancer()
    elif args.balancer == "gradnorm": balancer = GradNormBalancer(2.0)
    elif args.balancer == "uncertainty": balancer = HomoscedasticUncertaintyBalancer()
            
    model = torch.nn.ModuleDict()
    model['encoder'] = LeNetEncoder()
    model['left'] = CLSTaskHead()
    model['right'] = CLSTaskHead()

    if args.balancer == "gradnorm":
        model['left'].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        model['right'].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    elif args.balancer == "uncertainty":
        model['left'].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
        model['right'].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
    model.to(device)

    optimizer = torch.optim.SGD([
        {"params": model['encoder'].parameters(), "lr": args.e_lr},
        {"params": model['left'].parameters(), "lr": args.d_lr}, 
        {"params": model['right'].parameters(), "lr": args.d_lr}
    ], lr=1e-1, momentum=0.9)


    best = [0,0]

    for epoch in range(1, args.epochs+1):
        print("Epoch: {}".format(epoch))

        if epoch >= 15:# and epoch >= 15 == 0:
            for param_group in optimizer.param_groups:    
                param_group['lr'] *= 0.7

        do_train(model, balancer, dataset1, optimizer, device, deterministic=True,**train_kwargs)
        accs = do_test(model, dataset2, device, **test_kwargs)
        if sum(best) < sum(accs):
            best = accs
    
    print(f"Best accuracies: {best[0]} -- {best[1]}")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='Multi MNIST settings.')
    parser.add_argument('--data_path', type=str, default="~/Storage/datasets/mmnist",
        help='path to the MultiMNIST dataset')
    parser.add_argument('--balancer', type=str, default="zalign", 
        help='balancer type: zalign, talign, mgda, mgdaub, pcgrad, uncertainty, gradnorm')
    parser.add_argument('--scale_heads', type=bool, default=False, help='scale heads or not')
    parser.add_argument('--train_batch', type=int, default=256, help='train batch size')
    parser.add_argument('--test_batch', type=int, default=512, help='test batch size')
    parser.add_argument('--e_lr', type=float, default=1e-2, help='encoder learning rate')
    parser.add_argument('--d_lr', type=float, default=1e-2, help='decoders learning rate')
    parser.add_argument('--epochs', type=int, default=25, help='number of epochs')

    args = parser.parse_args()

    experiment(args)