import numpy as np

import torch
import torch.utils.data
import torch.optim as optim

import argparse
import wandb
import random
import pickle

# from model_resnet_minist import MnistResNet_v3
from model_lenet import RegressionModel
# from model_resnet import MnistResNet
from model_resnet_fw import MnistResNet_fw

from noise_utils import noisify
from utils import multi_task_mgd_trainer, single_task_mgd_trainer

parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--type', default='standard', type=str, help='split type: standard, wide, deep')
parser.add_argument('--dataroot', default='../dataset/multiMNIST/multi_fashion_and_mnist.pickle', type=str, help='dataset root')
parser.add_argument('--optimizer', default='Adam', type=str, help='dataset root')
parser.add_argument('--base_model', default='lenet', type=str, help='optimization method')
parser.add_argument('--name', default='CAGrad', type=str, help='optimization method')
parser.add_argument('--project', default='multiMNIST', type=str, help='optimization method')
parser.add_argument('--lr', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--gamma', default=0.1, type=float, help='the learning rate')
parser.add_argument('--milestones', default=200, type=int, nargs='+', help='the seed')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=100, type=int, help='the seed')
parser.add_argument('--eval_freq', default=4, type=int, help='the freq of evaluation')
parser.add_argument('--task_id', default=0, type=int, help='the freq of evaluation')
parser.add_argument('--drop_last', action='store_true', help='toggle to apply data augmentation on NYUv2')
parser.add_argument('--bz', default=256, type=int, help='the seed')
parser.add_argument('--num_threads', default=1, type=int, help='the number of CPU threads')


def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed(opt):
    # control seed
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.enabled = False

    random.seed(opt.seed)
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)

    torch.set_num_threads(opt.num_threads)


if __name__ == '__main__':
    opt = parser.parse_args()
    init_wandb_logger(opt)
    init_seed(opt)

    with open(opt.dataroot, 'rb') as f:
        trainX, trainLabel, testX, testLabel = pickle.load(f)

    trainX = torch.from_numpy(trainX.reshape(120000,1,36,36)).float().cuda()
    trainLabel = torch.from_numpy(trainLabel).long().cuda()
    testX = torch.from_numpy(testX.reshape(20000,1,36,36)).float().cuda()
    testLabel = torch.from_numpy(testLabel).long().cuda()

    train_set = torch.utils.data.TensorDataset(trainX, trainLabel)

    test_set = torch.utils.data.TensorDataset(testX, testLabel)

    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=opt.bz,
        shuffle=True,
        drop_last=True)

    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=opt.bz,
        shuffle=False
    )

    if opt.base_model == 'lenet':
        model = RegressionModel()
        model = model.cuda()
    # elif opt.base_model == 'resnet18':
    #     model = MnistResNet()
    #     model = model.cuda()
    # elif opt.base_model == 'resnet18_v2':
    #     model = MnistResNet_v2()
    #     model = model.cuda()
    # elif opt.base_model == 'resnet18_v3':
    #     model = MnistResNet_v3()
    #     model = model.cuda()
    elif opt.base_model == 'fw':
        model = MnistResNet_fw()
        model = model.cuda()
    else:
        raise ValueError('Error')

    if opt.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    elif opt.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=opt.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    single_task_mgd_trainer(train_loader,
                           test_loader,
                           model,
                           optimizer,
                           scheduler,
                           opt)


