import numpy as np

import torch
import torch.utils.data
import torch.optim as optim

import argparse
import wandb
import random
import pickle

# from rotograd import RotoGrad
from create_dataset import *
from model_roto_segnet import SegNet_RotoGrad
from utils import train_rotograd

parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--type', default='standard', type=str, help='split type: standard, wide, deep')
parser.add_argument('--dataroot', default='../dataset/cityspaces', type=str, help='dataset root')
parser.add_argument('--optimizer', default='Adam', type=str, help='dataset root')
parser.add_argument('--name', default='RotoGrad', type=str, help='optimization method')
parser.add_argument('--project', default='DEBUG', type=str, help='optimization method')

parser.add_argument('--lr', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--lr_R', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--milestones', default=[50], type=int, nargs='+', help='the seed')

parser.add_argument('--gamma', default=0.1, type=float, help='the learning rate')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=200, type=int, help='the seed')

parser.add_argument('--eval_freq', default=20, type=int, help='the freq of evaluation')
parser.add_argument('--bz', default=8, type=int, help='the seed')
parser.add_argument('--topK', default=15, type=int, help='the seed')
parser.add_argument('--num_threads', default=1, type=int, help='the number of CPU threads')


def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed(opt):
    # control seed
    # torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.enabled = False

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)

    torch.set_num_threads(opt.num_threads)



if __name__ == '__main__':
    opt = parser.parse_args()
    init_wandb_logger(opt)
    init_seed(opt)

    train_set = CityScapes(root=opt.dataroot, train=True, augmentation=True)
    test_set = CityScapes(root=opt.dataroot, train=False)

    train_loader = torch.utils.data.DataLoader(
                     dataset=train_set,
                     batch_size=opt.bz,
                     shuffle=True,
                     drop_last=True)

    test_loader = torch.utils.data.DataLoader(
                    dataset=test_set,
                    batch_size=1,
                    shuffle=False
                    )

    model = SegNet_RotoGrad(latent_size=1024).to(torch.device('cuda'))

    params_model = [{'params': m.parameters()} for m in [model.backbone] + model.heads]
    # params_model = model.parameters()
    params_leader = [{'params': model.parameters()}]

    if opt.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)
        optimizer_main = optim.SGD(params_model, lr=opt.lr, momentum=0.9)
        optimizer_R = optim.SGD(params_leader, lr=opt.lr_R, momentum=0.9)
        scheduler_main = optim.lr_scheduler.MultiStepLR(optimizer_main, milestones=opt.milestones, gamma=opt.gamma)
        scheduler_R = optim.lr_scheduler.MultiStepLR(optimizer_R, milestones=opt.milestones, gamma=opt.gamma)

    elif opt.optimizer == 'Adam':
        optimizer_main = optim.Adam(params_model, lr=opt.lr)
        optimizer_R = optim.Adam(params_leader, lr=opt.lr_R)
        scheduler_main = optim.lr_scheduler.MultiStepLR(optimizer_main, milestones=opt.milestones, gamma=opt.gamma)
        scheduler_R = optim.lr_scheduler.MultiStepLR(optimizer_R, milestones=opt.milestones, gamma=opt.gamma)

    train_rotograd(train_loader, test_loader, model,
                           optimizer_main, scheduler_main, optimizer_R, scheduler_R, opt)


