import numpy as np

import torch
import torch.utils.data
import torch.optim as optim

import argparse
import wandb
import random
import pickle
from pathlib import Path
from datetime import datetime

from model_resnet_bmtas import SuperMnistResNet18, BranchMnistResNet18
from tools import read_json, train_search



parser = argparse.ArgumentParser(description='Multi-task: Split')
parser.add_argument('--dataroot', default='../dataset/multiMNIST/multi_fashion_and_mnist.pickle', type=str, help='dataset root')


parser.add_argument('--procedure', default='search', choices={'branch', 'search'}, type=str, help='the procedure of bmtas')
parser.add_argument('--configuration', default='branch_config.josn', type=str, help='the configuration of branch')
parser.add_argument('--resource_loss_weight', default=0.1, type=float,
                    help='weight of resource loss')

parser.add_argument('--eval_freq', default=10, type=int, help='the freq of evaluation')
parser.add_argument('--bz', default=256, type=int, help='the seed')

parser.add_argument('--name', default='bmtas', type=str, help='optimization method')
parser.add_argument('--project', default='multiMNIST', type=str, help='optimization method')

parser.add_argument('--optimizer', default='Adam', type=str, help='the optimizer')
parser.add_argument('--lr', default=1e-3, type=float, help='the learning rate')
parser.add_argument('--lr_arch', default=1e-2, type=float, help='the learning rate')

parser.add_argument('--milestones', default=200, type=int, nargs='+', help='the seed')
parser.add_argument('--gamma', default=0.1, type=float, help='the learning rate')

parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--n_epoch', default=100, type=int, help='the number of epoch')
parser.add_argument('--num_threads', default=1, type=int, help='the number of CPU threads')


def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

def init_seed():
    # control seed
    # torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    torch.set_num_threads(opt.num_threads)


if __name__ == '__main__':
    opt = parser.parse_args()
    init_wandb_logger(opt)
    init_seed()

    tasks = ['t1', 't2']

    with open(opt.dataroot, 'rb') as f:
        trainX, trainLabel, testX, testLabel = pickle.load(f)

    trainX = torch.from_numpy(trainX.reshape(120000,1,36,36)).float().cuda()
    trainLabel = torch.from_numpy(trainLabel).long().cuda()
    testX = torch.from_numpy(testX.reshape(20000,1,36,36)).float().cuda()
    testLabel = torch.from_numpy(testLabel).long().cuda()

    train_set = torch.utils.data.TensorDataset(trainX, trainLabel)
    test_set  = torch.utils.data.TensorDataset(testX, testLabel)

    if opt.procedure == 'search':
        time_str = datetime.now().strftime(r'%m-%d-%H-%M-%S')
        exp_dir = Path(f'{opt.name}/{time_str}/search')

        exp_dir.mkdir(parents=True, exist_ok=True)

        indices = list(range(len(train_set)))
        random.shuffle(indices)
        # split the dataset into 80% for training the weights and 20% for training the arch-params
        trainset_weight = torch.utils.data.Subset(
            train_set, indices[:int(0.8 * len(indices))])
        trainset_arch = torch.utils.data.Subset(
            train_set, indices[int(0.8 * len(indices)):])

        train_loader_weight = torch.utils.data.DataLoader(
                         dataset=trainset_weight,
                         batch_size=opt.bz,
                         shuffle=True,
                         drop_last=True)

        train_loader_arch = torch.utils.data.DataLoader(
            dataset=trainset_arch,
            batch_size=opt.bz,
            drop_last=True)

        print('building the model')
        model = SuperMnistResNet18(tasks).cuda()

        if opt.optimizer == 'SGD':
            optimizer_weight = optim.SGD(model.weight_parameters(), lr=opt.lr, momentum=0.9)
            optimizer_arch = optim.SGD(model.arch_parameters(), lr=opt.lr_arch, momentum=0.9)
            # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

        elif opt.optimizer == 'Adam':
            optimizer_weight = optim.Adam(model.weight_parameters(), lr=opt.lr)
            optimizer_arch = torch.optim.Adam(model.arch_parameters(), lr=opt.lr_arch)
            # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)
        else:
            raise ValueError(f'Error: Do not support {opt.optimizer}')

        train_search(opt, tasks, model, train_loader_arch, train_loader_weight, optimizer_arch, optimizer_weight, exp_dir=exp_dir)

    else:
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=opt.bz,
            shuffle=True,
            drop_last=True)

        test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=opt.bz,
            shuffle=False)

        branch_config = read_json(opt.configuration)['config']
        print('building the model')
        model = BranchMnistResNet18(tasks, branch_config=branch_config).cuda()






