"""Keeps track of experiments -- addon on top of Hydra."""

import os
import warnings
from pathlib import Path
from typing import cast

import mlflow
import mlflow.tracking
import pandas as pd
from mlflow.exceptions import RestException
from omegaconf import DictConfig, OmegaConf
from torch import nn

from ml_utils.proxies import with_proxies
from ml_utils.utils import get_param_info

# ------------------------- Hydra ---------------------------------


def log_hydra(cfg: DictConfig, config_path="config.yaml"):
    log_config_as_yaml(cfg=cfg, config_path=config_path)
    log_config_params(cfg=cfg)


def log_config_params(cfg: DictConfig):
    """Flatten and log Hydra config to MLflow"""
    flat_cfg = OmegaConf.to_container(cfg, resolve=True)
    mlflow.log_params(_flatten_dict(flat_cfg))


def log_config_as_yaml(cfg: DictConfig, config_path: str):
    """Log the resolved Hydra config directly as a YAML artifact"""
    try:
        yaml_str = OmegaConf.to_yaml(cfg, resolve=True)
        mlflow.log_text(yaml_str, artifact_file=config_path)
    except Exception as e:
        if "Resource Conflict" in str(e) and "already exists" in str(e):
            warnings.warn("mlflow: config.yaml already logged, skipping.", stacklevel=2)
        else:
            raise


def _flatten_dict(d, parent_key="", sep="."):
    """Recursively flattens a nested dict for MLflow logging"""
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(_flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def log_params(mdl: nn.Module, path: str) -> None:
    try:
        mlflow.log_dict(get_param_info(mdl=mdl), artifact_file=path)
    except Exception as e:
        if "Resource Conflict" in str(e) and "already exists" in str(e):
            warnings.warn("mlflow: config.yaml already logged, skipping.", stacklevel=2)
        else:
            raise


# ------------------------- mlflow ---------------------------------


def init_mlflow():
    if "AZUREML_RUN_ID" not in os.environ:
        mlflow.set_tracking_uri("file:./mlruns")
        mlflow.set_experiment(Path.cwd().name)
    mlflow.config.enable_async_logging()  # type: ignore


@with_proxies
def get_runs(
    tracking_uri: str | None = None, exp_name: str | None = None, job_name: str | None = None, *, with_hist: bool = True
) -> tuple[pd.DataFrame, pd.DataFrame] | pd.DataFrame:
    # Setup
    tracking_uri = tracking_uri or os.environ["MLFLOW_TRACKING_URI"]
    mlflow.set_tracking_uri(tracking_uri)

    # Point-based logs
    exp_name = exp_name or Path.cwd().name
    fltr_string = f"attributes.run_id = '{job_name}'" if job_name else "attributes.status = 'Finished'"
    df_runs = mlflow.search_runs(experiment_names=[exp_name], filter_string=fltr_string)
    df_runs = cast(pd.DataFrame, df_runs)

    if not with_hist:
        return df_runs

    # Full metric history
    client, metric_history = mlflow.tracking.MlflowClient(), []
    metric_names = [col for col in df_runs.columns if col.startswith("metrics.")]
    for run_id in df_runs["run_id"]:
        for metric_name in metric_names:
            if pd.isna(df_runs.loc[df_runs["run_id"] == run_id, metric_name].values[0]):  # type: ignore
                continue
            try:
                metric = client.get_metric_history(run_id=run_id, key=metric_name.split(".")[1])
            except RestException:
                continue
            if len(metric) > 1:  # skip scalar values
                hist = [
                    {"run_id": run_id, "key": metric_name, "value": metric_point.value, "step": metric_point.step}
                    for metric_point in metric
                ]
                metric_history.extend(hist)
    df_metrics = pd.DataFrame(metric_history)

    return df_runs, df_metrics
