import numpy as np

import torch
import torch.utils.data
import torch.optim as optim

import argparse
import wandb
import random
import pickle

# from rotograd import RotoGrad
from utils import train_rotograd

parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--type', default='standard', type=str, help='split type: standard, wide, deep')
parser.add_argument('--weight', default='equal', type=str, help='multi-task weighting: equal, uncert, dwa')
parser.add_argument('--dataroot', default='../dataset/multiMNIST/multi_fashion_and_mnist.pickle', type=str, help='dataset root')
parser.add_argument('--optimizer', default='Adam', type=str, help='dataset root')
parser.add_argument('--method', default='rotograd', type=str, help='optimization method')
# parser.add_argument('--method_sub', default='nothing', type=str, help='optimization method')
parser.add_argument('--base_model', default='fw', type=str, help='optimization method')
parser.add_argument('--name', default='CAGrad', type=str, help='optimization method')

parser.add_argument('--project', default='multiMNIST', type=str, help='optimization method')
parser.add_argument('--branch_mode', default='task_angle', type=str, help='optimization method')

parser.add_argument('--lr', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--lr_R', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--milestones', default=[50], type=int, nargs='+', help='the seed')

parser.add_argument('--gamma', default=0.1, type=float, help='the learning rate')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=100, type=int, help='the seed')
parser.add_argument('--start_epoch', default=-1, type=int, help='the seed')

parser.add_argument('--eval_freq', default=4, type=int, help='the freq of evaluation')
parser.add_argument('--bz', default=256, type=int, help='the seed')
parser.add_argument('--topK', default=15, type=int, help='the seed')
parser.add_argument('--num_threads', default=1, type=int, help='the number of CPU threads')


def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed():
    # control seed
    # torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    torch.set_num_threads(opt.num_threads)


if __name__ == '__main__':
    opt = parser.parse_args()
    init_wandb_logger(opt)
    init_seed()

    with open(opt.dataroot, 'rb') as f:
        trainX, trainLabel, testX, testLabel = pickle.load(f)

    trainX = torch.from_numpy(trainX.reshape(120000,1,36,36)).float().cuda()
    trainLabel = torch.from_numpy(trainLabel).long().cuda()
    testX = torch.from_numpy(testX.reshape(20000,1,36,36)).float().cuda()
    testLabel = torch.from_numpy(testLabel).long().cuda()

    train_set = torch.utils.data.TensorDataset(trainX, trainLabel)

    test_set  = torch.utils.data.TensorDataset(testX, testLabel)

    train_loader = torch.utils.data.DataLoader(
                     dataset=train_set,
                     batch_size=opt.bz,
                     shuffle=True,
                     drop_last=True)

    test_loader = torch.utils.data.DataLoader(
                    dataset=test_set,
                    batch_size=opt.bz,
                    shuffle=False
                    )

    model = MnistResNet_RotoGrad(latent_size=100).to(torch.device('cuda'))

    params_model = [{'params': m.parameters()} for m in [model.backbone] + model.heads]
    params_leader = [{'params': model.parameters()}]

    if opt.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    elif opt.optimizer == 'Adam':
        optimizer_main = optim.Adam(params_model, lr=opt.lr)
        optimizer_R = optim.Adam(params_leader, lr=opt.lr_R)
        scheduler_main = optim.lr_scheduler.MultiStepLR(optimizer_main, milestones=opt.milestones, gamma=opt.gamma)
        scheduler_R = optim.lr_scheduler.MultiStepLR(optimizer_R, milestones=opt.milestones, gamma=opt.gamma)

    train_rotograd(train_loader, test_loader, model,
                           optimizer_main, scheduler_main, optimizer_R, scheduler_R, opt)


