import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argparse
import torch.utils.data.sampler as sampler
from collections import OrderedDict
import wandb

from create_dataset import *
from model_segnet_fw import SegNet_fw
from utils import *

parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--type', default='standard', type=str, help='split type: standard, wide, deep')
parser.add_argument('--dataroot', default='nyuv2', type=str, help='dataset root')
parser.add_argument('--optimizer', default='SGD', type=str, help='dataset root')
parser.add_argument('--base_model', default='SegNet', type=str, help='dataset root')
parser.add_argument('--name', default='CAGrad', type=str, help='optimization method')
parser.add_argument('--project', default='Multi-task', type=str, help='optimization method')
parser.add_argument('--gamma', default=0.5, type=float, help='the alpha')
parser.add_argument('--lr', default=1e-4, type=float, help='the learning rate')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--task_id', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=200, type=int, help='the seed')
parser.add_argument('--milestones', default=200, type=int, nargs='+', help='the seed')
parser.add_argument('--eval_freq', default=20, type=int, help='the freq of evaluation')
parser.add_argument('--apply_augmentation', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--drop_last', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--bz', default=8, type=int, help='the seed')
parser.add_argument('--num_threads', default=8, type=int, help='the number of CPU threads')

def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed(opt):
    # control seed
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.enabled = False

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)


    torch.set_num_threads(opt.num_threads)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

if __name__ == '__main__':
    opt = parser.parse_args()

    init_wandb_logger(opt)
    # control seed
    init_seed(opt)

    # define model, optimiser and scheduler
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    if opt.base_model == 'fw':
        SegNet_MTAN = SegNet_fw(branched='empty').to(device)
    elif opt.base_model == 'fw_b':
        SegNet_MTAN = SegNet_fw(branched='branch').to(device)
    elif opt.base_model == 'fw_ablation':
        SegNet_MTAN = SegNet_fw(branched='ablation').to(device)
    elif opt.base_model == 'fw_na':
        SegNet_MTAN = SegNet_fw(branched='empty', attention=False).to(device)
    elif opt.base_model == 'fw_b_na':
        SegNet_MTAN = SegNet_fw(branched='branch', attention=False).to(device)


    if opt.optimizer == 'SGD':
        optimizer = optim.SGD(SegNet_MTAN.parameters(), lr=opt.lr, momentum=0.9)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    elif opt.optimizer == 'Adam':
        optimizer = optim.Adam(SegNet_MTAN.parameters(), lr=opt.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)


    print('Parameter Space: ABS: {:.1f}, REL: {:.4f}'.format(count_parameters(SegNet_MTAN),
                                                             count_parameters(SegNet_MTAN) / 24981069))
    print('LOSS FORMAT: SEMANTIC_LOSS MEAN_IOU PIX_ACC | DEPTH_LOSS ABS_ERR REL_ERR | NORMAL_LOSS MEAN MED <11.25 <22.5 <30')

    # define dataset
    dataset_path = opt.dataroot
    if opt.apply_augmentation:
        train_set = CityScapes(root=dataset_path, train=True, augmentation=True)
        print('Applying data augmentation.')
    else:
        train_set = CityScapes(root=dataset_path, train=True)
        print('Standard training strategy without data augmentation.')

    test_set = CityScapes(root=dataset_path, train=False)

    batch_size = 8

    g = torch.Generator()
    g.manual_seed(0)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        num_workers=4,
        shuffle=True,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
        drop_last=True)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=batch_size,
        worker_init_fn=seed_worker,
        generator=g,
        num_workers=4,
        shuffle=False)

    # train_loader, test_loader, single_task_model, optimizer, scheduler, opt, total_epoch=200
    single_task_trainer(train_loader,
                           test_loader,
                           SegNet_MTAN,
                           optimizer,
                           scheduler,
                           opt,
                           opt.n_epoch)
