import math
import argparse
import pprint
from distutils.util import strtobool
from pathlib import Path
from loguru import logger as loguru_logger

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.plugins import DDPPlugin

from src.config.PAConfig import get_cfg_defaults
from src.utils.misc import get_rank_zero_only_logger, setup_gpus
from src.utils.profiler import build_profiler
from src.lightning.data import MultiSceneDataModule
from src.lightning.lightning_PAloftr import PL_PALoFTR

# import torch
# torch.autograd.set_detect_anomaly(True)

loguru_logger = get_rank_zero_only_logger(loguru_logger)

data_cfg_path = "configs/data/scannet_100scene.py"
# data_cfg_path = "configs/data/scannet25k_debug.py"
main_cfg_path = "configs/loftr/indoor/depth_matcher.py"

# ---------- checkpoint paths ----------
ckpt_path = None
ckpt_backBone = "weights/ResNetFPN.ckpt"
ckpt_depth_predictor = None
ckpt_pose_coarse = "weights/loftr_partial/loftr_coarse.ckpt"
ckpt_pose_preprocess = "weights/loftr_partial/fine_preprocess.ckpt"
ckpt_pose_fine = "weights/loftr_partial/loftr_fine.ckpt"
ckpt_pose_proj = "weights/loftr_partial/neck_proj.ckpt"

n_nodes = 1
n_gpus_per_node = 4
torch_num_workers = 4
batch_size = 2
pin_memory = True
exp_name = "test-ds-bs={}".format(n_gpus_per_node * n_nodes * batch_size)

MyArgs = [
    data_cfg_path,
    main_cfg_path,
    # '--ckpt_path={}'.format(ckpt_path),
    '--exp_name={}'.format(exp_name),
    '--gpus={}'.format(n_gpus_per_node),
    '--num_nodes={}'.format(n_nodes),
    '--batch_size={}'.format(batch_size),
    '--num_workers={}'.format(torch_num_workers),
    '--pin_memory={}'.format(pin_memory),
    '--check_val_every_n_epoch=1',
    '--log_every_n_steps=100',
    '--flush_logs_every_n_steps=100',
    '--limit_val_batches=1.',
    '--num_sanity_val_steps=32',
    '--benchmark=True',
    '--max_epochs=30',
    '--parallel_load_data'
]

# Define the parser
def parse_args(args=None):
    # init a custom parser which will be added into pl.Trainer parser
    # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        'data_cfg_path', type=str, help='data config path')
    parser.add_argument(
        'main_cfg_path', type=str, help='main config path')
    parser.add_argument(
        '--exp_name', type=str, default='default_exp_name')
    parser.add_argument(
        '--batch_size', type=int, default=4, help='batch_size per gpu')
    parser.add_argument(
        '--num_workers', type=int, default=4)
    parser.add_argument(
        '--pin_memory', type=lambda x: bool(strtobool(x)),
        nargs='?', default=True, help='whether loading data to pinned memory or not')
    parser.add_argument(
        '--ckpt_path', type=str, default=None,
        help='pretrained checkpoint path, helpful for using a pre-trained coarse-only LoFTR')
    parser.add_argument(
        '--disable_ckpt', action='store_true',
        help='disable checkpoint saving (useful for debugging).')
    parser.add_argument(
        '--profiler_name', type=str, default=None,
        help='options: [inference, pytorch], or leave it unset')
    parser.add_argument(
        '--parallel_load_data', action='store_true',
        help='load datasets in with multiple processes.')

    parser = pl.Trainer.add_argparse_args(parser)
    if args == None:
        return parser.parse_args()
    else:
        return parser.parse_args(args)


def main():
    args = parse_args(MyArgs)
    rank_zero_only(pprint.pprint)(vars(args))
    rank_zero_only(vars(args))

    # ----------------- Prepare Configuration -------------------
    # init default-cfg and merge with main and data cfg
    config = get_cfg_defaults()
    config.merge_from_file(args.main_cfg_path)
    config.merge_from_file(args.data_cfg_path)
    pl.seed_everything(config.TRAINER.SEED)

    # scale lr and warmup-step automatically
    args.gpus = _n_gpus = setup_gpus(args.gpus)
    config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes
    config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size
    _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS
    config.TRAINER.SCALING = _scaling
    config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling
    config.TRAINER.WARMUP_STEP = math.floor(config.TRAINER.WARMUP_STEP / _scaling)
    
    # ----------------- Setup Model and Data -------------------
    # lightning module
    logger = TensorBoardLogger(save_dir='logs/tb_logs', name=args.exp_name, default_hp_metric=False)
    if not Path.exists(Path(logger.log_dir)):
        Path.mkdir(Path(logger.log_dir))
    ckpt_dir = Path(logger.log_dir) / 'checkpoints'
    fig_dir = Path(logger.log_dir) / 'figs'

    profiler = build_profiler(args.profiler_name)
    model = PL_PALoFTR(config, pretrained_ckpt=None, profiler=profiler)

    # ----------- load pretrained parameters ------------
    if ckpt_backBone:
        model.load_partial_ckpt(ckpt_backBone, 'backbone', inTrain=False)
    if ckpt_depth_predictor:
        model.load_partial_ckpt(ckpt_depth_predictor, 'depth_predictor', inTrain=True)
    if ckpt_pose_proj:
        model.load_partial_ckpt(ckpt_pose_proj, 'proj', inTrain=False)
    if ckpt_pose_coarse:
        model.load_partial_ckpt(ckpt_pose_coarse, 'pose_coarse', inTrain=False)
    if ckpt_pose_preprocess:
        model.load_partial_ckpt(ckpt_pose_preprocess, 'pose_preprocess', inTrain=False)
    if ckpt_pose_fine:
        model.load_partial_ckpt(ckpt_pose_fine, 'pose_fine', inTrain=False)
    # ---------------------------------------------------
    loguru_logger.info(f"DepthLoFTR LightningModule initialized!")

    # lightning data
    data_module = MultiSceneDataModule(args, config)
    loguru_logger.info(f"InstLoFTR DataModule initialized!")

    # Callbacks
    ckpt_callback = ModelCheckpoint(
        monitor='auc@10', verbose=True, save_top_k=5, mode='max',
        save_last=True,
        dirpath=str(ckpt_dir),
        filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}'
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    callbacks = [lr_monitor]
    if not args.disable_ckpt:
        callbacks.append(ckpt_callback)

    # ------------ Train the Model -------------
    # Lightning DepthLoFTR Trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING,
        plugins=DDPPlugin(
            find_unused_parameters=False,
            num_nodes=args.num_nodes,
            sync_batchnorm=config.TRAINER.WORLD_SIZE > 0
        ),
        # resume_from_checkpoint=ckpt_path,
        # precision=16,
        accelerator='ddp',
        gpus=args.gpus,
        callbacks=callbacks,
        logger=logger,
        sync_batchnorm=config.TRAINER.WORLD_SIZE > 0,
        replace_sampler_ddp=False,
        reload_dataloaders_every_epoch=False,
        weights_summary='full',
        profiler=profiler
    )

    # Fitting the model
    loguru_logger.info(f"Trainer initialized")
    loguru_logger.info(f"Start Training")
    trainer.fit(model, datamodule=data_module)


if __name__ == '__main__':
    main()