import importlib
from argparse import ArgumentParser
from omegaconf import OmegaConf
from os.path import join as pjoin
from pathlib import Path
import os


def parse_args(phase="train"):
    """
    Parse arguments and load config files.

    Args:
        phase (str): The phase for which to parse arguments. Can be 'train', 'test', 'demo', 'render', or 'webui'.

    Returns:
        OmegaConf: The loaded and updated configuration object.
    """
    parser = ArgumentParser()
    group = parser.add_argument_group("Training options")

    # Assets
    group.add_argument(
        "--cfg_assets",
        type=str,
        default="./configs/assets.yaml",
        help="Config file for asset paths",
    )

    # Default config
    cfg_default = {
        "train": "./configs/default.yaml",
        "test": "./configs/default.yaml",
        "render": "./configs/render.yaml",
        "webui": "./configs/webui.yaml",
    }.get(phase, "./configs/default.yaml")

    group.add_argument(
        "--cfg",
        type=str,
        default=cfg_default,
        help="Config file",
    )

    # Parse phase-specific arguments
    if phase in ["train", "test"]:
        group.add_argument("--batch_size", type=int, help="Training batch size")
        group.add_argument("--num_nodes", type=int, help="Number of nodes")
        group.add_argument("--device", type=int, nargs="+", help="Training device")
        group.add_argument("--task", type=str, help="Evaluation task type")
        group.add_argument("--nodebug", action="store_true", help="Disable debug mode")

    if phase == "demo":
        group.add_argument(
            "--example", type=str, help="Input text and lengths with txt format"
        )
        group.add_argument("--out_dir", type=str, help="Output directory")
        group.add_argument("--task", type=str, help="Evaluation task type")

    if phase == "render":
        group.add_argument("--npy", type=str, default=None, help="Npy motion files")
        group.add_argument("--dir", type=str, default=None, help="Npy motion folder")
        group.add_argument("--fps", type=int, default=30, help="Render fps")
        group.add_argument(
            "--mode",
            type=str,
            default="sequence",
            help="Render target: video, sequence, frame",
        )

    params = parser.parse_args()

    # Load YAML config files
    OmegaConf.register_new_resolver("eval", eval)
    cfg_assets = OmegaConf.load(params.cfg_assets)
    cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, "default.yaml"))
    cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))

    if not cfg_exp.FULL_CONFIG:
        cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER)

    cfg = OmegaConf.merge(cfg_exp, cfg_assets)

    # Update config with arguments
    if phase in ["train", "test"]:
        cfg.TRAIN.BATCH_SIZE = (
            params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE
        )
        cfg.DEVICE = params.device if params.device else cfg.DEVICE
        cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES
        cfg.model.params.task = params.task if params.task else cfg.model.params.task
        cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG

        if phase == "test":
            cfg.DEBUG = False
            cfg.DEVICE = [0]

    if phase == "demo":
        cfg.DEMO.EXAMPLE = params.example
        cfg.DEMO.TASK = params.task
        cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER
        os.makedirs(cfg.TEST.FOLDER, exist_ok=True)

    if phase == "render":
        if params.npy:
            cfg.RENDER.NPY = params.npy
            cfg.RENDER.INPUT_MODE = "npy"
        if params.dir:
            cfg.RENDER.DIR = params.dir
            cfg.RENDER.INPUT_MODE = "dir"
        cfg.RENDER.FPS = float(params.fps)
        cfg.RENDER.MODE = params.mode

    if cfg.DEBUG:
        cfg.NAME = "debug--" + cfg.NAME
        cfg.LOGGER.WANDB.params.offline = True
        cfg.LOGGER.VAL_EVERY_STEPS = 1
        cfg.DEVICE = [0]

    # Resume config
    cfg = resume_config(cfg)

    return cfg


def get_module_config(cfg: OmegaConf, filepath: str = "./configs") -> OmegaConf:
    """
    Load YAML config files from subfolders and update the given configuration object.

    Args:
        cfg (omegaconf.DictConfig): The configuration object to be updated.
        filepath (str): The root directory path to search for YAML config files.

    Returns:
        omegaconf.DictConfig: The updated configuration object.
    """
    config_path = Path(filepath)
    yaml_files = config_path.glob("**/*.yaml")

    for yaml_file in yaml_files:
        relative_path = yaml_file.relative_to(config_path)
        nodes = str(relative_path).replace(".yaml", "").replace(os.sep, ".")
        OmegaConf.update(cfg, nodes, OmegaConf.load(yaml_file))

    return cfg


def get_obj_from_str(string, reload: bool = False) -> object:
    """
    Get an object (e.g., class, function) from a module using a string.

    Args:
        string (str): The full name of the object in the format 'module.submodule.ClassName'.
        reload (bool): Whether to reload the module to ensure the latest version is used.

    Returns:
        object: The object specified by the string.

    Raises:
        ImportError: If the module cannot be imported.
        AttributeError: If the object cannot be found in the module.
    """
    try:
        module_name, obj_name = string.rsplit(".", 1)
        module = importlib.import_module(module_name)
        if reload:
            importlib.reload(module)
        return getattr(module, obj_name)
    except ImportError as e:
        raise ImportError(f"Error importing module {module_name}: {e}")
    except AttributeError as e:
        raise AttributeError(
            f"Error accessing attribute {obj_name} in module {module_name}: {e}"
        )


def instantiate_from_config(config) -> object:
    """
    Instantiate an object from a configuration dictionary.

    Args:
        config (dict): The configuration dictionary with keys 'target' and 'params'.
                       The 'target' key specifies the object's import path.
                       The 'params' key specifies the parameters for object instantiation.

    Returns:
        object: The instantiated object.

    Raises:
        KeyError: If the 'target' key is not in the config.
    """
    if "target" not in config:
        raise KeyError("Expected key 'target' to instantiate.")

    target = config["target"]
    params = config.get("params", {})

    obj = get_obj_from_str(target)
    return obj(**params)


def resume_config(cfg: OmegaConf) -> OmegaConf:
    """
    Resume model and Weights & Biases (wandb) configuration.

    Args:
        cfg (OmegaConf): The configuration object to be updated.

    Returns:
        OmegaConf: The updated configuration object.

    Raises:
        ValueError: If the resume path does not exist or if no wandb run file is found.
    """
    if cfg.TRAIN.RESUME:
        resume_path = Path(cfg.TRAIN.RESUME)
        if resume_path.exists():
            # Checkpoints
            cfg.TRAIN.PRETRAINED = str(resume_path / "checkpoints" / "last.ckpt")

            # Wandb
            wandb_latest_run_path = resume_path / "wandb" / "latest-run"
            wandb_files = list(wandb_latest_run_path.glob("run-*.wandb"))
            if wandb_files:
                wandb_run = wandb_files[0].stem.replace("run-", "")
                cfg.LOGGER.WANDB.params.id = wandb_run
            else:
                raise ValueError("No wandb run file found in the latest-run directory.")
        else:
            raise ValueError(f"Resume path {resume_path} does not exist.")

    return cfg
