from jaxtyping import install_import_hook
import argparse  
import contextlib 
import importlib 
import logging 
import os 
import sys  
import time  
import traceback  
class ColoredFilter(logging.Filter):
    """
    A logging filter to add color to certain log levels.
    """

    RESET = "\033[0m"  
    RED = "\033[31m"  
    GREEN = "\033[32m" 
    YELLOW = "\033[33m" 
    BLUE = "\033[34m" 
    MAGENTA = "\033[35m"  
    CYAN = "\033[36m" 

    COLORS = {
        "WARNING": YELLOW,
        "INFO": GREEN,
        "DEBUG": BLUE,
        "CRITICAL": MAGENTA,
        "ERROR": RED,
    } 

    RESET = "\x1b[0m"  

    def __init__(self):
        super().__init__()

    def filter(self, record):
        if record.levelname in self.COLORS:
            color_start = self.COLORS[record.levelname]
            record.levelname = f"{color_start}[{record.levelname}]"  # Adding color to the log level name
            record.msg = f"{record.msg}{self.RESET}"  # Resetting the text color after the log message
        return True


def load_custom_module(module_path):
    module_name = os.path.basename(module_path)  # Extracting the module name from the module path
    if os.path.isfile(module_path):
        sp = os.path.splitext(module_path)
        module_name = sp[0]  # Removing the file extension from the module name
    try:
        if os.path.isfile(module_path):
            module_spec = importlib.util.spec_from_file_location(
                module_name, module_path
            )  # Creating a module spec from the module path
        else:
            module_spec = importlib.util.spec_from_file_location(
                module_name, os.path.join(module_path, "__init__.py")
            )  # Creating a module spec from the package path

        module = importlib.util.module_from_spec(module_spec)  # Creating a module from the module spec
        sys.modules[module_name] = module  # Adding the module to the sys.modules dictionary
        module_spec.loader.exec_module(module)  # Executing the module
        return True
    except Exception as e:
        print(traceback.format_exc())  # Printing the stack trace in case of an exception
        print(f"Cannot import {module_path} module for custom nodes:", e)  # Printing the error message
        return False


def load_custom_modules():
    node_paths = ["custom"]  # List of paths where custom modules are located
    node_import_times = []  # List to store import times for custom modules
    for custom_node_path in node_paths:
        possible_modules = os.listdir(custom_node_path)  # Getting the list of possible modules in the path
        if "__pycache__" in possible_modules:
            possible_modules.remove("__pycache__")  # Removing the __pycache__ directory from the list

        for possible_module in possible_modules:
            module_path = os.path.join(custom_node_path, possible_module)  # Constructing the module path
            if (
                os.path.isfile(module_path)
                and os.path.splitext(module_path)[1] != ".py"
            ):  # Skipping non-Python files
                continue
            if module_path.endswith("_disabled"):  # Skipping disabled modules
                continue
            time_before = time.perf_counter()  # Recording the time before importing the module
            success = load_custom_module(module_path)  # Loading the custom module
            node_import_times.append(
                (time.perf_counter() - time_before, module_path, success)
            )  # Storing the import time and module path

    if len(node_import_times) > 0:
        print("\nImport times for custom modules:")
        for n in sorted(node_import_times):
            if n[2]:
                import_message = ""
            else:
                import_message = " (IMPORT FAILED)"
            print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
        print()


def main(args, extras) -> tuple:
    import torch
    # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
    selected_gpus = [0]

    # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
    # As far as Pytorch Lightning is concerned, we always use all available GPUs
    # (possibly filtered by CUDA_VISIBLE_DEVICES).
    devices = -1
    if len(env_gpus) > 0:
        # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script.
        n_gpus = len(env_gpus)
    else:
        if isinstance(args.gpu, str):
            selected_gpus = list(args.gpu.split(","))
            n_gpus = len(selected_gpus)
            os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        elif isinstance(args.gpu, int):
            selected_gpus = [args.gpu]
            n_gpus = len(selected_gpus)
            os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
        elif isinstance(args.gpu, torch.device):
            if args.gpu.type == "cpu":
                selected_gpus = []
                n_gpus = 0
                os.environ["CUDA_VISIBLE_DEVICES"] = ""
            else:
                if args.gpu.index is None:
                    selected_gpus = [0]
                    n_gpus = 1
                    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
                else:
                    selected_gpus = [args.gpu.index]
             
                    n_gpus = len(selected_gpus)
                    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu.index)

                

    import pytorch_lightning as pl
    import torch
    from pytorch_lightning import Trainer
    from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
    from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
    from pytorch_lightning.utilities.rank_zero import rank_zero_only

    if args.typecheck:
        from jaxtyping import install_import_hook

        install_import_hook("threestudio", "typeguard.typechecked")

    import threestudio
    from threestudio.systems.base import BaseSystem
    from threestudio.utils.callbacks import (
        CodeSnapshotCallback,
        ConfigSnapshotCallback,
        CustomProgressBar,
        ProgressCallback,
    )
    from threestudio.utils.config import ExperimentConfig, load_config
    from threestudio.utils.misc import get_rank
    from threestudio.utils.typing import Optional

    logger = logging.getLogger("pytorch_lightning")
    if args.verbose:
        logger.setLevel(logging.DEBUG)

    for handler in logger.handlers:
        if handler.stream == sys.stderr:  # type: ignore
            if not args.gradio:
                handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
                handler.addFilter(ColoredFilter())
            else:
                handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))

    # load_custom_modules()

    # parse YAML config to OmegaConf
    cfg: ExperimentConfig
    cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)

    # set a different seed for each device
    pl.seed_everything(cfg.seed + get_rank(), workers=True)

    dm = threestudio.find(cfg.data_type)(cfg.data)
    system: BaseSystem = threestudio.find(cfg.system_type)(
        cfg.system, resumed=cfg.resume is not None
    )
    system.set_save_dir(os.path.join(cfg.trial_dir, "save"))

    if args.gradio:
        fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
        fh.setLevel(logging.INFO)
        if args.verbose:
            fh.setLevel(logging.DEBUG)
        fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
        logger.addHandler(fh)

    callbacks = []
    if args.train:
        callbacks += [
            ModelCheckpoint(
                dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
            ),
            # LearningRateMonitor(logging_interval="step"),
            # CodeSnapshotCallback(
            #     os.path.join(cfg.trial_dir, "code"), use_version=False
            # ),
            ConfigSnapshotCallback(
                args.config,
                cfg,
                os.path.join(cfg.trial_dir, "configs"),
                use_version=False,
            ),
        ]
        if args.gradio:
            callbacks += [
                ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
            ]
        else:
            callbacks += [CustomProgressBar(refresh_rate=1)]

    def write_to_text(file, lines):
        with open(file, "w") as f:
            for line in lines:
                f.write(line + "\n")

    loggers = []
    if args.train:
        pass
        # make tensorboard logging dir to suppress warning
        # rank_zero_only(
        #     lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
        # )()
        # loggers += [
        #     TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
        #     CSVLogger(cfg.trial_dir, name="csv_logs"),
        # ] + system.get_loggers()
        # rank_zero_only(
        #     lambda: write_to_text(
        #         os.path.join(cfg.trial_dir, "cmd.txt"),
        #         ["python " + " ".join(sys.argv), str(args)],
        #     )
        # )()

    trainer = Trainer(
        callbacks=callbacks,
        logger=loggers,
        inference_mode=False,
        accelerator="gpu",
        devices=devices,
        **cfg.trainer,
    )

    def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
        if ckpt_path is None:
            return
        ckpt = torch.load(ckpt_path, map_location="cpu")
        system.set_resume_status(ckpt["epoch"], ckpt["global_step"])

    if args.train:
        trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
        trainer.test(system, datamodule=dm)
        if args.gradio:
            # also export assets if in gradio mode
            trainer.predict(system, datamodule=dm)
    elif args.validate:
        # manually set epoch and global_step as they cannot be automatically resumed
        set_system_status(system, cfg.resume)
        trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
    elif args.test:
        # manually set epoch and global_step as they cannot be automatically resumed
        set_system_status(system, cfg.resume)
        trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
    elif args.export:
        set_system_status(system, cfg.resume)
        trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
    
    return_images = system.output_rgb
    # return_images = None
    # print(len(return_images), type(return_images[0]), return_images[0].shape)
    
    # return_vid_path = system.result_imgs()[1]
    return_last_ckpt = os.path.join(cfg.trial_dir, "ckpts")
    return return_images, return_last_ckpt + "/last.ckpt"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="path to config file")
    parser.add_argument(
        "--gpu",
        default="0",
        help="GPU(s) to be used. 0 means use the 1st available GPU. "
        "1,2 means use the 2nd and 3rd available GPU. "
        "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
        "this argument is ignored and all available GPUs are always used.",
    )

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--train", action="store_true")
    group.add_argument("--validate", action="store_true")
    group.add_argument("--test", action="store_true")
    group.add_argument("--export", action="store_true")

    parser.add_argument(
        "--gradio", action="store_true", help="if true, run in gradio mode"
    )

    parser.add_argument(
        "--verbose", action="store_true", help="if true, set logging level to DEBUG"
    )

    parser.add_argument(
        "--typecheck",
        action="store_true",
        help="whether to enable dynamic type checking",
    )

    args, extras = parser.parse_known_args()

    if args.gradio:
        # FIXME: no effect, stdout is not captured
        with contextlib.redirect_stdout(sys.stderr):
            main(args, extras)
    else:
        main(args, extras)
