import numpy as np

import torch
import torch.utils.data
import torch.optim as optim

import argparse
import wandb
import random
import pickle

# from model_resnet_minist import MnistResNet_v3
from model_lenet import RegressionModel
from model_resnet import MnistResNet_RotoGrad
from model_resnet_fw import MnistResNet_fw
from noise_utils import noisify
from utils import multi_task_mgd_trainer

parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--type', default='standard', type=str, help='split type: standard, wide, deep')
parser.add_argument('--weight', default='equal', type=str, help='multi-task weighting: equal, uncert, dwa')
parser.add_argument('--dataroot', default='../dataset/multiMNIST/multi_fashion_and_mnist.pickle', type=str, help='dataset root')
parser.add_argument('--optimizer', default='Adam', type=str, help='dataset root')
parser.add_argument('--method', default='cagrad', type=str, help='optimization method')
parser.add_argument('--flag', default='RSL', type=str, help='optimization method')
parser.add_argument('--method_sub', default='nothing', type=str, help='optimization method')
parser.add_argument('--base_model', default='lenet', type=str, help='optimization method')
parser.add_argument('--ablation_file', default='./file_name.json', type=str, help='optimization method')
parser.add_argument('--name', default='CAGrad', type=str, help='optimization method')
parser.add_argument('--project', default='multiMNIST', type=str, help='optimization method')
parser.add_argument('--branch_mode', default='task_angle', type=str, help='optimization method')
parser.add_argument('--temp', default=2.0, type=float, help='temperature for DWA (must be positive)')
parser.add_argument('--lr', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--milestones', default=200, type=int, nargs='+', help='the seed')
parser.add_argument('--noise_rate', default=0.0, type=float, help='the learning rate')
parser.add_argument('--freeze_level', default=0.0, type=float, help='the learning rate')
parser.add_argument('--alpha', default=0.5, type=float, help='the learning rate')
parser.add_argument('--sigma', default=2, type=float, help='the learning rate')
parser.add_argument('--omega', default=0.4, type=float, help='the learning rate')
parser.add_argument('--gamma', default=0.1, type=float, help='the learning rate')
parser.add_argument('--stage1', default=0.0, type=float, help='the learning rate')
parser.add_argument('--stage2', default=0.4, type=float, help='the learning rate')
parser.add_argument('--random_rate', default=0.18, type=float, help='the learning rate')
parser.add_argument('--t1_flood', default=0.08, type=float, help='the learning rate')
parser.add_argument('--t2_flood', default=0.27, type=float, help='the learning rate')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=100, type=int, help='the seed')
parser.add_argument('--start_epoch', default=-1, type=int, help='the seed')
parser.add_argument('--eval_freq', default=4, type=int, help='the freq of evaluation')
parser.add_argument('--flood', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--ignore', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--drop_last', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--bz', default=256, type=int, help='the seed')
parser.add_argument('--topK', default=15, type=int, help='the seed')
parser.add_argument('--num_threads', default=1, type=int, help='the number of CPU threads')


def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed(opt):
    # control seed
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.enabled = False

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)

    torch.set_num_threads(opt.num_threads)



if __name__ == '__main__':
    opt = parser.parse_args()
    init_wandb_logger(opt)
    init_seed(opt)

    with open(opt.dataroot, 'rb') as f:
        trainX, trainLabel, testX, testLabel = pickle.load(f)

    if opt.noise_rate > 0:
        trainLabel_t1, actual_noise_t1 = noisify(train_labels=trainLabel[:, 0], nb_classes=10, noise_rate=opt.noise_rate)
        trainLabel_t2, actual_noise_t2 = noisify(train_labels=trainLabel[:, 1], nb_classes=10, noise_rate=opt.noise_rate)

        trainLabel = np.array([trainLabel_t1, trainLabel_t2]).transpose()

    trainX = torch.from_numpy(trainX.reshape(120000,1,36,36)).float().cuda()
    trainLabel = torch.from_numpy(trainLabel).long().cuda()
    testX = torch.from_numpy(testX.reshape(20000,1,36,36)).float().cuda()
    testLabel = torch.from_numpy(testLabel).long().cuda()

    train_set = torch.utils.data.TensorDataset(trainX, trainLabel)

    test_set  = torch.utils.data.TensorDataset(testX, testLabel)

    train_loader = torch.utils.data.DataLoader(
                     dataset=train_set,
                     batch_size=opt.bz,
                     shuffle=True,
                     drop_last=True)

    test_loader = torch.utils.data.DataLoader(
                    dataset=test_set,
                    batch_size=opt.bz,
                    shuffle=False
                    )


    if opt.base_model == 'lenet':
        model = RegressionModel()
        model = model.cuda()
    # elif opt.base_model == 'resnet18':
    #     model = MnistResNet()
    #     model = model.cuda()
    elif opt.base_model == 'roto':
        model = MnistResNet_RotoGrad(latent_size=100)
        model = model.cuda()
    elif opt.base_model == 'bmtas':
        model = MnistResNet_fw(branched='bmtas', topK=opt.topK)
        model = model.cuda()
    # elif opt.base_model == 'resnet18_v2':
    #     model = MnistResNet_v2()
    #     model = model.cuda()
    # elif opt.base_model == 'resnet18_v3':
    #     model = MnistResNet_v3()
    #     model = model.cuda()
    elif opt.base_model == 'fw':
        model = MnistResNet_fw()
        model = model.cuda()
    elif opt.base_model == 'fw_b':
        model = MnistResNet_fw(branched='branch', topK=opt.topK)
        model = model.cuda()
    elif opt.base_model == 'fw_ablation':
        model = MnistResNet_fw(branched='ablation', topK=opt.topK, ablation_file=opt.ablation_file)
        model = model.cuda()
    else:
        raise ValueError('Error')

    if opt.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    elif opt.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=opt.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    multi_task_mgd_trainer(train_loader,
                           test_loader,
                           model,
                           optimizer,
                           scheduler,
                           opt)


