from typing import Any
import pandas as pd
from sklearn.metrics import f1_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class BirdSpeciesClassificationMetrics(CompetitionMetrics):
    """
    Metrics for 200-bird-species classification.
    Uses macro-averaged F1 score, where higher is better.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Expect dataframes with columns: id, label
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")
        if self.value not in y_true.columns or self.value not in y_pred.columns:
            raise InvalidSubmissionError(f"Missing required column '{self.value}' in inputs.")

        # Sort both by first column (id) to align rows
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Compute macro F1
        return float(
            f1_score(
                y_true_sorted[self.value],
                y_pred_sorted[self.value],
                average="macro",
            )
        )

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Columns must match expected set exactly
        expected_cols = list(ground_truth.columns)
        sub_cols = list(submission.columns)
        if set(sub_cols) != set(expected_cols):
            missing = set(expected_cols) - set(sub_cols)
            extra = set(sub_cols) - set(expected_cols)
            msgs = []
            if missing:
                msgs.append(f"Missing required columns in submission: {', '.join(sorted(missing))}.")
            if extra:
                msgs.append(f"Extra unexpected columns found in submission: {', '.join(sorted(extra))}.")
            raise InvalidSubmissionError(" ".join(msgs))

        # Sort both by first column and check id alignment and lengths
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if len(submission_sorted) != len(ground_sorted):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission_sorted)}) does not match ground truth ({len(ground_sorted)})."
            )

        # First column (id) must match exactly
        sub_ids = submission_sorted[submission_sorted.columns[0]].astype(str).tolist()
        gt_ids = ground_sorted[ground_sorted.columns[0]].astype(str).tolist()
        if sub_ids != gt_ids:
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        # Validate labels are from allowed set (derived from ground truth labels)
        if self.value in ground_sorted.columns:
            allowed = set(ground_sorted[self.value].unique())
            invalid = [x for x in submission_sorted[self.value].tolist() if x not in allowed]
            if invalid:
                raise InvalidSubmissionError(
                    f"Submission contains invalid labels not present in ground truth: {sorted(set(invalid))[:5]}"
                )

        return "Submission is valid."
