from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class MERCMetrics(CompetitionMetrics):
    """
    Metrics for Multimodal Emotion Recognition in Conversations (MERC).

    - Task: multiclass classification among seven emotions
    - Evaluation: macro-averaged F1 (higher is better)
    - Expected columns: id, emotion
    """

    EMOTIONS = [
        "anger",
        "disgust",
        "sadness",
        "joy",
        "neutral",
        "surprise",
        "fear",
    ]

    def __init__(self, value: str = "emotion", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort by the first column (id) to ensure alignment
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Compute macro F1 across the seven classes
        yt = y_true[self.value].astype(str).str.lower().values
        yp = y_pred[self.value].astype(str).str.lower().values
        return float(
            f1_score(yt, yp, labels=self.EMOTIONS, average="macro")
        )

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort both by first column (id)
        submission = submission.sort_values(by=submission.columns[0])
        ground_truth = ground_truth.sort_values(by=ground_truth.columns[0])

        # First column values must match exactly
        if (submission.iloc[:, 0].values != ground_truth.iloc[:, 0].values).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Columns must be exactly {id, emotion}
        required_cols = set(ground_truth.columns)
        sub_cols = set(submission.columns)

        missing_cols = required_cols - sub_cols
        extra_cols = sub_cols - required_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Validate label values
        if self.value not in submission.columns:
            raise InvalidSubmissionError(
                f"Submission must contain column '{self.value}'."
            )
        invalid = (
            submission[self.value].astype(str).str.lower().isin(self.EMOTIONS)
        )
        if not bool(invalid.all()):
            raise InvalidSubmissionError(
                "Submission contains invalid label values. Allowed labels: "
                + ", ".join(self.EMOTIONS)
            )

        return "Submission is valid."
