import os
import argparse, os, sys, datetime, glob, importlib
from typing import Any

from omegaconf import OmegaConf
import torch
import time

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import wandb
from pytorch_lightning.plugins.environments import SLURMEnvironment
import signal
from webdataset.utils import make_seed

def get_parser(**parser_kwargs):
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")

    parser = argparse.ArgumentParser(**parser_kwargs)
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "--wandb_offline",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "--wandb_project_name",
        type=str,
        default="dismo",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=list(),
    )
    parser.add_argument(
        "--gpus",
        type=str,
        default=None,
        # default="0,6",
    )
    parser.add_argument(
        "--n_gpus",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--n_nodes",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--bs",
        # default=8,
        default=None,
        help="batch size",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-4,
        help="learning rate",
    )
    parser.add_argument(
        "--wd",
        type=float,
        default=0,
        help="weight decay",
    )
    parser.add_argument(
        "--precision",
        type=str,
        # default="32",
        default="16-mixed",
        # default="bf16-mixed",
        help="whether to use mixed precision during training",
    )
    parser.add_argument(
        "--acc_bs",
        default=1,
        help="How much gradient batches should be accumulated",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument(
        "-p",
        "--prefix",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="prefix for logdir",
    )
    parser.add_argument(
        "--wandb_prefix",
        type=str,
        const=True,
        default="",
        nargs="?",
    )
    parser.add_argument(
        "-l",
        "--log",
        help="path to the log"
        "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=None,
    )
    parser.add_argument(
        "--tar_base",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--slurm",
        type=str2bool,
        default=False,
    )
    parser.add_argument(
        "-t",
        "--mode",
        type=str,
        const=True,
        default="train",
        nargs="?",
        help="whether to run an entire training process or a single validation set",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
    )
    parser.add_argument(
        "--randomize_train_loader",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--compile",
        type=str2bool,
        default=True,
    )
    parser.add_argument(
        "--ckpt_every",
        type=int,
        default=None,
    )

    return parser


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


class SetupCallback(Callback):
    def __init__(self, resume, logdir, ckptdir, cfgdir, config):
        super().__init__()
        self.resume = resume
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
        print("MAKEDIRS INIT")

        self.on_fit_start(None, None)

    def on_fit_start(self, trainer, pl_module):
        # print(f"on_fit_start (rank: {trainer.global_rank})")
        # print(f"[SetupCallback rank={trainer.global_rank}]\nlogdir={self.logdir}\nckptdir={self.ckptdir}\ncfgdir={self.cfgdir}\nconfig={self.config}")
        print("MAKEDIRS1")
        #if pl_module.trainer.global_rank == 0:
            # Create logdirs and save configs
            # print("MAKEDIRS2")
        os.makedirs(self.logdir, exist_ok=True, mode=0o777)
        os.makedirs(self.ckptdir, exist_ok=True, mode=0o777)
        os.makedirs(self.cfgdir, exist_ok=True, mode=0o777)
        OmegaConf.save(self.config, os.path.join(self.cfgdir, "config.yaml"))


def setup(opt, unknown):
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")

    # add cwd for convenience and to make classes in this file available when
    # running as `python run.py`
    # (in particular `run.DataModuleFromConfig`)
    sys.path.append(os.getcwd())

    assert opt.gpus is None or opt.n_gpus is None, "Either gpus or n_gpus"

    if opt.gpus is not None:
        n_devices = len(opt.gpus.split(","))
    elif opt.n_gpus is not None:
        n_devices = int(opt.n_gpus)

    n_nodes = opt.n_nodes or 1

    is_rel_path = False
    if opt.base:
        base = []
        for cfg in opt.base:
            if not os.path.isabs(cfg):
                base.append(os.path.join("configs", cfg))
                # is_rel_path = True
            else:
                base.append(cfg)
        opt.base = base

    if opt.resume:
        if not os.path.exists(opt.resume):
            raise ValueError("Cannot find {}".format(opt.resume))
        if os.path.isfile(opt.resume):
            logdir = "/".join(opt.resume.split("/")[:-2])
            ckpt = opt.resume
        else:
            assert os.path.isdir(opt.resume), opt.resume
            logdir = opt.resume.rstrip("/")
            ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
        opt.resume_from_checkpoint = ckpt
        run_name = logdir.split("/")[-1]
        cfg_name = logdir.split("/")[-2]
    else:
        if opt.base:
            cfg_fname = os.path.split(opt.base[0])
            if is_rel_path:
                cfg_name = "--".join(cfg_fname[1:])
            else:
                cfg_fname = cfg_fname[-1]
            cfg_name = os.path.splitext(cfg_fname)[0]
            total_bs = n_nodes * n_devices * int(opt.bs) * int(opt.acc_bs)
            cfg_name = cfg_name + f"--{total_bs}bs"
            name = cfg_name
        else:
            name = ""


        run_name = now

        if opt.prefix:
            run_name = opt.prefix + "--" + run_name

        run_name = str(opt.precision) + "--" + run_name

        if opt.log or opt.log != "":
            logdir = os.path.join(opt.log, name, run_name)
        else:
            logdir = os.path.join("logs", name, run_name)

    ckptdir = os.path.join(logdir, "checkpoints")
    cfgdir = os.path.join(logdir, "configs")

    seed_everything(opt.seed)

    return cfg_name, run_name, logdir, ckptdir, cfgdir


def create_trainer_from_config(opt, config, cfg_name, run_name, datamodule):
    import copy

    lightning_config = config.get("lightning", OmegaConf.create())
    lightning_config = copy.deepcopy(lightning_config)

    # trainer and callbacks
    trainer_kwargs = dict()

    # Setup logger args
    default_logger_cfgs = {
        "wandb": {
            "target": "pytorch_lightning.loggers.WandbLogger",
            "params": {
                "name": opt.wandb_prefix + cfg_name + "--" + run_name,
                "save_dir": logdir,
                "offline": opt.debug or opt.wandb_offline,
                "id": opt.wandb_prefix + cfg_name + "--" + run_name,
                "project": opt.wandb_project_name,
            },
        },
        "testtube": {
            "target": "pytorch_lightning.loggers.TestTubeLogger",
            "params": {
                "name": "testtube",
                "save_dir": logdir,
            },
        },
        "tensorboard": {
            "target": "pytorch_lightning.loggers.TensorBoardLogger",
            "params": {
                "name": "tensorboard",
                "save_dir": logdir,
            },
        },
        "csv": {
            "target": "pytorch_lightning.loggers.CSVLogger",
            "params": {
                "name": "csv",
                "save_dir": logdir,
            },
        },
    }
    default_logger_cfg = default_logger_cfgs["wandb"]
    logger_cfg = lightning_config.pop("logger", OmegaConf.create())
    logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
    trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)

    # Add callbacks
    # ==============================
    trainer_kwargs["callbacks"] = []

    # Checkpointing
    default_modelckpt_cfg = {
        "target": "pytorch_lightning.callbacks.ModelCheckpoint",
        "params": {
            "dirpath": ckptdir,
            "verbose": True,
            "save_last": True,
            "save_top_k": 3,
            "every_n_epochs": 0,
            "monitor": "val/loss",
        },
    }
    modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, lightning_config.pop("modelcheckpoint", OmegaConf.create()))
    trainer_kwargs["callbacks"].append(instantiate_from_config(modelckpt_cfg))

    # Directory initialization and image logging
    default_callbacks_cfg = {
        "setup_callback": {
            "target": "ltx_train.SetupCallback",
            "params": {"resume": opt.resume, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config},
        },
    }
    callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, lightning_config.pop("callbacks", OmegaConf.create()))
    try:
        callbacks_cfg.image_logger.params.max_images = min(
            callbacks_cfg.image_logger.params.max_images, config["data"]["params"]["batch_size"]
        )
    except:
        pass
    trainer_kwargs["callbacks"] += [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]

    # Log learning rate scheduler
    trainer_kwargs["callbacks"].append(LearningRateMonitor(logging_interval="step"))

    # Number of nodes
    trainer_kwargs["num_nodes"] = int(opt.n_nodes)

    # Specify devices
    if opt.gpus is not None:
        trainer_kwargs["devices"] = [int(device) for device in opt.gpus.split(",")]
        n_devices = len(trainer_kwargs["devices"])
    elif opt.n_gpus is not None:
        trainer_kwargs["devices"] = int(opt.n_gpus)
        n_devices = trainer_kwargs["devices"]

    # specify strategy
    trainer_kwargs["strategy"] = "ddp" if n_devices > 1 else "auto"

    # misc
    if "limit_val_samples" in lightning_config:
        trainer_kwargs["limit_val_batches"] = int(lightning_config["limit_val_samples"]) // int(opt.bs)
        del lightning_config["limit_val_samples"]

    trainer_kwargs.update(OmegaConf.to_container(lightning_config))
    # print()
    # print(trainer_kwargs)

    plugins = []
    if opt.slurm:
        plugins.append(SLURMEnvironment(auto_requeue=False, requeue_signal=signal.SIGUSR2))

    trainer = Trainer(
        accelerator="gpu",
        precision=opt.precision,
        plugins=plugins,
        use_distributed_sampler=False,  # we handle that ourselves
        num_sanity_val_steps=0,
        **trainer_kwargs,
    )

    return trainer


def run(opt, trainer: pl.Trainer, datamodule, model):

    # allow checkpointing via USR1
    def melk(*args, **kwargs):
        # run all checkpoint hooks
        if trainer.global_rank == 0:
            print("Summoning checkpoint callbacks.")
            checkpoint_callbacks = [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)]
            [c.on_validation_end(trainer, None) for c in checkpoint_callbacks]

    # allow debugging via USR2
    def divein(*args, **kwargs):
        if trainer.global_rank == 0:
            import pudb

            pudb.set_trace()

    import signal

    signal.signal(signal.SIGUSR1, melk)
    signal.signal(signal.SIGUSR2, divein)

    try:
        kwargs = {}
        if "resume_from_checkpoint" in opt and opt.resume_from_checkpoint is not None:
            kwargs["ckpt_path"] = opt.resume_from_checkpoint

        if opt.mode == "validate":
            trainer.validate(model, datamodule, **kwargs)
        elif opt.mode == "train":
            trainer.fit(model, datamodule, **kwargs)

    except Exception as e:
        raise e

    finally:
        # move newly created debug project to debug_runs
        if opt.debug and not opt.resume and trainer.global_rank == 0:
            dst, name = os.path.split(logdir)
            dst = os.path.join(dst, "debug_runs", name)
            os.makedirs(os.path.split(dst)[0], exist_ok=True)
            os.rename(logdir, dst)


if __name__ == "__main__":
    torch._dynamo.config.cache_size_limit = 64
    torch._dynamo.config.suppress_errors = False
    torch.set_float32_matmul_precision("high")

    parser = get_parser()
    opt, unknown = parser.parse_known_args()
    cli = OmegaConf.from_dotlist(unknown)

    cfg_name, run_name, logdir, ckptdir, cfgdir = setup(opt, unknown)

    resume_from_checkpoint = "resume_from_checkpoint" in opt and os.path.isfile(opt.resume_from_checkpoint)
    if not resume_from_checkpoint:
        opt.resume_from_checkpoint = None

    # Load configs (either from resumed training or from opt.base)
    if resume_from_checkpoint and not opt.base:
        configs = [OmegaConf.load(os.path.join(logdir, "configs/config.yaml"))]
    else:
        configs = [OmegaConf.load(cfg) for cfg in opt.base]
    config = OmegaConf.merge(*configs, cli)

    config.data.params.train.batch_size = int(opt.bs)
    config.lightning.accumulate_grad_batches = int(opt.acc_bs)
    if opt.lr is not None:
        config.model.learning_rate = float(opt.lr)
    if opt.wd is not None:
        config.model.weight_decay = float(opt.wd)
    if opt.ckpt_every is not None:
        if "lightning" in config and "modelcheckpoint" in config["lightning"]:
            config.lightning.modelcheckpoint.params["every_n_train_steps"] = int(opt.ckpt_every)
    datamodule = instantiate_from_config(config.data)

    # Build model
    if not "target" in config.model:
        raise KeyError("Expected key `target` to instantiate.")

    config.model["params"]["compile"] = opt.compile
    model: pl.LightningModule = instantiate_from_config(config.model)
    if opt.ckpt_path is not None:
        ckpt = torch.load(opt.ckpt_path, weights_only=False)
        model.load_from_checkpoint(ckpt)
        print(f"Loaded checkpoint from {opt.ckpt_path}")
    
    trainer = create_trainer_from_config(opt, config, cfg_name, run_name, datamodule)

    # Configure learning rate and weight decay

    bs, ngpu = (int(config.data.params.train.batch_size), len(trainer.device_ids))
    if "base_learning_rate" in config.model:
        base_lr = config.model.base_learning_rate
        model.learning_rate = config.model.base_learning_rate * trainer.accumulate_grad_batches * ngpu * bs
    elif "learning_rate" in config.model:
        base_lr = 0
        model.learning_rate = config.model.learning_rate
    else:
        raise "jo no lr"
    model.weight_decay = config.model.weight_decay if "weight_decay" in config.model else 0.0
    model.max_training_steps = config.lightning.max_steps if "max_steps" in config.lightning else -1
    model.lr_schedule_params = config.model.lr_schedule_params if "lr_schedule_params" in config.model else None

    print(
        "Setting learning rate to {:.2e} = "
        "{} (accumulate_grad_batches) * "
        "{} (num_gpus) * "
        "{} (batchsize) * "
        "{:.2e} (base_lr)".format(model.learning_rate, trainer.accumulate_grad_batches, ngpu, bs, base_lr)
    )

    # Run
    run(opt, trainer, datamodule, model)
