import os
import argparse
from pathlib import Path
import torch
import random
import numpy as np
import wandb

from src import dataset
from src import models
from src import modules
from src import criterion
from src import tools
from src import utils
from src import metrics


parser = argparse.ArgumentParser()
parser.add_argument('--configuration', type=str, required=True,
                    help='path to the branching configuration file')
parser.add_argument('--seed', default=0, type=int, help='the seed')
parser.add_argument('--epoch', default=130, type=int, help='the number of epoch')
parser.add_argument('--start_epoch', default=130, type=int, help='the number of epoch for branch search')
parser.add_argument('--bz', default=12, type=int, help='the batch size')
parser.add_argument('--lr', default=0.005, type=float, help='the learning rate')
parser.add_argument('--num_threads', default=4, type=int, help='the number of threads')
parser.add_argument('--data_root', default='.',
                    type=str, help='dataset root dir')
parser.add_argument('--pretrain', action='store_true', help='whether used the pretrain weights')
parser.add_argument('--method', default='nothing', choices={'bmtas', 'nothing', 'mgd', 'graddrop', 'pcgrad', 'cagrad', 'branch_layer', 'bmtas_branch'},type=str, help='method')
parser.add_argument('--method_sub', default='nothing', choices={'bmtas', 'nothing', 'pcgrad', 'cagrad'},type=str, help='method')
parser.add_argument('--base_model', default='fw', choices={'fw', 'fw_b', 'Normal', 'fw_v2', 'fw_b_v2'},type=str, help='method')
parser.add_argument('--topK', default=15, type=int, help='the topK layer will be branched')
parser.add_argument('--optimizer', default='SGD', choices={'SGD', 'Adam'},type=str, help='method')
parser.add_argument('--gamma', default=0.1, type=float, help='the gamma of the scheduler')
parser.add_argument('--alpha', default=0.4, type=float, help='the alpha of CAGrad')
parser.add_argument('--milestones', default=[100], type=int, nargs='+', help='the milestones of the scheduler')

parser.add_argument('--name', default='BMTAS', type=str, help='optimization method')
parser.add_argument('--project', default='PASCALCONTEXT', type=str, help='optimization method')
parser.add_argument('--tasks', default='semseg,human_parts,sal,normals,edge', type=str,
                    help='tasks to train, comma-separated, order matters!')
parser.add_argument('--resume_path', type=str,
                    help='path to model to resume')

# torch.backends.cudnn.benchmark = True


def main(local_rank, world_size, opt):

    configuration = opt.configuration
    data_root = opt.data_root
    tasks = opt.tasks.split(',')
    resume_path = opt.resume_path

    printf = utils.distributed_print(local_rank)
    if world_size > 1:
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '8888'
        torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=world_size,
            rank=local_rank
        )
    device = torch.device('cuda:{}'.format(local_rank)
                          if world_size > 0 else 'cpu')

    # set up dataloader
    printf('setting up dataloader...')
    trainset = dataset.PASCALContext(
        data_dir=data_root, split='train', transforms=True, tasks=tasks, download=False)
    if world_size > 1:
        assert (16 % world_size) == 0
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            trainset,
            num_replicas=world_size,
            rank=local_rank,
            shuffle=True,
        )
    else:
        train_sampler = None
    trainloader = torch.utils.data.DataLoader(
        dataset=trainset,
        batch_size=opt.bz // max(1, world_size),
        num_workers=int((4 + max(1, world_size) - 1) / max(1, world_size)),
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        pin_memory=True,
        drop_last=True
    )

    testset = dataset.PASCALContext(
        data_dir=data_root, split='val', transforms=True, tasks=tasks, download=False)
    testloader = torch.utils.data.DataLoader(
        dataset=testset, batch_size=1, shuffle=False, pin_memory=True)


    # build model architecture
    printf('building the model and loss...')
    branch_config = utils.read_json(configuration)['config']

    if opt.method in ['nothing', 'cagrad', 'pcgrad', 'branch_layer', 'graddrop', 'mgd']:
        if opt.base_model == 'fw':
            model = models.MoblieNetV2_fw(tasks, pretrain=opt.pretrain)
        elif opt.base_model == 'fw_b':
            model = models.MoblieNetV2_fw(tasks, branched='branch', topK=opt.topK, pretrain=opt.pretrain)

    elif opt.method == 'bmtas' or opt.method == 'bmtas_branch':
        if opt.base_model == 'nothing':
            model = models.BranchMobileNetV2(tasks, branch_config=branch_config, pretrain=opt.pretrain)
        elif opt.base_model == 'fw_b':
            model = models.BranchMobileNetV2_fw(tasks, branch_config=branch_config, branched='branch', topK=opt.topK, pretrain=opt.pretrain)
        elif  opt.base_model == 'fw':
            model = models.BranchMobileNetV2_fw(tasks, branch_config=branch_config,
                                                pretrain=opt.pretrain)
        elif opt.base_model == 'fw_b_v2':
            model = models.BranchMobileNetV2_fw_v2(tasks, branch_config=branch_config, branched='branch', topK=opt.topK, pretrain=opt.pretrain)
        elif  opt.base_model == 'fw_v2':
            model = models.BranchMobileNetV2_fw_v2(tasks, branch_config=branch_config,
                                                pretrain=opt.pretrain)
    else:
        raise ValueError(f'Error: Do not support method {opt.method}')

    if world_size > 1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if opt.method in ['nothing', 'bmtas']:
        loss = criterion.WeightedSumLoss(tasks)
    else:
        loss = criterion.LossOfEachTask(tasks)

    model = model.to(device)
    loss = loss.to(device)
    if world_size > 1:
        model = modules.MyDataParallel(model,
                                       device_ids=[
                                           local_rank],
                                       output_device=local_rank)

    # build optimization tools
    printf('building optimization tools...')
    max_epochs = opt.epoch  # around 40000 iterations with batchsize 16

    if opt.optimizer == 'SGD':
        optimizer = torch.optim.SGD(
            lr=opt.lr, momentum=0.9, weight_decay=1e-4, params=model.parameters())

        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

        # poly learning rate schedule
        # scheduler = torch.optim.lr_scheduler.LambdaLR(
        #     optimizer, lambda ep: (1 - float(ep) / max_epochs) ** 0.9)

    elif opt.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=opt.gamma)

    # in case we resume...
    start_epoch = 1
    if resume_path is not None:
        printf('resuming saved model...')
        checkpoint = torch.load(resume_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch'] + 1

    printf('setup complete, start training...')


    exp_dir = Path(f'./exp_results/{opt.name}/')
    exp_dir.mkdir(parents=True, exist_ok=True)
    edge_save_dir = Path(exp_dir) / 'edge'
    edge_save_dir.mkdir(parents=True, exist_ok=True)

    # metrics_dict = {
    #     'semseg': metrics.MeanIoU(task='semseg', n_classes=21),
    #     'human_parts': metrics.MeanIoU(task='human_parts', n_classes=7),
    #     'sal': metrics.ThresholdedMeanIoU(task='sal', thresholds=[x / 20. for x in range(4, 19)]),
    #     'normals': metrics.MeanErrorInAngle(task='normals'),
    #     'edge': metrics.SavePrediction(task='edge', save_dir=edge_save_dir)
    # }

    metrics_dict = {
        'semseg': metrics.ConfMatrix(task='semseg', n_classes=21),
        'human_parts': metrics.ConfMatrix(task='human_parts', n_classes=7),
        'sal': metrics.ThresholdedMeanIoU(task='sal', thresholds=[x / 20. for x in range(4, 19)]),
        'normals': metrics.Normal_error(task='normals'),
        'edge': metrics.SavePrediction(task='edge', save_dir=edge_save_dir)
    }

    if opt.method == 'bmtas' or opt.method == 'nothing':
        tools.train_branched(local_rank,
                             world_size,
                             device,
                             start_epoch,
                             max_epochs,
                             tasks,
                             trainloader,
                             testloader,
                             model,
                             loss,
                             optimizer,
                             scheduler,
                             metrics_dict,
                             exp_dir)
    else:
        tools.train(opt, local_rank,
                             world_size,
                             device,
                             start_epoch,
                             max_epochs,
                             tasks,
                             trainloader,
                             testloader,
                             model,
                             loss,
                             optimizer,
                             scheduler,
                             metrics_dict,
                             exp_dir)

    printf('training finished!')

def init_seed(opt):
    # control seed
    # torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    torch.cuda.manual_seed_all(opt.seed)
    torch.set_num_threads(opt.num_threads)

def init_wandb_logger(opt):
    """We now only use wandb to sync tensorboard log."""
    wandb_id = wandb.util.generate_id()
    resume = 'never'

    run = wandb.init(
        id=wandb_id,
        resume=resume,
        name=opt.name,
        config=opt,
        project=opt.project,
        sync_tensorboard=False)

if __name__ == '__main__':
    opt = parser.parse_args()
    init_seed(opt)
    init_wandb_logger(opt)
    world_size = torch.cuda.device_count()  # only support training on one node
    if world_size > 1:
        torch.multiprocessing.spawn(
            main, nprocs=world_size, args=(world_size, opt))
    else:
        main(0, world_size, opt)
