from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class SpeechCommandsClassificationMetrics(CompetitionMetrics):
    """
    Metric class for Speech Commands classification using macro-averaged F1 score.

    Expected CSV schema (aligned with public/sample_submission.csv and private/test_answer.csv):
    - Two columns: [id, label]
    - 'id' values must match exactly between prediction and ground truth; order is ignored.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        """
        Compute macro F1 between ground-truth labels and predictions.
        Both y_true and y_pred must be pandas DataFrames with columns [id, label].
        The rows are aligned by sorting the first column (id).
        """
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_true and y_pred must be pandas DataFrames.")
        # Sort and align by id (first column)
        id_col_true = y_true.columns[0]
        id_col_pred = y_pred.columns[0]
        y_true_sorted = y_true.sort_values(by=id_col_true).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=id_col_pred).reset_index(drop=True)

        # Validate ID equality after sorting
        if (y_true_sorted[id_col_true].values != y_pred_sorted[id_col_pred].values).any():
            raise InvalidSubmissionError("IDs do not match between y_true and y_pred after sorting.")

        # Compute macro F1 on the label/value column
        y_true_labels = y_true_sorted[self.value]
        y_pred_labels = y_pred_sorted[self.value]
        return float(f1_score(y_true_labels, y_pred_labels, average="macro"))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        """
        Validate the submission format against the ground truth file.
        Requirements:
          - Both are pandas DataFrames
          - Row counts match exactly
          - First column values (ids) match exactly (order-agnostic by sorting)
          - Column sets match exactly (should be [id, label])
        """
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by first column (ids) to align
        submission_sorted = submission.sort_values(by=submission.columns[0])
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0])

        # Check first column values match exactly
        if (
            submission_sorted[submission_sorted.columns[0]].values
            != ground_truth_sorted[ground_truth_sorted.columns[0]].values
        ).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        return "Submission is valid."
