import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler

import os
import sys
import ast
import click
import shutil
import random
import argparse
import warnings
import numpy as np

import _init_paths
import loss as custom_loss
import dataset as custom_dataset
from data_transform.transform_wrapper import get_transform
from config import cfg, update_config
from utils.utils import (
    create_logger,
    get_optimizer,
    get_scheduler,
    get_model,
    get_category_list,
)
from core.function import train_model, valid_model, test_model
from core.trainer import Trainer
from utils.reprod import fix_seed
from utils.dist import setup, cleanup

import horovod
import horovod.torch as hvd
from filelock import FileLock
from core.hvd_func import train_mixed_precision, valid_hvd


def parse_args():
    parser = argparse.ArgumentParser(description="Codes for BBN")

    parser.add_argument(
        "--cfg",
        help="decide which cfg to use",
        required=False,
        default="configs/cifar10.yaml",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()
    return args


def main_worker(args):
    update_config(cfg, args)
    logger = None

    hvd.init()
    fix_seed(cfg.seed_num)
    torch.cuda.set_device(hvd.local_rank())
    torch.set_num_threads(1)

    verbose = True if hvd.rank() == 0 else False
    if verbose:
        logger, log_file = create_logger(cfg)
        warnings.filterwarnings("ignore")

    transform_tr = get_transform(cfg, mode='train')
    transform_ts = get_transform(cfg, mode='test')

    with FileLock(os.path.expanduser("~/.horovod_lock")):
        train_set = getattr(custom_dataset, cfg.dataset.dataset)(
            cfg, train=True, download=True, transform=transform_tr)
    class_map = train_set.class_map \
        if (cfg.dataset.dataset == 'ImageNetLT') or (cfg.dataset.dataset == 'Places') \
        else None
    valid_set = getattr(custom_dataset, cfg.dataset.dataset)(
        cfg, train=False, download=True, transform=transform_ts, class_map=class_map)

    if not isinstance(train_set.targets, torch.Tensor):
        train_set.targets = torch.tensor(train_set.targets, dtype=torch.long)
    num_classes = len(torch.unique(train_set.targets))
    num_class_list, ctgy_list = get_category_list(train_set.targets, num_classes, cfg)

    param_dict = {
        'num_classes': num_classes,
        'num_class_list': num_class_list,
        'cfg': cfg,
        'rank': hvd.rank(),
    }

    trainsampler = DistributedSampler(train_set, num_replicas=hvd.size(), rank=hvd.rank())
    validsampler = DistributedSampler(valid_set, num_replicas=hvd.size(), rank=hvd.rank())
    
    dl_kwargs = {'num_workers': cfg.train.num_workers, 'pin_memory': True}
    if (dl_kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
            mp.__supports_context and 'forkserver' in mp.get_all_start_methods()):
        dl_kwargs['multiprocessing_context'] = 'forkserver'

    trainloader = DataLoader(
        train_set, batch_size=cfg.train.batch_size, sampler=trainsampler, **dl_kwargs)
    validloader = DataLoader(
        valid_set, batch_size=2*cfg.train.batch_size, sampler=validsampler, **dl_kwargs)

    num_epochs = cfg.train.num_epochs
    # ----- BEGIN model builder -----
    trainer = Trainer(cfg, hvd.rank())
    criterion = getattr(custom_loss, cfg.loss.loss_type)(param_dict=param_dict).cuda(hvd.rank())

    model = get_model(cfg, num_classes, hvd.rank())

    optimizer = get_optimizer(cfg, model)
    scheduler = get_scheduler(cfg, optimizer)

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    
    optimizer = hvd.DistributedOptimizer(
        optimizer, 
        named_parameters=model.named_parameters(),
        compression=hvd.Compression.fp16,
        op=hvd.Average,
        gradient_predivide_factor=1.0)
    scaler = torch.cuda.amp.GradScaler()
    # ----- END model builder -----

    if verbose:
        model_dir = os.path.join(cfg.output_dir, cfg.name, 'seed{:03d}'.format(cfg.seed_num), "models")
        code_dir = os.path.join(cfg.output_dir, cfg.name, 'seed{:03d}'.format(cfg.seed_num), "codes")
        tensorboard_dir = (
            os.path.join(cfg.output_dir, cfg.name, 'seed{:03d}'.format(cfg.seed_num), "tensorboard")
            if cfg.train.tensorboard.enable else None
        )
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        else:
            shutil.rmtree(code_dir)
            if (tensorboard_dir is not None) and os.path.exists(tensorboard_dir):
                shutil.rmtree(tensorboard_dir)
        print("=> output model will be saved in {}".format(model_dir))
        current_dir = os.path.dirname(__file__)
        ignore = shutil.ignore_patterns(
            '*.pyc', '*.so', '*.out', '*pycache*', '*.pth', '*build*', '*output*', '*datasets*'
        )
        shutil.copytree(os.path.join(current_dir, '..'), code_dir, ignore=ignore)

        if tensorboard_dir is not None:
            dummy_input = torch.rand((1, 3) + cfg.input_size).cuda(hvd.rank())
            writer = SummaryWriter(log_dir=tensorboard_dir)
            pooling_module = model.module.pooling if cfg.ddp else model.pooling
            writer.add_graph(pooling_module, (dummy_input,))
        else:
            writer = None

    best_result, best_epoch, start_epoch = 0, 0, 1
    save_step = cfg.save_step if cfg.save_step != -1 else num_epochs

    if verbose:
        logger.info(
            "-------------------Train start: {} {} {} | {} {} / {}--------------------".format(
                cfg.backbone.type, cfg.pooling.type, cfg.reshape.type, 
                cfg.classifier.type, cfg.scaling.type, 
                cfg.train.trainer.type,
            )
        )

    for epoch in range(start_epoch, num_epochs + 1):
        if (epoch > start_epoch) and (scheduler is not None):
            scheduler.step()
        trainsampler.set_epoch(epoch)
        # train
        train_acc, train_loss = train_mixed_precision(
            trainloader, model,
            epoch, num_epochs,
            optimizer, trainer, criterion,
            cfg, logger, verbose, scaler
        )
        if verbose:
            model_save_path = os.path.join(
                model_dir,
                'epoch_{}.pth'.format(epoch),
            )
            if epoch % save_step == 0:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'epoch': epoch,
                        'best_result': best_result,
                        'best_epoch': best_epoch,
                        'scheduler': scheduler.state_dict() if scheduler is not None else None,
                        'optimizer': optimizer.state_dict()
                    }, model_save_path
                )
        loss_dict, acc_dict = {'train_loss': train_loss}, {'train_acc': train_acc}

        # valid
        if (cfg.valid_step != -1) and (epoch % cfg.valid_step == 0):
            valid_acc, valid_loss = valid_hvd(
                validloader, model,
                epoch, 
                criterion, 
                cfg, logger, verbose, hvd.rank()
            )
            loss_dict['valid_loss'], acc_dict['valid_acc'] = valid_loss, valid_acc
            if verbose:
                if valid_acc >= best_result:
                    best_result, best_epoch = valid_acc, epoch
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'epoch': epoch,
                            'best_result': best_result,
                            'best_epoch': best_epoch,
                            'scheduler': scheduler.state_dict() if scheduler is not None else None,
                            'optimizer': optimizer.state_dict(),
                        }, os.path.join(model_dir, 'best_model.pth')
                    )
                logger.info(
                    "-------------Best Epoch:{:>3d}   Best Acc:{:>5.2f}%------------".format(
                        best_epoch, best_result * 100
                    )
                )

        if cfg.train.tensorboard.enable and verbose:
            writer.add_scalars('scalar/acc', acc_dict, epoch)
            writer.add_scalars('scalar/loss', loss_dict, epoch)
    if cfg.train.tensorboard.enable and verbose:
        writer.close()
    if verbose:
        logger.info(
            "-------------------Train Finished: {} (seed:{})-------------------".format(cfg.name, cfg.seed_num)
        )


if __name__ == "__main__":
    args = parse_args()
    update_config(cfg, args)

    main_worker(args)

