"""MLFlowLogger class for logging to MLflow."""

# pylint: disable-all
# Temporarily disable all pylint checkers during AST traversal to prevent crash.
# The imports checker crashes when resolving fwh_core package imports due to a bug
# in pylint/astroid: https://github.com/pylint-dev/pylint/issues/10185
# pylint: enable=all
# Re-enable all pylint checkers for the checking phase. This allows other checks
# (code quality, style, undefined names, etc.) to run normally while bypassing
# the problematic imports checker that would crash during AST traversal.

import json
import os
import tempfile
import time
from collections.abc import Mapping
from typing import Any

import dotenv
import matplotlib.figure
import mlflow
import numpy
import PIL.Image
import plotly.graph_objects
from mlflow.entities import Metric, Param, RunTag
from omegaconf import DictConfig, OmegaConf

from fwh_core.logging.logger import Logger
from fwh_core.structured_configs.logging import MLFlowLoggerInstanceConfig
from fwh_core.utils.mlflow_utils import (
    get_experiment,
    get_run,
    maybe_terminate_run,
    resolve_registry_uri,
)

dotenv.load_dotenv()
_DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")


class MLFlowLogger(Logger):
    """Logs to MLflow Tracking."""

    def __init__(
        self,
        experiment_id: str | None = None,
        experiment_name: str | None = None,
        run_id: str | None = None,
        run_name: str | None = None,
        tracking_uri: str | None = None,
        registry_uri: str | None = None,
        downgrade_unity_catalog: bool | None = None,
    ):
        """Initialize MLflow logger."""
        self._downgrade_unity_catalog = downgrade_unity_catalog if downgrade_unity_catalog is not None else True
        resolved_registry_uri = resolve_registry_uri(
            registry_uri=registry_uri,
            tracking_uri=tracking_uri,
            downgrade_unity_catalog=downgrade_unity_catalog,
        )
        self._client = mlflow.MlflowClient(tracking_uri=tracking_uri, registry_uri=resolved_registry_uri)
        experiment = get_experiment(experiment_id=experiment_id, experiment_name=experiment_name, client=self.client)
        assert experiment is not None
        self._experiment_id = experiment.experiment_id
        self._experiment_name = experiment.name
        run = get_run(run_id=run_id, run_name=run_name, experiment_id=self.experiment_id, client=self.client)
        assert run is not None
        self._run_id = run.info.run_id
        self._run_name = run.info.run_name

    @property
    def client(self) -> mlflow.MlflowClient:
        """Expose underlying MLflow client for integrations."""
        return self._client

    @property
    def experiment_name(self) -> str:
        """Expose active MLflow experiment name."""
        return self._experiment_name

    @property
    def experiment_id(self) -> str:
        """Expose active MLflow experiment identifier."""
        return self._experiment_id

    @property
    def run_name(self) -> str | None:
        """Expose active MLflow run name."""
        return self._run_name

    @property
    def run_id(self) -> str:
        """Expose active MLflow run identifier."""
        return self._run_id

    @property
    def tracking_uri(self) -> str | None:
        """Return the tracking URI associated with this logger."""
        return self.client.tracking_uri

    @property
    def registry_uri(self) -> str | None:
        """Return the model registry URI associated with this logger."""
        return self.client._registry_uri  # pylint: disable=protected-access

    @property
    def cfg(self) -> MLFlowLoggerInstanceConfig:
        """Return the configuration of this logger."""
        return MLFlowLoggerInstanceConfig(
            _target_=self.__class__.__qualname__,
            experiment_id=self.experiment_id,
            experiment_name=self.experiment_name,
            run_id=self.run_id,
            run_name=self.run_name,
            tracking_uri=self.tracking_uri,
            registry_uri=self.registry_uri,
            downgrade_unity_catalog=self._downgrade_unity_catalog,
        )

    def log_config(self, config: DictConfig, resolve: bool = False) -> None:
        """Log config to MLflow."""
        with tempfile.TemporaryDirectory() as temp_dir:
            config_path = os.path.join(temp_dir, "config.yaml")
            OmegaConf.save(config, config_path, resolve=resolve)
            self.client.log_artifact(self.run_id, config_path)

    def log_metrics(self, step: int, metric_dict: Mapping[str, Any]) -> None:
        """Log metrics to MLflow."""
        timestamp = int(time.time() * 1000)
        metrics = self._flatten_metric_dict(metric_dict, timestamp, step)
        self._log_batch(metrics=metrics)

    def _flatten_metric_dict(
        self, metric_dict: Mapping[str, Any], timestamp: int, step: int, key_prefix: str = ""
    ) -> list[Metric]:
        """Flatten a dictionary of metrics into a list of Metric entities."""
        metrics = []
        for key, value in metric_dict.items():
            key = f"{key_prefix}/{key}" if key_prefix else key
            if isinstance(value, Mapping):
                nested_metrics = self._flatten_metric_dict(value, timestamp, step, key_prefix=key)
                metrics.extend(nested_metrics)
            else:
                value = float(value)
                metric = Metric(key, value, timestamp, step)
                metrics.append(metric)
        return metrics

    def log_params(self, param_dict: Mapping[str, Any]) -> None:
        """Log params to MLflow."""
        params = self._flatten_param_dict(param_dict)
        self._log_batch(params=params)

    def _flatten_param_dict(self, param_dict: Mapping[str, Any], key_prefix: str = "") -> list[Param]:
        """Flatten a dictionary of params into a list of Param entities."""
        params = []
        for key, value in param_dict.items():
            key = f"{key_prefix}.{key}" if key_prefix else key
            if isinstance(value, Mapping):
                nested_params = self._flatten_param_dict(value, key_prefix=key)
                params.extend(nested_params)
            else:
                value = str(value)
                param = Param(key, value)
                params.append(param)
        return params

    def log_tags(self, tag_dict: Mapping[str, Any]) -> None:
        """Set tags on the MLFlow."""
        tags = [RunTag(k, str(v)) for k, v in tag_dict.items()]
        self._log_batch(tags=tags)

    def log_figure(
        self,
        figure: matplotlib.figure.Figure | plotly.graph_objects.Figure,
        artifact_file: str,
        **kwargs,
    ) -> None:
        """Log a figure to MLflow using MLflowClient.log_figure."""
        self.client.log_figure(self.run_id, figure, artifact_file, **kwargs)

    def log_image(
        self,
        image: numpy.ndarray | PIL.Image.Image | mlflow.Image,
        artifact_file: str | None = None,
        key: str | None = None,
        step: int | None = None,
        **kwargs,
    ) -> None:
        """Log an image to MLflow using MLflowClient.log_image."""
        # Parameter validation - ensure we have either artifact_file or (key + step)
        if not artifact_file and not (key and step is not None):
            raise ValueError("Must provide either artifact_file or both key and step parameters")

        self.client.log_image(self.run_id, image, artifact_file=artifact_file, key=key, step=step, **kwargs)

    def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
        """Log an artifact (file or directory) to MLflow."""
        self.client.log_artifact(self.run_id, local_path, artifact_path)

    def log_json_artifact(self, data: dict | list, artifact_name: str) -> None:
        """Log a JSON object as an artifact to MLflow."""
        with tempfile.TemporaryDirectory() as temp_dir:
            json_path = os.path.join(temp_dir, artifact_name)
            with open(json_path, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2)
            self.client.log_artifact(self.run_id, json_path)

    def close(self) -> None:
        """End the MLflow run."""
        maybe_terminate_run(run_id=self.run_id, client=self.client)

    def _log_batch(self, **kwargs: Any) -> None:
        """Log arbitrary data to MLflow."""
        self.client.log_batch(self.run_id, **kwargs, synchronous=False)
