from typing import Any, Set

import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class ArxivTaggingMetrics(CompetitionMetrics):
    """Micro-averaged F1 for multi-label arXiv tagging submissions.

    Expected DataFrame schema (aligned with this competition):
    - y_true (ground truth): columns [id, labels]
    - y_pred (submission): columns [id, labels]
    where `labels` is a space-separated string of tags. Empty string is allowed.
    """

    def __init__(self, value: str = "labels", higher_is_better: bool = True):
        # Higher F1 is better by definition
        super().__init__(higher_is_better)
        self.value = value

    def _ensure_dataframe(self, obj: Any, name: str) -> pd.DataFrame:
        if not isinstance(obj, pd.DataFrame):
            raise InvalidSubmissionError(f"{name} must be a pandas DataFrame.")
        return obj

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Type checks
        submission = self._ensure_dataframe(submission, "Submission").copy()
        ground_truth = self._ensure_dataframe(ground_truth, "Ground truth").copy()

        # Column checks: exactly two columns in required order
        required_cols = ["id", self.value]
        if submission.columns.tolist() != required_cols:
            raise InvalidSubmissionError(
                f"Submission must have columns {required_cols} in this exact order. Got: {submission.columns.tolist()}"
            )
        if ground_truth.columns.tolist() != required_cols:
            raise InvalidSubmissionError(
                f"Ground truth must have columns {required_cols} in this exact order. Got: {ground_truth.columns.tolist()}"
            )

        # Replace NaN labels with empty string (allowed)
        submission[self.value] = submission[self.value].fillna("")
        ground_truth[self.value] = ground_truth[self.value].fillna("")

        # Length check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Row count mismatch: submission={len(submission)} ground_truth={len(ground_truth)}"
            )

        # Sort both by id and check exact id equality
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (submission_sorted.iloc[:, 0].values != ground_sorted.iloc[:, 0].values).any():
            raise InvalidSubmissionError("Submission ids do not match ground truth ids or order.")

        return "Submission is valid."

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        # Validate first (will also align schemas); tolerate NaNs as empty labels
        self.validate_submission(y_pred, y_true)

        # Align by id (sorted by first column)
        y_true = self._ensure_dataframe(y_true, "y_true").copy().sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = self._ensure_dataframe(y_pred, "y_pred").copy().sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Ensure label fields have no NaN
        y_true[self.value] = y_true[self.value].fillna("")
        y_pred[self.value] = y_pred[self.value].fillna("")

        # Compute micro-averaged F1
        TP = 0
        FP = 0
        FN = 0

        for i in range(len(y_true)):
            truth_labels: Set[str] = set(t for t in str(y_true.at[i, self.value]).split() if t)
            pred_labels: Set[str] = set(t for t in str(y_pred.at[i, self.value]).split() if t)

            TP += len(truth_labels & pred_labels)
            FP += len(pred_labels - truth_labels)
            FN += len(truth_labels - pred_labels)

        precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
        if precision + recall == 0.0:
            return 0.0
        f1 = 2.0 * precision * recall / (precision + recall)
        if not (0.0 <= f1 <= 1.0) or math.isnan(f1) or math.isinf(f1):
            raise ValueError("Computed F1 is invalid")
        return float(f1)
