import warnings
from collections.abc import Callable
from functools import wraps
from importlib.util import find_spec
from pathlib import Path
from typing import Any

from omegaconf import DictConfig, OmegaConf

from meds_torch.utils import pylogger, rich_utils

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


def configure_logging(cfg: DictConfig) -> None:
    """Applies optional utilities before the task is started.

    Utilities:
        - Ignoring python warnings
        - Setting tags from command line
        - Rich config printing

    Args:
        cfg (DictConfig): A DictConfig object containing the config tree.
    """
    # return if no `extras` config
    if not cfg.get("extras"):
        log.warning("Extras config not found! <cfg.extras=null>")
        return

    # disable python warnings
    if cfg.extras.get("ignore_warnings"):
        log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
        warnings.filterwarnings("ignore")

    # prompt user to input tags from command line if none are provided in the config
    if cfg.extras.get("enforce_tags"):
        log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
        rich_utils.enforce_tags(cfg, save_to_file=True)

    # pretty print config tree using Rich library
    if cfg.extras.get("print_config"):
        log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
        rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)


def task_wrapper(task_func: Callable) -> Callable:
    """Optional decorator that controls the failure behavior when executing the task function.

    This wrapper can be used to:     - make sure loggers are closed even if the task function raises an
    exception (prevents multirun failure)     - save the exception to a `.log` file     - mark the run as
    failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)     - etc. (adjust
    depending on your needs)

    Example: ``` @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ...
    return outputs ```

    :param task_func: The task function to be wrapped.
    :return: The wrapped task function.
    """

    @wraps(task_func)
    def wrap(cfg: DictConfig, **kwargs) -> tuple[dict[str, Any], dict[str, Any]]:
        cfg.paths.time_output_dir = Path(cfg.paths.time_output_dir)
        cfg.paths.time_output_dir.mkdir(parents=True, exist_ok=True)
        OmegaConf.save(config=cfg, f=Path(cfg.paths.time_output_dir) / "hydra_config.yaml")
        # execute the task
        try:
            outputs = task_func(cfg=cfg, **kwargs)

        # things to do if exception occurs
        except Exception as ex:  # pragma: no cover
            # save exception to `.log` file
            log.error(ex)

            # some hyperparameter combinations might be invalid or cause out-of-memory errors
            # so when using hparam search plugins like Optuna, you might want to disable
            # raising the below exception to avoid multirun failure
            raise ex

        # things to always do after either success or exception
        finally:
            # display output dir path in terminal
            log.info(f"Output dir: {cfg.paths.time_output_dir}")

            # always close wandb run (even if exception occurs so multirun won't fail)
            if find_spec("wandb"):  # check if wandb is installed
                import wandb

                if wandb.run:
                    log.info("Closing wandb!")
                    wandb.finish()

        return outputs

    return wrap


def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> float | None:
    """Safely retrieves value of the metric logged in LightningModule.

    :param metric_dict: A dict containing metric values.
    :param metric_name: If provided, the name of the metric to retrieve.
    :return: If a metric name was provided, the value of the metric.
    """
    if not metric_name:
        log.info("Metric name is None! Skipping metric value retrieval...")
        return None

    if metric_name not in metric_dict:
        raise Exception(
            f"Metric value not found! <metric_name={metric_name}>\n"
            "Make sure metric name logged in LightningModule is correct!\n"
            "Make sure `optimized_metric` name in `hparams_search` config is correct!"
        )

    metric_value = metric_dict[metric_name].item()
    log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")

    return metric_value
