from typing import Any
import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

# Fixed label set for UrbanSound8K classification
LABELS = [
    'air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling',
    'engine_idling', 'gun_shot', 'jackhammer', 'siren', 'street_music'
]


def _macro_f1(y_true: list[str], y_pred: list[str]) -> float:
    classes = LABELS
    tp = {c: 0 for c in classes}
    fp = {c: 0 for c in classes}
    fn = {c: 0 for c in classes}

    for t, p in zip(y_true, y_pred):
        if p == t:
            tp[t] += 1
        else:
            if p in fp:
                fp[p] += 1
            if t in fn:
                fn[t] += 1

    f1s = []
    for c in classes:
        prec_den = tp[c] + fp[c]
        rec_den = tp[c] + fn[c]
        precision = tp[c] / prec_den if prec_den > 0 else 0.0
        recall = tp[c] / rec_den if rec_den > 0 else 0.0
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)
        # Guard against numerical issues
        if not (math.isfinite(f1) and 0.0 <= f1 <= 1.0):
            f1 = 0.0
        f1s.append(f1)

    return float(sum(f1s) / len(f1s))


class UrbanSound8KMetrics(CompetitionMetrics):
    """
    Competition metric for UrbanSound8K classification.
    Evaluates macro-averaged F1 over the 10 classes using label submissions.
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value  # column that contains predicted/true labels

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame.")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame.")

        # Sort both dataframes by first column before calculating score
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Columns must be present
        if self.value not in y_true_sorted.columns:
            raise InvalidSubmissionError(f"Column '{self.value}' missing in y_true.")
        if self.value not in y_pred_sorted.columns:
            raise InvalidSubmissionError(f"Column '{self.value}' missing in y_pred.")

        # The first columns (ids) must match exactly
        if (y_true_sorted.iloc[:, 0].values != y_pred_sorted.iloc[:, 0].values).any():
            raise InvalidSubmissionError("IDs in y_true and y_pred do not align.")

        true_labels = y_true_sorted[self.value].astype(str).tolist()
        pred_labels = y_pred_sorted[self.value].astype(str).tolist()

        # Validate label values
        for lbl in true_labels:
            if lbl not in LABELS:
                raise InvalidSubmissionError(f"Invalid label in y_true: {lbl}")
        for lbl in pred_labels:
            if lbl not in LABELS:
                raise InvalidSubmissionError(f"Invalid label in y_pred: {lbl}")

        return _macro_f1(true_labels, pred_labels)

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort by id (first column) for alignment
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # Check if first columns are identical
        if (
            submission_sorted.iloc[:, 0].values != ground_truth_sorted.iloc[:, 0].values
        ).any():
            raise InvalidSubmissionError(
                "First column values (ids) do not match between submission and ground truth."
            )

        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols

        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Validate label values
        if self.value not in submission.columns:
            raise InvalidSubmissionError(f"Submission must contain a '{self.value}' column.")
        invalid = [lbl for lbl in submission[self.value].astype(str).tolist() if lbl not in LABELS]
        if invalid:
            raise InvalidSubmissionError(
                f"Submission contains invalid label values: {sorted(set(invalid))}."
            )

        return "Submission is valid."
