from typing import Any
import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


TARGET_LABELS = [
    "Prolongation",
    "Block",
    "SoundRep",
    "WordRep",
    "Interjection",
    "NoStutteredWords",
    "PoorAudioQuality",
    "DifficultToUnderstand",
    "NaturalPause",
    "Music",
    "NoSpeech",
]

EPS = 1e-15


class Sep28kSoftBCEMetrics(CompetitionMetrics):
    """
    Evaluation for SEP-28k: average soft binary cross-entropy across all labels and clips.
    Lower is better, so higher_is_better defaults to False.
    """

    def __init__(self, value: str | None = None, higher_is_better: bool = False):
        # `value` is unused for this multi-label task but kept for API compatibility
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic dataframe checks and alignment by first column (id)
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")

        # Sort both frames by the first column to align rows
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]
        y_true = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Now compute the soft BCE over TARGET_LABELS
        # Ensure both have required columns in the same order
        expected_cols_true = [id_col_true] + TARGET_LABELS
        expected_cols_pred = [id_col_pred] + TARGET_LABELS
        if list(y_true.columns) != expected_cols_true:
            raise InvalidSubmissionError(
                f"Ground truth must have columns: {expected_cols_true}"
            )
        if list(y_pred.columns) != expected_cols_pred:
            raise InvalidSubmissionError(
                f"Predictions must have columns: {expected_cols_pred}"
            )

        # Check id sets match exactly
        if (y_true[id_col_true].values != y_pred[id_col_pred].values).any():
            raise InvalidSubmissionError("Row order or ids do not match between y_true and y_pred.")

        P = y_pred[TARGET_LABELS].to_numpy(dtype=float)
        Y = y_true[TARGET_LABELS].to_numpy(dtype=float)
        if not np.isfinite(P).all():
            raise InvalidSubmissionError("Predictions contain NaN or Inf values.")
        if (P < -1e-6).any() or (P > 1 + 1e-6).any():
            raise InvalidSubmissionError("Predictions must be within [0,1].")

        # Clip for numerical stability
        P = np.clip(P, EPS, 1.0 - EPS)
        Y = np.clip(Y, 0.0, 1.0)

        loss = -(Y * np.log(P) + (1.0 - Y) * np.log(1.0 - P))
        return float(loss.mean())

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        # Must have the same number of rows
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by first column and check id equality
        sub = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        gtr = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (sub[sub.columns[0]].values != gtr[gtr.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Column checks: must match exactly id + TARGET_LABELS
        expected_cols = [submission.columns[0]] + TARGET_LABELS
        sub_cols = list(submission.columns)
        gtr_cols = list(ground_truth.columns)
        if sub_cols != expected_cols:
            raise InvalidSubmissionError(f"Submission must have columns: {expected_cols}.")
        if gtr_cols != expected_cols:
            raise InvalidSubmissionError(
                f"Ground truth must have columns: {expected_cols}."
            )

        # Value checks for submission
        vals = submission[TARGET_LABELS].to_numpy(dtype=float)
        if not np.isfinite(vals).all():
            raise InvalidSubmissionError("Submission contains NaN or Inf values.")
        if (vals < -1e-6).any() or (vals > 1 + 1e-6).any():
            raise InvalidSubmissionError("Submission predictions must be in [0,1].")

        return "Submission is valid."
