from typing import Any
import pandas as pd
from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class SpeakerRecognitionMetrics(CompetitionMetrics):
    """Macro F1 metric for speaker recognition submissions.

    Expected dataframe schemas:
    - y_true (ground truth): columns ['file_id', 'label']
    - y_pred (submission):   columns ['file_id', 'label']
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        from sklearn.metrics import f1_score

        # Basic type checks
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")

        # Validate schemas
        required_cols = {"file_id", self.value}
        if not required_cols.issubset(set(y_true.columns)):
            raise InvalidSubmissionError(
                f"y_true must contain columns {sorted(required_cols)}; got {list(y_true.columns)}"
            )
        if not required_cols.issubset(set(y_pred.columns)):
            raise InvalidSubmissionError(
                f"y_pred must contain columns {sorted(required_cols)}; got {list(y_pred.columns)}"
            )

        # Sort and align by file_id
        y_true_sorted = y_true.sort_values(by=["file_id"]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=["file_id"]).reset_index(drop=True)

        # Check same ids
        if (y_true_sorted["file_id"].values != y_pred_sorted["file_id"].values).any():
            raise InvalidSubmissionError(
                "file_id values do not match between y_true and y_pred after sorting."
            )

        # Compute macro F1 over labels present in y_true
        labels = sorted(y_true_sorted[self.value].unique().tolist())
        return float(
            f1_score(
                y_true_sorted[self.value], y_pred_sorted[self.value], average="macro", labels=labels
            )
        )

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Columns check: must be exactly ['file_id', self.value]
        sub_cols = list(submission.columns)
        gt_cols = list(ground_truth.columns)
        expected_cols = ["file_id", self.value]
        if set(sub_cols) != set(expected_cols):
            extra = sorted(set(sub_cols) - set(expected_cols))
            missing = sorted(set(expected_cols) - set(sub_cols))
            parts = []
            if missing:
                parts.append(f"missing columns: {', '.join(missing)}")
            if extra:
                parts.append(f"unexpected columns: {', '.join(extra)}")
            raise InvalidSubmissionError(
                f"Submission must have exactly columns {expected_cols}; problems: {'; '.join(parts) if parts else 'unknown'}"
            )
        if set(gt_cols) != set(expected_cols):
            raise InvalidSubmissionError(
                f"Ground truth must have exactly columns {expected_cols}; got {gt_cols}"
            )

        # Sort the submission and ground truth by file_id
        submission = submission.sort_values(by=["file_id"]).reset_index(drop=True)
        ground_truth = ground_truth.sort_values(by=["file_id"]).reset_index(drop=True)

        # Check if file_id columns are identical
        if (submission["file_id"].values != ground_truth["file_id"].values).any():
            raise InvalidSubmissionError(
                "file_id values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Label validity: ensure all predicted labels are from the allowed set (derived from ground truth)
        allowed_labels = set(ground_truth[self.value].unique().tolist())
        pred_labels = set(submission[self.value].unique().tolist())
        unseen = sorted(pred_labels - allowed_labels)
        if unseen:
            raise InvalidSubmissionError(
                f"Submission contains labels not present in ground truth set: {', '.join(unseen)}."
            )

        return "Submission is valid."
