import sys
import time
import random
import copy

from tqdm import tqdm

import torch
from torch.nn import functional as F
from torchvision import transforms
import numpy as np

from code.data import DeterministicDataset
from code.data.datasets.cityscapes import CITYSCAPES
from code.data.augmentation.cityscapes import *
from code.evaluation.cityscapes import CityScapesEvaluator
from code.optim import *
from code.models.cityscapes import ResNet50Dilated, SegmentationDecoder


def l1_loss_depth(input, target, val=False):
    mask = target > 0
    if mask.data.sum() < 1:
        return None 

    lss = F.l1_loss(input[mask], target[mask], reduction="mean")
    return lss 


def l1_loss_instance(input, target, val=False):
    mask = target!=250
    if mask.data.sum() < 1:
        return None 

    lss = F.l1_loss(input[mask], target[mask], reduction="mean")
    return lss 

def do_train(
    model, 
    balancer,
    train_dataset, 
    optimizer,
    device,
    deterministic=True,
    **train_kwargs
):

    if deterministic is True:
        train_kwargs.update({"shuffle":False})
        train_dataset = DeterministicDataset(train_dataset)
        train_dataset.update_order()
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,**train_kwargs
    )


    model.train()

    if isinstance(balancer, GradNormBalancer):
        encoder = torch.nn.Sequential(
            model["encoder"].conv1,
            model["encoder"].bn1,
            torch.nn.ReLU(),
            model["encoder"].maxpool,
            model["encoder"].layer1,
            model["encoder"].layer2,
            model["encoder"].layer3,
            # model["encoder"].layer4[0]
        )
        last_shared = model["encoder"].layer4
    else:
        encoder = model["encoder"]
        last_shared = None

    decoders = torch.nn.ModuleDict()
    decoders["IS"] = model["IS"]
    decoders["SS"] = model["SS"]
    decoders["DE"] = model["DE"]

    criteria = {
        "IS": l1_loss_instance,
        "DE": l1_loss_depth,
        "SS": torch.nn.NLLLoss(ignore_index=250)
    }

    pbar = tqdm(total=len(train_loader))
    for data in train_loader:
        # data = data.to(device) 

        targets = {
            "SS": data[1].to(device),
            "IS": data[2].to(device),
            "DE": data[3].to(device),
        }

        optimizer.zero_grad()

        balancer.step(
            input=data[0].to(device), 
            targets=targets, 
            encoder=encoder,
            decoders=decoders,
            criteria=criteria,
            layer=last_shared
        )
        losses = balancer.losses

        post = {
            "loss": sum(losses),
            "IS": losses[0],
            "SS": losses[1],
            "DE": losses[2]
        }

        # pbar.set_postfix({"loss": sum(losses)})
        pbar.set_postfix(post)
        optimizer.step()
        pbar.update(1)
    pbar.clear()
    pbar.close()
    del pbar


def do_test(model, test_dataset, device, **test_kwargs):
    model.eval()
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **test_kwargs
    )
    return CityScapesEvaluator.evaluate(model, test_loader, device)


def experiment(args):

    seed1, seed2 = 1627421307, 1406924764
    print(f"Seeds are: {seed1}, {seed2}")

    torch.cuda.set_device(args.gpu)
    
    np.random.seed(seed1)
    torch.manual_seed(seed2)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    train_kwargs = {'batch_size': args.train_batch,
                    'drop_last': False,
                    'shuffle': False}
    test_kwargs = {'batch_size': args.test_batch,
                'shuffle': False}
    cuda_kwargs = {'num_workers': 1,
                'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    augs = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    dataset1 = CITYSCAPES(root=args.data_path, is_transform=True, split=['train'], augmentations=augs)
    dataset2 = CITYSCAPES(root=args.data_path, is_transform=True, split=['val'])
    
    use_cuda = torch.cuda.is_available()
    print(f"Cuda is used: {use_cuda}")
    print(f"Balancer: {args.balancer}")
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.balancer == "zalign":   balancer = ZAlignedBalancer(args.scale_heads)
    elif args.balancer == "talign": balancer = ThetaAlignedBalancer(args.scale_heads)
    elif args.balancer == "mgda":   balancer = MGDABalancer(args.scale_heads)
    elif args.balancer == "mgdaub": balancer = MGDAUBBalancer(args.scale_heads)
    elif args.balancer == "pcgrad": balancer = PCGradBalancer()
    elif args.balancer == "uncertainty": balancer = HomoscedasticUncertaintyBalancer()
    elif args.balancer == "gradnorm": balancer = GradNormBalancer(alpha=2.0)
    
    print(type(balancer))
            
    model = torch.nn.ModuleDict()
    model["encoder"] = ResNet50Dilated(True)
    model["IS"] = SegmentationDecoder(num_class=2, task_type="R")
    model["SS"] = SegmentationDecoder(num_class=19, task_type="C")
    model["DE"] = SegmentationDecoder(num_class=1, task_type="R")
    if args.balancer == "gradnorm":
        model["IS"].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        model["SS"].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        model["DE"].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    elif args.balancer == "uncertainty":
        model["IS"].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
        model["SS"].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
        model["DE"].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best = [0, 1000, 1000]
    best_state_dict = dict()

    for epoch in range(1, args.epochs+1):
        print("Epoch: {}".format(epoch))


        if epoch >= args.milestone and epoch % args.period == 0:
            for param_group in optimizer.param_groups:    
                param_group['lr'] *= args.scaler


        do_train(model, balancer, dataset1, optimizer, device, deterministic=True,**train_kwargs)
        results = do_test(model, dataset2, device, **test_kwargs)
        if best[1]+best[2]-best[0] > results[1]+results[2]-results[0]:
            best = results
            best_state_dict = model.state_dict()
    
    print(f"BEST: mIou={best[0]}, ISL1={best[1]}, DEL2={best[2]}")
    
    with open(f"./output/cityscapes/{args.balancer}.txt", "w") as file:
        file.write(f"BEST: mIou={best[0]}, ISL1={best[1]}, DEL2={best[2]}")

    torch.save(best_state_dict, f"./output/cityscapes/weights/{args.balancer}.pth")

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='CelebA settings.')
    parser.add_argument('--data_path', type=str, default="~/Storage/datasets/CityScapes",
        help='path to the CityScapes dataset')
    parser.add_argument('--balancer', type=str, default="zalign", 
        help='balancer type: zalign, mgdaub')
    parser.add_argument('--scale_heads', type=bool, default=False, help='scale heads or not')
    parser.add_argument('--train_batch', type=int, default=8, help='train batch size')
    parser.add_argument('--test_batch', type=int, default=4, help='test batch size')
    parser.add_argument('--lr', type=float, default=1e-4, help='encoder learning rate')
    parser.add_argument('--epochs', type=int, default=70, help='number of epochs')
    parser.add_argument('--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--milestone', type=int, default=40)
    parser.add_argument('--period', type=int, default=3)
    parser.add_argument('--scaler', type=float, default=0.7)

    args = parser.parse_args()
    print(args)

    experiment(args)