import os
from os import path as osp
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np

from code.data.datasets.posenet import SevenScenesDatasetFactory
from code.optim.balancer_factory import get_balancer
from code.optim.gradnorm.balancer import GradNormBalancer
from code.models.posenet import PoseNetEncoder, OrientationHead, PositionHead
from code.evaluation.posenet import SevenScenesEvaluator


def do_train(
        model,
        balancer,
        train_dataset,
        optimizer,
        device,
        **train_kwargs
):
    train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)

    model.train()
    if not isinstance(balancer, GradNormBalancer):
        encoder = model["encoder"]
        last_shared = None
    else:
        encoder = torch.nn.Identity()
        last_shared = model["encoder"]

    decoders = torch.nn.ModuleDict({"left": model["left"], "right": model["right"]})

    criteria = {"left": nn.MSELoss(), "right": nn.MSELoss()}

    loss_total, t_loss_total, q_loss_total = 0., 0., 0.
    for data in tqdm(train_loader, total=len(train_loader)):
        img, t_gt, q_gt = data['img'].to(device), data['t_gt'].to(device), data['q_gt'].to(device)

        targets = {'left': t_gt, 'right': q_gt}
        optimizer.zero_grad()

        balancer.step(
            input=img,
            targets=targets,
            encoder=encoder,
            decoders=decoders,
            criteria=criteria,
            layer=last_shared,
        )
        losses = balancer.losses
        loss_total += sum(losses)
        t_loss_total += losses[0]
        q_loss_total += losses[1]

        optimizer.step()

    avg_total_loss = loss_total / len(train_loader)
    avg_t_loss = t_loss_total / len(train_loader)
    avg_q_loss = q_loss_total / len(train_loader)
    return avg_total_loss, avg_t_loss, avg_q_loss


def do_test(model, test_dataset, device, **test_kwargs):
    model.eval()
    test_loader = torch.utils.data.DataLoader(
        test_dataset, **test_kwargs
    )
    return SevenScenesEvaluator.evaluate(model, test_loader, device)


def experiment(args):
    seed_dict = {1: [1984, 7777],
                 2: [7777, 1984],
                 3: [3883, 8338],
                 4: [8338, 3883],
                 5: [9889, 8998],
                 6: [8998, 9889],
                 7: [1111, 1010],
                 8: [4858, 8548],
                 9: [654321, 123456],
                 10: [1625412766, 1377899733],
                 }

    seed1, seed2 = seed_dict[args.round]
    print(f"Seeds are: {seed1}, {seed2}")

    np.random.seed(seed1)
    torch.manual_seed(seed2)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    train_kwargs = {'batch_size': args.train_batch,
                    'drop_last': False,
                    'shuffle': True}
    test_kwargs = {'batch_size': args.test_batch,
                   'drop_last': False,
                   'shuffle': False}
    cuda_kwargs = {'num_workers': 5,
                   'pin_memory': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    res_path = osp.join(args.output_path, args.scene_name, args.balancer, str(args.round))
    if not osp.exists(res_path):
        os.makedirs(res_path)

    # Define the datasets
    try:
        train_dataset, test_dataset = SevenScenesDatasetFactory.create_splits(args.data_path, args.scene_name)
    except ValueError as e:
        print(e)

    use_cuda = torch.cuda.is_available()
    print(f"Cuda is used: {use_cuda}")
    print(f"Balancer: {args.balancer}")
    print(f"Scene name: {args.scene_name}")
    device = torch.device("cuda" if use_cuda else "cpu")

    # Define the balancer
    try:
        balancer = get_balancer(args)
    except ValueError as e:
        print(e)

    # Define the model
    model = torch.nn.ModuleDict()
    model['encoder'] = PoseNetEncoder()
    model['left'] = PositionHead()
    model['right'] = OrientationHead()

    if args.balancer == "gradnorm":
        model['left'].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
        model['right'].weight = torch.nn.Parameter(torch.tensor(1.0), requires_grad=True)
    elif args.balancer == "uncertainty":
        model['left'].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
        model['right'].log_var = torch.nn.Parameter(torch.tensor(0.0), requires_grad=True)
    model.to(device)

    # Define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=args.lr_decay_steps,
                                                gamma=args.lr_decay_factor)

    t_best, q_best = 1e6, 1e6
    for epoch in range(1, args.epochs + 1):
        print(f'Round: {args.round}; epoch: {epoch}')
        tr_loss, pos_tr_loss, ornt_tr_loss = do_train(model, balancer, train_dataset, optimizer, device, **train_kwargs)
        print(f'Training info: ep: {epoch}, pos_loss: {pos_tr_loss}, orientation loss: {ornt_tr_loss}')
        t_med_err, q_med_err = do_test(model, test_dataset, device, **test_kwargs)
        if t_med_err < t_best and q_med_err < q_best:
            t_best, q_best = t_med_err, q_med_err
            print(f'Save the model state')
            model_state = model.state_dict()
            torch.save({'epoch': epoch,
                        'state_dict': model_state,
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        },
                       osp.join(res_path, 'best_test.pth'))

        scheduler.step()

    print(f'Best result (medium error): pos: {t_best}, ornt: {q_best}')
    with open(osp.join(res_path, 'res.txt'), 'w') as f:
        f.write(f'Best result (medium error): pos: {t_best}, ornt: {q_best}')


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description='PoseNet settings.')
    parser.add_argument('--benchmark', type=str, default="7scenes")
    parser.add_argument('--round', type=int, default=1, help='[1..10]')
    parser.add_argument('--data_path', type=str, default="/ssd/data/7scenes",
                        help='path to the 7scenes dataset')
    parser.add_argument('--output_path', type=str, default="/data2/output/amtl",
                        help='path to store the results')
    parser.add_argument('--balancer', type=str, default="pcgrad",
                        help='balancer type: zalign, talign, mgda, mgdaub, pcgrad, gradnorm dummy')
    parser.add_argument('--scale_heads', type=bool, default=False, help='scale heads or not')
    parser.add_argument('--train_batch', type=int, default=128, help='train batch size')
    parser.add_argument('--test_batch', type=int, default=128, help='test batch size')
    parser.add_argument('--scene_name', type=str, default="chess",
                        help='scene name: fire, heads, chess, office, pumpkin, redkitchen, stairs')
    parser.add_argument('--epochs', type=int, default=120, help='number of epochs')
    parser.add_argument('--lr', type=float, default=1e-3, help='encoder learning rate')
    parser.add_argument('--lr_decay_steps', type=int, default=40, help='every n epochs decrease LR')
    parser.add_argument('--lr_decay_factor', type=float, default=0.1, help='LR decaying factor')

    args = parser.parse_args()

    experiment(args)
