from typing import Any
import pandas as pd
from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class BirdsongClassificationMetrics(CompetitionMetrics):
    """Macro-F1 metric for Birdsong single-label classification.

    Expected CSV schema (aligned with public/sample_submission.csv and private/test_answer.csv):
    - Columns: [id, label]
    - id: anonymized audio filename
    - label: class name (string)
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        from sklearn.metrics import f1_score

        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")

        # Validate schemas contain the expected columns
        for df, name in [(y_true, "y_true"), (y_pred, "y_pred")]:
            if "id" not in df.columns or self.value not in df.columns:
                raise InvalidSubmissionError(
                    f"{name} must contain columns ['id', '{self.value}']"
                )

        # Sort both dataframes by first column (id) and align
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Ensure ids match exactly
        if len(y_true_sorted) != len(y_pred_sorted):
            raise InvalidSubmissionError(
                f"Row count mismatch between y_true ({len(y_true_sorted)}) and y_pred ({len(y_pred_sorted)})."
            )
        if (y_true_sorted[y_true_sorted.columns[0]].values != y_pred_sorted[y_pred_sorted.columns[0]].values).any():
            raise InvalidSubmissionError("Mismatch in id column between y_true and y_pred.")

        return float(
            f1_score(
                y_true_sorted[self.value].astype(str).values,
                y_pred_sorted[self.value].astype(str).values,
                average="macro",
            )
        )

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # ground_truth is expected to be the private/test_answer.csv (id,label)
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Required columns must be identical sets
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = {"id", self.value} - sub_cols
        extra_cols = sub_cols - {"id", self.value}
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Check row counts and ids (must match exactly with ground truth/test ids)
        sub_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        true_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if len(sub_sorted) != len(true_sorted):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(sub_sorted)}) does not match ground truth ({len(true_sorted)})."
            )

        if (sub_sorted[sub_sorted.columns[0]].values != true_sorted[true_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Non-null, non-empty labels
        if not submission[self.value].notna().all():
            raise InvalidSubmissionError("Submission contains null labels.")
        if (submission[self.value].astype(str).str.strip() == "").any():
            raise InvalidSubmissionError("Submission contains empty labels.")

        return "Submission is valid."
