import sys
import time
import random
import copy

import numpy as np
from tqdm import tqdm
import torch
from pathlib import Path

from code.data import DeterministicDataset
from code.data.datasets.celeba import CELEBA
from code.evaluation.celeba import CelebAEvaluator
from code.models.celeba import ResNet18, FaceAttributeDecoder
from code.optim.gradnorm.balancer import GradNormBalancer
from code.optim.balancer_factory import get_balancer
from code.tools.utils import fix_seed



def do_train(
    model, 
    balancer,
    train_dataset, 
    optimizer,
    device,
    deterministic=True,
    **train_kwargs
):

    if deterministic is True:
        train_kwargs.update({"shuffle":False})
        train_dataset = DeterministicDataset(train_dataset)
        train_dataset.update_order()
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,**train_kwargs
    )

    model.train()

    if isinstance(balancer, GradNormBalancer):
        encoder = torch.nn.Sequential(
            model["encoder"].conv1,
            model["encoder"].bn1,
            torch.nn.ReLU(),
            model["encoder"].layer1,
            model["encoder"].layer2,
            model["encoder"].layer3,
            model["encoder"].layer4[0]
        )
        last_shared = model["encoder"].layer4[1]
    else:
        encoder = model["encoder"]
        last_shared = None

    decoders = torch.nn.ModuleDict(
        {str(i):model[str(i)] for i in range(40)}
    )

    criteria = {str(i): torch.nn.NLLLoss() for i in range(40)} 

    pbar = tqdm(total=len(train_loader))
    for (data, target) in train_loader:
        data = data.to(device) 
        target = target.to(device)

        targets = {str(i): target[:, i] for i in range(40)}
        optimizer.zero_grad()

        balancer.step(
            input=data, 
            targets=targets, 
            encoder=encoder,
            decoders=decoders,
            criteria=criteria,
            layer=last_shared
        )
        losses = balancer.losses

        pbar.set_postfix({"loss": sum(losses)})
        optimizer.step()
        pbar.update(1)
    pbar.clear()
    pbar.close()
    del pbar


def do_test(model, test_dataset, device, **test_kwargs):
    model.eval()
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **test_kwargs
    )
    return CelebAEvaluator.evaluate(model, test_loader, device)


def experiment(args):
    '''
    seed_dict = {1: [1984, 7777],
                 2: [7777, 1984],
                 3: [3883, 8338],
                 4: [8338, 3883],
                 5: [9889, 8998],
                 6: [8998, 9889],
                 7: [1111, 1010],
                 8: [4858, 8548],
                 9: [654321, 123456],
                 10: [1627421307, 1406924764],
                 }
    seed1, seed2 = seed_dict[args.round]
    print(f"Seeds are: {seed1}, {seed2}")
    torch.cuda.set_device(args.gpu)

    np.random.seed(seed1)
    torch.manual_seed(seed2)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    '''

    seed_dict = {1: 1984,
                 2: 7777,
                 3: 3883,
                 4: 8338,
                 5: 9889,
                 6: 8998,
                 7: 1111,
                 8: 4858,
                 9: 654321,
                 10: 1627421307}

    seed = seed_dict[args.round]
    print(f"Seed is: {seed}")

    torch.cuda.set_device(args.gpu)
    fix_seed(seed)

    train_kwargs = {'batch_size': args.train_batch,
                    'drop_last': False,
                    'shuffle': False}
    test_kwargs = {'batch_size': args.test_batch,
                'shuffle': False}
    cuda_kwargs = {'num_workers': 1,
                'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


    dataset1 = CELEBA(root=args.data_path,
                      is_transform=True,
                      split='train',
                      augmentations=None,
                      img_size=(64, 64))
    dataset2 = CELEBA(root=args.data_path,
                      is_transform=True,
                      split='val',
                      augmentations=None,
                      img_size=(64, 64))
    
    use_cuda = torch.cuda.is_available()
    print(f"Cuda is used: {use_cuda}")
    print(f"Balancer: {args.balancer}")
    device = torch.device("cuda" if use_cuda else "cpu")

    output_path = Path(args.output_path)
    #res_path = output_path / args.balancer / f'{seed1}_{seed2}'
    res_path = output_path / args.balancer / f'{seed}'
    res_path.mkdir(parents=True, exist_ok=True)

    # Define the balancer
    try:
        balancer = get_balancer(args)
    except ValueError as e:
        print(e)

    model = torch.nn.ModuleDict()
    model["encoder"] = ResNet18()
    for i in range(40):
        model[str(i)] = FaceAttributeDecoder()
        if args.balancer == "gradnorm":
            model[str(i)].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        elif args.balancer == "uncertainty":
            model[str(i)].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    best = [0, 0]
    best_state_dict = dict()

    for epoch in range(1, args.epochs+1):
        print("Epoch: {}".format(epoch))
        '''
        if epoch > 6 and epoch < 15:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.7
        '''
        if epoch % 10 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
            print(f'The new learning rate (LR):{optimizer.param_groups[0]["lr"]}')

        do_train(model, balancer, dataset1, optimizer, device, deterministic=True, **train_kwargs)
        results = do_test(model, dataset2, device, **test_kwargs)
        if best[0] < results[0]:
            best = results
            best_state_dict = model.state_dict()

    print(f"Best val acc: {best[0]}")
    for i in range(40):
        print(f"Class: {i}\t{dataset1.class_names[i]}\t, AvgAcc = {best[1][i]}")

    with open(res_path / 'res.txt', "w") as f:
        f.write(f"Best val acc: {best[0]}\n")
        for i in range(40):
            f.write(f"Class: {i}\t{dataset1.class_names[i]}\t, AvgAcc = {best[1][i]}\n")

    torch.save(best_state_dict, res_path / 'snapshot.pth')

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='CelebA settings.')
    parser.add_argument('--benchmark', type=str, default="celeba")
    parser.add_argument('--round', type=int, default=10, help='[1..10]')
    parser.add_argument('--data_path', type=str, default="~/Storage/datasets/CelebA",
                        help='path to the CelebA dataset')
    parser.add_argument('--output_path', type=str, default="/data2/output/amtl/celeba",
                        help='path to store the results')
    parser.add_argument('--balancer', type=str, default="mgda",
        help='balancer type: zalign, talign, mgda, mgdaub, pcgrad, gradnorm')
    parser.add_argument('--scale_heads', type=bool, default=False, help='scale heads or not')
    parser.add_argument('--train_batch', type=int, default=256, help='train batch size')
    parser.add_argument('--test_batch', type=int, default=256, help='test batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='encoder learning rate')
    #parser.add_argument('--lr', type=float, default=1e-2, help='encoder learning rate')
    #parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
    parser.add_argument('--epochs', type=int, default=60, help='number of epochs')
    parser.add_argument('--gpu', type=int, default=0, help='gpu id')

    args = parser.parse_args()

    experiment(args)
