import os
import pickle

import matplotlib.pyplot as plt
import mlflow
import numpy as np
from omegaconf import DictConfig, OmegaConf
import pandas as pd
import torch
import hydra

from haipr.utils import (
    compute_classification_metrics,
    compute_regression_metrics,
    plot_pr_curve,
    plot_roc_curve,
    plot_regression,
    plot_residuals,
    plot_qq,
)

import logging
import time

logger = logging.getLogger(__name__)


class ResultsLogger:
    def __init__(self, cfg, run=None):
        """Initialize the results logger.

        Args:
            cfg: Configuration object
            run: MLflow run object (optional)
        """
        self.cfg = cfg
        self.run = run
        self._created_temp_dirs = set()
        self._temp_dir: str = "/tmp/haipr_results"
        self._ensure_temp_dir()
        # Set up MLflow tracking if run is provided
        if run is not None:
            self.set_run_id(run.info.run_id)
        else:
            self.run_id = None

        self.postfix = ""
        self._current_split = None  # Track current split for model naming

    def _ensure_temp_dir(self):
        """Ensure temporary directory exists."""
        if hasattr(self, "run_id") and self.run_id:
            self._temp_dir = f"/tmp/haipr_results_{self.run_id}"
            os.makedirs(self._temp_dir, exist_ok=True)
            try:
                self._created_temp_dirs.add(self._temp_dir)
            except Exception:
                pass
            logger.debug(f"Created temporary directory: {self._temp_dir}")
            return
        if not os.path.exists(self._temp_dir):
            os.makedirs(self._temp_dir, exist_ok=True)
            try:
                self._created_temp_dirs.add(self._temp_dir)
            except Exception:
                pass
            logger.debug(f"Created temporary directory: {self._temp_dir}")
            return
        if not self._temp_dir:
            if not self.run_id:
                raise ValueError("Run ID must be set before using temporary directory")
            self._temp_dir = f"/tmp/haipr_results_{self.run_id}"
            os.makedirs(self._temp_dir, exist_ok=True)
            try:
                self._created_temp_dirs.add(self._temp_dir)
            except Exception:
                pass
            logger.debug(f"Created temporary directory: {self._temp_dir}")
            return
        self._temp_dir = "/tmp/haipr_results"

    def __del__(self):
        """Clean up temporary files on object destruction"""
        try:
            # Attempt to clean up all temp dirs created during this logger's lifetime
            temp_dirs = set(self._created_temp_dirs)
            if self._temp_dir:
                temp_dirs.add(self._temp_dir)
            for temp_dir in temp_dirs:
                try:
                    if os.path.exists(temp_dir):
                        for file in os.listdir(temp_dir):
                            try:
                                os.remove(os.path.join(temp_dir, file))
                            except Exception:
                                pass
                        try:
                            os.rmdir(temp_dir)
                        except Exception:
                            pass
                        logger.debug(f"Cleaned up temporary directory: {temp_dir}")
                except Exception:
                    pass
        except Exception as e:
            logger.error(f"Error cleaning up temporary files: {str(e)}")

    def set_run_id(self, run_id):
        """Set the run ID and initialize temp directory if needed."""
        # Switch run context without deleting prior temp dirs to avoid races in parallel runs
        self.run_id = run_id
        self._ensure_temp_dir()

    def log_metrics(self, metrics, step=None):
        """Log metrics to MLflow."""
        if not self.run_id:
            raise ValueError("Run ID must be set before logging metrics")

        logger.debug(f"Logging metrics for {self.run_id}")
        # Convert numpy values to Python native types
        metrics = {
            k: v.item() if isinstance(v, np.generic) else v for k, v in metrics.items()
        }
        mlflow.log_metrics(metrics, step=step)

    def log_metrics_summary(self, metrics_df, fname="metrics_summary"):
        """Log metrics summary as a JSON file."""
        if not self.run_id:
            raise ValueError("Run ID must be set before logging metrics summary")

        logger.debug(f"Logging metrics summary for {self.run_id}")

        # Ensure temp directory exists
        self._ensure_temp_dir()

        # Convert DataFrame to dict
        metrics_dict = metrics_df.to_dict(orient="records")
        # Save as JSON
        mlflow.log_dict(metrics_dict, f"{fname}.json")
        # also log as csv
        metrics_df.to_csv(f"{self._temp_dir}/{fname}.csv", index=False)
        mlflow.log_artifact(f"{self._temp_dir}/{fname}.csv", run_id=self.run_id)
        logger.debug(f"Metrics summary logged to {fname}")

    def log_input_sample(self, data_sample, context, max_rows=5, tags=None):
        """Log a sample of the input data using MLflow's log_input functionality."""
        if not self.run_id:
            raise ValueError("Run ID must be set before logging input sample")

        self._ensure_temp_dir()
        logger.debug(f"Logging input sample for {self.run_id}")

        if tags is None:
            tags = {}

        if isinstance(data_sample, tuple):  # features, labels
            features, labels = data_sample

            # mlflow.log_text(features[0].to_string(), "input_sample_features.txt")
            # mlflow.log_text(labels[0].to_string(), "input_sample_labels.txt")
            # make dataframe from features and labels
            df = pd.DataFrame(features)
            df["label"] = labels
        elif isinstance(data_sample, dict):  # dict for esm features
            # First validate that all tensors have the same batch dimension
            batch_sizes = [
                v.shape[0] for v in data_sample.values() if isinstance(v, torch.Tensor)
            ]
            if not all(size == batch_sizes[0] for size in batch_sizes):
                raise ValueError("All tensors must have the same batch dimension")
            # log each value in data_sample as a separate text file
            # for key, value in data_sample.items():
            #     mlflow.log_text(value[0].to_string(), f"input_sample_{key}.txt")
            # Convert tensors to numpy arrays and create row-wise data
            rows = []
            n_samples = min(max_rows, batch_sizes[0])

            for i in range(n_samples):
                row_data = {}
                for key, value in data_sample.items():
                    if isinstance(value, torch.Tensor):
                        # Store the shape as metadata in the first row
                        if i == 0:
                            row_data[f"{key}_shape"] = str(
                                value.shape[1:]
                            )  # Store non-batch dimensions
                        # Get the i-th sample and flatten if needed
                        sample = value[i].numpy()
                        if sample.ndim > 1:
                            row_data[key] = sample.flatten()
                        else:
                            row_data[key] = sample
                    else:
                        row_data[key] = value
                rows.append(row_data)

            df = pd.DataFrame(rows)
        elif isinstance(data_sample, pd.DataFrame):
            df = data_sample
        else:
            raise ValueError(f"Unsupported data type: {type(data_sample)}")

        df = df.head(max_rows)
        logger.debug(f"Saving sample with shape: {df.shape}")
        self._ensure_temp_dir()
        temp_path = os.path.join(self._temp_dir, f"{context}_sample.csv")
        df.to_csv(temp_path, index=False)
        if self.run_id:
            mlflow.log_artifact(temp_path, run_id=self.run_id)
        else:
            mlflow.log_artifact(temp_path)
        logger.debug(f"Input sample logged to {context}_sample.csv")

    def log_run_metrics_and_predictions(self, all_metrics, all_predictions, run_data):
        """Log final metrics and predictions."""
        logger.info("Logging final metrics and predictions")

        # Ensure temp directory exists
        self._ensure_temp_dir()

        # Create detailed metrics DataFrame
        metrics_df = pd.DataFrame(all_metrics)
        metrics_df.index = [f"split_{i + 1}" for i in range(len(all_metrics))]

        # Log metrics summary as artifact
        self.log_metrics_summary(metrics_df)

        # Create predictions DataFrame with split information and original indices
        predictions_data = []

        for split_num, pred_dict in enumerate(all_predictions):
            # Get original indices from the dataset
            orig_indices = run_data.data.iloc[pred_dict["indices"]]["sample_id"].values

            # Ensure predictions and true values are flattened
            predictions = np.array(pred_dict["predictions"]).flatten()
            true_values = np.array(pred_dict["true_values"]).flatten()

            split_data = pd.DataFrame(
                {
                    "split": split_num + 1,
                    "sample_id": orig_indices,
                    "prediction": predictions,
                    "true_value": true_values,
                }
            )

            if "probabilities" in pred_dict:
                probs = pred_dict["probabilities"]

                if isinstance(probs, list) and len(probs) > 0:
                    if isinstance(probs[0], list):  # Multi-class
                        for i in range(len(probs[0])):
                            split_data[f"probability_class_{i}"] = [p[i] for p in probs]
                    else:  # Binary
                        split_data["probability"] = probs
            predictions_data.append(split_data)

        # Concatenate all predictions
        all_predictions_df = pd.concat(predictions_data, ignore_index=True)

        # Log predictions summary with original indices
        self.log_predictions(all_predictions_df, "predictions_summary.csv")
        # save predictions to csv in hydras run dir
        hydra_run_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
        all_predictions_df.to_csv(os.path.join(hydra_run_dir, "predictions.csv"), index=False)


        # Prepare data for combined plots
        y_true = all_predictions_df["true_value"].values
        y_pred = all_predictions_df["prediction"].values
        y_prob = (
            all_predictions_df.filter(regex="probability_class_").values
            if any("probability_class_" in col for col in all_predictions_df.columns)
            else None
        )

        # Log combined plots
        self.log_plots(
            y_true,
            y_pred,
            y_prob,
            sample_indices=all_predictions_df["sample_id"].values,
        )



    def log_predictions(self, predictions_data, fname="predictions_summary.csv"):
        """Log predictions summary as a CSV file.

        Args:
            predictions_data (list): List of dictionaries containing predictions data
                Each dict should have: split, sample_id, predictions, true_values, probabilities
        """
        logger.debug(f"Logging predictions for {self.run_id}")

        # Ensure temp directory exists
        self._ensure_temp_dir()

        # Create DataFrame from predictions data, check if it's a list of dicts

        if isinstance(predictions_data, list) and all(
            isinstance(item, dict) for item in predictions_data
        ):
            predictions_df = pd.DataFrame(
                [
                    {
                        "sample_id": split_data["sample_id"],  # Use sample_id
                        "split": split_data["split"],
                        "prediction": pred,
                        "true_value": true,
                        "probability": prob,
                    }
                    for split_data in predictions_data
                    for pred, true, prob in zip(
                        split_data["predictions"],
                        split_data["true_values"],
                        split_data["probabilities"],
                    )
                ]
            )
        else:
            predictions_df = pd.DataFrame(predictions_data)

            if "sample_id" not in predictions_df.columns:
                logger.warning("sample_id not found in predictions data")

        # Save predictions to CSV in temporary directory
        self._ensure_temp_dir()
        temp_path = os.path.join(self._temp_dir, fname)
        predictions_df.to_csv(temp_path, index=False)

        # Log to MLflow
        mlflow.log_artifact(temp_path)
        logger.debug(f"Predictions logged to {temp_path}")

    def log_plots(self, y_true, y_pred, y_prob=None, sample_indices=None):
        """Log plots for model evaluation.

        Args:
            y_true (array-like): True values
            y_pred (array-like): Predicted values
            y_prob (array-like, optional): Prediction probabilities
            sample_indices (array-like, optional): Original indices of samples
        """
        logger.debug(f"Logging plots for {self.run_id}")
        logger.debug(f"y_true shape: {np.array(y_true).shape}")
        logger.debug(f"y_pred shape: {np.array(y_pred).shape}")
        logger.debug(
            f"y_prob shape: {None if y_prob is None else np.array(y_prob).shape}"
        )

        # Ensure inputs are numpy arrays and properly shaped
        y_true = np.array(y_true).reshape(-1)
        y_pred = np.array(y_pred).reshape(-1)

        if self.cfg.num_classes == 0:  # Regression task
            self._log_regression_plots(y_true, y_pred, sample_indices)
        else:  # Classification task
            if y_prob is not None:
                y_prob = np.array(y_prob)

                if y_prob.ndim == 1:  # Binary classification
                    y_prob = np.vstack([1 - y_prob, y_prob]).T
            self._log_classification_plots(y_true, y_prob)

    def set_postfix(self, postfix):
        self.postfix = postfix

    def set_current_split(self, split_num):
        """Set the current split number for model naming."""
        self._current_split = split_num

    def _log_classification_plots(self, y_true, y_prob):
        """Log classification metrics and curve data as tables and figures for MLflow visualization."""

        # Generate PR curve figure using plot_utils
        fig = plot_pr_curve(
            y_true, y_prob, self.cfg.data.label_column, self.cfg.num_classes
        )
        mlflow.log_figure(fig, "pr_curve.png")
        plt.close()

        # Generate ROC curve figure using plot_utils
        fig = plot_roc_curve(
            y_true, y_prob, self.cfg.data.label_column, self.cfg.num_classes
        )
        mlflow.log_figure(fig, "roc_curve.png")
        plt.close()

    def _log_regression_plots(self, y_true, y_pred, sample_indices=None):
        """Log regression data as tables and figures for MLflow visualization."""
        # Ensure 1D arrays
        y_true = np.asarray(y_true).ravel()
        y_pred = np.asarray(y_pred).ravel()

        # Generate scatter plot figure using plot_utils
        fig = plot_regression(y_true, y_pred, self.cfg.data.label_column)
        mlflow.log_figure(fig, "scatter_plot.png")
        plt.close()

        # Generate residual plot figure using plot_utils
        fig = plot_residuals(y_true, y_pred, self.cfg.data.label_column)
        mlflow.log_figure(fig, "residual_plot.png")
        plt.close()

        # Generate QQ plot figure using plot_utils
        fig = plot_qq(
            y_true=y_true, y_pred=y_pred, label_name=self.cfg.data.label_column
        )
        mlflow.log_figure(fig, "qq_plot.png")
        plt.close()

    def log_sequences(self, sequences, fitness, step=None):
        """Log sequences and fitness."""
        logger.debug(f"Logging sequences for {self.run_id}")
        # logger.debug(f"Sequences: {sequences}")
        # logger.debug(f"Fitness: {fitness}")
        logger.debug(f"Step: {step}")
        self._ensure_temp_dir()
        temp_path = os.path.join(self._temp_dir, f"sequences_step_{step}.csv")
        with open(temp_path, "w") as f:
            f.write("sequence,fitness\n")
            for seq, fit in zip(sequences, fitness):
                f.write(f"{seq},{fit}\n")
        mlflow.log_artifact(temp_path, artifact_path="sequences")

    def log_optimization_step(self, step_data, step=None):
        """
        Log metrics for an optimization step (generic for any optimizer).

        Args:
            step_data: Dict containing optimization metrics
            step: Optional step number
        """
        # Convert numpy values to Python native types and handle arrays
        metrics = {}
        for key, value in step_data.items():
            if isinstance(value, np.ndarray):
                if value.size == 1:
                    metrics[key] = value.item()
                else:
                    # For arrays, log basic statistics
                    metrics.update(
                        {
                            f"{key}_mean": np.mean(value),
                            f"{key}_std": np.std(value),
                            f"{key}_min": np.min(value),
                            f"{key}_max": np.max(value),
                        }
                    )
            elif isinstance(value, (list, tuple)) and len(value) > 1:
                arr = np.array(value)
                metrics.update(
                    {
                        f"{key}_mean": np.mean(arr),
                        f"{key}_std": np.std(arr),
                        f"{key}_min": np.min(arr),
                        f"{key}_max": np.max(arr),
                    }
                )
            else:
                metrics[key] = value.item() if isinstance(value, np.generic) else value

        self.log_metrics(metrics, step=step)
        return metrics

    def log_config(self, cfg):
        """Log the hydra config with simple conflict handling."""
        self._ensure_temp_dir()

        # Save raw config as YAML artifact
        self._ensure_temp_dir()
        path = os.path.join(self._temp_dir, "config.yaml")
        with open(path, "w") as f:
            f.write(OmegaConf.to_yaml(cfg))
        mlflow.log_artifact(path)

        # Log flattened config parameters with improved conflict handling
        try:
            flattened_params = self._flatten_config_safe(cfg)

            # Get existing parameters to avoid conflicts
            existing_params = {}
            try:
                run = mlflow.active_run()
                if run:
                    existing_params = mlflow.get_run(run.info.run_id).data.params
            except Exception:
                pass

            # Only log parameters that don't already exist OR have different values
            new_params = {}
            conflicting_params = {}
            for key, value in flattened_params.items():
                if key not in existing_params:
                    new_params[key] = value
                elif existing_params[key] != value:
                    # Parameter exists but with different value - log as warning
                    conflicting_params[key] = {
                        "existing": existing_params[key],
                        "new": value,
                    }
                    logger.warning(
                        f"Parameter '{key}' already exists with value '{existing_params[key]}', "
                        f"attempting to log '{value}'. Skipping to avoid MLflow exception."
                    )

            if new_params:
                mlflow.log_params(new_params)
            else:
                logger.debug("All parameters already exist, skipping parameter logging")

            if conflicting_params:
                logger.warning(
                    f"Skipped {len(conflicting_params)} conflicting parameters: {list(conflicting_params.keys())}"
                )

        except Exception as e:
            logger.warning(f"Failed to log config parameters: {e}")

    def _flatten_config_safe(self, cfg, parent_key=""):
        """Safely flatten config with error handling."""
        items = {}
        if isinstance(cfg, (dict, DictConfig)):
            try:
                cfg_dict = (
                    OmegaConf.to_container(cfg, resolve=True)
                    if isinstance(cfg, DictConfig)
                    else cfg
                )
            except Exception:
                cfg_dict = (
                    OmegaConf.to_container(cfg, resolve=False)
                    if isinstance(cfg, DictConfig)
                    else cfg
                )

            if isinstance(cfg_dict, dict):
                for k, v in cfg_dict.items():
                    new_key: str = f"{parent_key}.{k}" if parent_key else str(k)
                    if isinstance(v, (dict, DictConfig)):
                        items.update(self._flatten_config_safe(v, new_key))
                    else:
                        items[new_key] = str(v)[:200] if v is not None else "null"
        return items

    def log_model(self, model, model_name=None, tags=None, metadata=None):
        """
        Log BasePredictor as pyfunc model.
        """
        if not self.run_id:
            raise ValueError("Run ID must be set before logging models")

        if model is None:
            logger.warning("Model is None, skipping")
            return

        # Set default model name
        if model_name is None:
            model_name = f"{self.cfg.model.name}"
            if hasattr(self, "_current_split") and self._current_split is not None:
                model_name += f"_split_{self._current_split + 1}"

        # Build tags and metadata
        if tags is None:
            tags = {}
        if metadata is None:
            metadata = {}

        run = mlflow.get_run(self.run_id)
        experiment_id = run.info.experiment_id

        # Determine model type for metadata only
        model_type = "pytorch" if self._is_neural_model(model) else "sklearn"

        model_tags = {
            "experiment_id": experiment_id,
            "model_type": model_type,
            "model_name": model_name,
            "task": ("classification" if self.cfg.num_classes > 0 else "regression"),
            "num_classes": self.cfg.num_classes,
            "benchmark": self.cfg.benchmark.name,
            "split_method": self.cfg.data.split_method,
            "feature_type": self.cfg.data.feature_type,
            "registered_model_name": model_name,
        }

        if hasattr(self.cfg, "embedder") and self.cfg.embedder.model:
            model_tags["embedder_model"] = self.cfg.embedder.model

        if hasattr(self, "_current_split") and self._current_split is not None:
            model_tags["split"] = self._current_split + 1
            model_tags["split_num"] = self._current_split

        model_tags.update(tags)

        model_metadata = {
            "experiment_id": experiment_id,
            "experiment_name": self.cfg.mlflow.experiment_name,
            "model_name": model_name,
            "run_id": self.run_id,
            "timestamp": time.time(),
            "config": OmegaConf.to_container(self.cfg, resolve=True),
        }
        model_metadata.update(metadata)

        try:
            import tempfile
            import shutil
            from haipr.predictor import InferenceWrapper

            artifacts_dir = tempfile.mkdtemp()
            model_path = model.save_model(artifacts_dir)
            # Save config
            config_path = os.path.join(artifacts_dir, "config.yaml")
            with open(config_path, "w") as f:
                OmegaConf.save(self.cfg, f)

            artifacts = {
                "model": model_path,
                "config": config_path,
            }

            # wrap the model for inference, removing lightning inheritance to avaoid pyfunc call signature error
            inference_model = InferenceWrapper(model)
            # Log as pyfunc - now using inference wrapper without Lightning
            mlflow.pyfunc.log_model(
                artifact_path="model",
                python_model=inference_model,  # Wrapped model!
                artifacts=artifacts,
                model_config=OmegaConf.to_container(self.cfg.model, resolve=True), # needed to change params for inference
                registered_model_name=model_name,
                pip_requirements=self._get_pip_requirements(),
                metadata=model_metadata,
            )

            logger.info(f"Successfully logged pyfunc model: {model_name}")
            mlflow.set_tags(model_tags)

            shutil.rmtree(artifacts_dir)

        except Exception as e:
            logger.error(f"Failed to log model {model_name}: {e}", exc_info=True)

    def _is_neural_model(self, model):
        """
        Check if a model is a neural model (PyTorch/Lightning).

        Args:
            model: The model to check

        Returns:
            bool: True if the model is a neural model
        """
        # Simple check: if it has state_dict method, it's likely a PyTorch model
        return hasattr(model, "state_dict") and callable(getattr(model, "state_dict"))

    def _create_model_signature(self):
        """Create MLflow model signature."""
        from mlflow.models.signature import ModelSignature
        from mlflow.types.schema import Schema, ColSpec

        input_schema = Schema([ColSpec("string", "sequence")])

        if self.cfg.num_classes > 0:
            output_schema = Schema(
                [
                    ColSpec("double", "prediction"),
                    ColSpec("double", "probability"),
                ]
            )
        else:
            output_schema = Schema([ColSpec("double", "prediction")])

        return ModelSignature(inputs=input_schema, outputs=output_schema)

    def _get_pip_requirements(self):
        """Get pip requirements."""
        requirements = [
            "torch>=2.0.0",
            "scikit-learn",
            "numpy",
            "pandas",
            "omegaconf",
            "mlflow>=2.0.0",
        ]

        if hasattr(self.cfg, "embedder") and self.cfg.embedder:
            if "esm" in self.cfg.embedder.model.lower():
                requirements.append("esm==3.2.0")

        return requirements

    def finish(self):
        # Nothing to clean up when using MLflow
        for temp_dir in self._created_temp_dirs:
            try:
                if os.path.exists(temp_dir):
                    for file in os.listdir(temp_dir):
                        os.remove(os.path.join(temp_dir, file))
                    os.rmdir(temp_dir)
                    logger.debug(f"Cleaned up temporary directory: {temp_dir}")
            except Exception:
                pass
        pass
