
import argparse
import copy
import os
import os.path as osp
import time
import numpy as np
import subprocess
import mmcv
import torch

from mmcv.utils import Config, get_git_hash
from mmseg import __version__
from mmseg.apis import set_random_seed
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env
from mmcv_custom import train_segmentor
from config import get_config
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from logger import create_logger


try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None



# def parse_args():
#     parser.add_argument(
#         '--load-from', help='the checkpoint file to load weights from')
#     parser.add_argument(
#         '--resume-from', help='the checkpoint file to resume from')

def main(config, log_writer, seed, args):
    # logger.info("ARGS is:")
    # logger.info(args)
    
    cfg = Config.fromfile(args.config)
    cfg.pretrained = args.pretrained
    cfg.model.pretrained = args.pretrained
 
    cfg.model.backbone.use_checkpoint = True
    cfg.model.backbone.out_with_norm = True
    cfg.optimizer.lr = 3e-5

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.load_from is not None:
        cfg.load_from = args.load_from
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from

    cfg.gpu_ids = [i for i in range(args.world_size)]

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # # dump config
    
    # cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
    logger.info("cfg is:")
    logger.info(cfg)

    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')

    seed = config.SEED + dist.get_rank()
    set_random_seed(seed)

    # init the meta dict to record some important information such as
    meta = dict()
    # log env info
    env_info_dict = collect_env()
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
    # logger.info(f'Config:\n{cfg.pretty_text}')

    
    cfg.seed = args.seed
    meta['seed'] = args.seed
    meta['exp_name'] = osp.basename(args.config)

    model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))

    logger.info(model)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmseg version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
            # config=cfg.pretty_text,
            config=cfg,
            CLASSES=datasets[0].CLASSES,
            PALETTE=datasets[0].PALETTE)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    distributed = False
    if args.distributed:
        distributed = True
    train_segmentor(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=(not args.no_validate),
        timestamp=timestamp,
        meta=meta)



def parse_option():
    parser = argparse.ArgumentParser('MFM Train a segmentor script', add_help=False)
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--fast-eval',
        action='store_true',
        help='whether to perform fast evaluation.')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--load-from', help='the checkpoint file to load weights from')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')

    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--pretrained', type=str, help='path to pre_trained model')
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')

    parser.add_argument('--tag', help='tag of experiment')
    # distributed training
    parser.add_argument("--distributed", action="store_true", help="Using distributed")
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--port', type=int, default=29500, help='port only works when launcher=="slurm"')
    parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--local_rank', default=-1, type=int, help='local rank for distributed training')

    args = parser.parse_args()

    config = get_config(args)
    

    return args, config


if __name__ == '__main__':
    args, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    cuda_version = torch.version.cuda

    # Check for cuDNN version
    cudnn_version = torch.backends.cudnn.version()

    print(f"CUDA Version: {cuda_version}")
    print(f"cuDNN Version: {cudnn_version}")

    ## initialize slurm distributed training environment
    proc_id = int(os.environ['SLURM_PROCID'])
    print(f"proc_id {proc_id}")
    ntasks = int(os.environ['SLURM_NTASKS'])
    print(f"ntasks {ntasks}")
    node_list = os.environ['SLURM_NODELIST']
    print(f"node_list {node_list}")
    num_gpus = torch.cuda.device_count()
    print(f"num_gpus {num_gpus}")
    print(f"torch.cuda.is_available() {torch.cuda.is_available()}")
    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(
        f'scontrol show hostname {node_list} | head -n1')
    # specify master port
    if args.port is not None:
        os.environ['MASTER_PORT'] = str(args.port)
    elif 'MASTER_PORT' in os.environ:
        pass  # use MASTER_PORT in the environment variable
    else:
        # 29500 is torch.distributed default port
        os.environ['MASTER_PORT'] = '29500'
    os.environ['MASTER_ADDR'] = addr
    os.environ['WORLD_SIZE'] = str(ntasks)
    os.environ['RANK'] = str(proc_id)
    dist.init_process_group(backend='nccl')
    world_size = torch.cuda.device_count()
    args.world_size = world_size
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    print(f"seed {seed}")
    # random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
    logger.info(f'World Size: {world_size}')
    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")
        log_writer = SummaryWriter(log_dir=config.OUTPUT)
    else:
        log_writer = None

    # print config
    logger.info(config.dump())

    main(config, log_writer, seed, args)