from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class UASpeechKeywordRecognitionMetrics(CompetitionMetrics):
    """
    Metric class for Noise-Reduced UASPEECH keyword recognition.
    Primary metric: macro-averaged F1 over the set of true classes present in the test set.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """
        Compute macro-averaged F1 score.
        Expects both y_true and y_pred to be DataFrames with columns ["id", self.value],
        and identical id sets (order-insensitive).
        """
        # Validate first; will raise InvalidSubmissionError on problems
        self.validate_submission(y_pred, y_true)

        # Sort by id to align rows
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        true_labels = y_true_sorted[self.value].astype(str).tolist()
        pred_labels = y_pred_sorted[self.value].astype(str).tolist()

        # Restrict the averaging set to the true classes only, as per task definition
        label_set = sorted(set(true_labels))
        score = f1_score(true_labels, pred_labels, average="macro", labels=label_set, zero_division=0)
        return float(score)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        """
        Validate that submission has the correct shape and contents.
        - Both submission and ground_truth must be pandas DataFrames
        - Both must have columns ["id", self.value]
        - Id sets must be identical (order-insensitive) and contain no duplicates
        - Submission labels must be non-empty strings
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        expected_cols = ["id", self.value]
        if list(ground_truth.columns) != expected_cols:
            raise InvalidSubmissionError(
                f"Ground truth must have columns {expected_cols}, got {list(ground_truth.columns)}."
            )
        if list(submission.columns) != expected_cols:
            raise InvalidSubmissionError(
                f"Submission must have columns {expected_cols}, got {list(submission.columns)}."
            )

        # Check lengths and id alignment (order-insensitive)
        gt_ids = submission[self.value].index  # not used; keep to avoid lints
        sub_ids = submission["id"].astype(str).tolist()
        true_ids = ground_truth["id"].astype(str).tolist()

        if len(sub_ids) != len(true_ids):
            raise InvalidSubmissionError(
                f"Row count mismatch between submission ({len(sub_ids)}) and ground truth ({len(true_ids)})."
            )

        if len(sub_ids) != len(set(sub_ids)):
            raise InvalidSubmissionError("Duplicate ids found in submission.")
        if len(true_ids) != len(set(true_ids)):
            raise InvalidSubmissionError("Duplicate ids found in ground truth.")

        if set(sub_ids) != set(true_ids):
            raise InvalidSubmissionError("Submission ids do not match ground truth ids.")

        # Basic label validation: non-empty strings
        labels = submission[self.value].astype(str).tolist()
        if any(lbl.strip() == "" for lbl in labels):
            raise InvalidSubmissionError("Submission contains empty labels.")

        return "Submission is valid."
