from typing import Any, Iterable
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


def _ccc(y_true: Iterable[float], y_pred: Iterable[float]) -> float:
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)

    # Handle NaNs/Infs in predictions conservatively
    mask = np.isfinite(y_pred)
    if not np.all(mask):
        # replace non-finite predictions with the mean of y_true to avoid exploding errors
        safe_fill = float(np.nanmean(y_true)) if np.isfinite(np.nanmean(y_true)) else 0.0
        y_pred = y_pred.copy()
        y_pred[~mask] = safe_fill

    # Clip to task range
    y_true = np.clip(y_true, 1.0, 9.0)
    y_pred = np.clip(y_pred, 1.0, 9.0)

    mean_t = float(np.mean(y_true))
    mean_p = float(np.mean(y_pred))
    var_t = float(np.var(y_true))
    var_p = float(np.var(y_pred))
    cov_tp = float(np.mean((y_true - mean_t) * (y_pred - mean_p)))

    denom = var_t + var_p + (mean_t - mean_p) ** 2
    if denom <= 1e-12:
        return 0.0
    return (2.0 * cov_tp) / denom


class MusicEmotionRegressionMetrics(CompetitionMetrics):
    """
    Metric class for Music Emotion Prediction using average Concordance Correlation Coefficient (CCC)
    over two targets: valence_mean and arousal_mean.
    """

    def __init__(self, value: str = None, higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value
        self.id_column = "song_id"

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both by id to ensure consistent pairing
        y_true = y_true.sort_values(by=self.id_column).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=self.id_column).reset_index(drop=True)

        # Compute CCC for each target
        v_col, a_col = ("valence_mean", "arousal_mean")
        ccc_v = _ccc(y_true[v_col].values, y_pred[v_col].values)
        ccc_a = _ccc(y_true[a_col].values, y_pred[a_col].values)
        score = float((ccc_v + ccc_a) / 2.0)
        # Bound to [-1, 1] for safety
        return max(-1.0, min(1.0, score))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        # Required columns must match exactly
        required_cols = {self.id_column, "valence_mean", "arousal_mean"}
        sub_cols = set(submission.columns)
        gt_cols = set(ground_truth.columns)

        if sub_cols != required_cols:
            extra = sub_cols - required_cols
            missing = required_cols - sub_cols
            if missing:
                raise InvalidSubmissionError(f"Missing required columns in submission: {', '.join(sorted(missing))}.")
            if extra:
                raise InvalidSubmissionError(f"Extra unexpected columns found in submission: {', '.join(sorted(extra))}.")

        # Check row counts and IDs
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort and check IDs identical
        submission_sorted = submission.sort_values(by=self.id_column).reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by=self.id_column).reset_index(drop=True)

        if (submission_sorted[self.id_column].values != gt_sorted[self.id_column].values).any():
            raise InvalidSubmissionError("song_id values do not match between submission and ground truth.")

        # Validate types and ranges
        for col in ["valence_mean", "arousal_mean"]:
            try:
                _ = pd.to_numeric(submission_sorted[col], errors="coerce").astype(float)
            except Exception as e:
                raise InvalidSubmissionError(f"Column '{col}' must be numeric: {e}")

        return True, "Submission is valid."
