from typing import Any, List

import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

# Allowed label set for this competition
LABELS: List[str] = [
    "NotHate",
    "Racist",
    "Sexist",
    "Homophobe",
    "Religion",
    "OtherHate",
]
LABEL_SET = set(LABELS)


def _macro_f1(y_true: list[str], y_pred: list[str]) -> float:
    assert len(y_true) == len(y_pred)

    classes = LABELS
    tp = {c: 0 for c in classes}
    fp = {c: 0 for c in classes}
    fn = {c: 0 for c in classes}

    for g, p in zip(y_true, y_pred):
        if g == p:
            tp[g] += 1
        else:
            fp[p] += 1
            fn[g] += 1

    f1s = []
    for c in classes:
        prec = tp[c] / (tp[c] + fp[c]) if (tp[c] + fp[c]) > 0 else 0.0
        rec = tp[c] / (tp[c] + fn[c]) if (tp[c] + fn[c]) > 0 else 0.0
        if prec + rec == 0:
            f1 = 0.0
        else:
            f1 = 2 * prec * rec / (prec + rec)
        if not math.isfinite(f1):
            f1 = 0.0
        f1s.append(f1)

    return sum(f1s) / len(f1s) if f1s else 0.0


class MultimodalHateSpeechMetrics(CompetitionMetrics):
    """Macro-F1 for Multimodal Hate Speech Classification.

    Expected format:
    - y_true: pandas.DataFrame with columns ['id', 'label'] (from test_answer.csv)
    - y_pred: pandas.DataFrame with columns ['id', 'label'] (like sample_submission.csv)
    """

    def __init__(self, value: str = "label", higher_is_better: bool = True):
        # For Macro-F1, higher is better
        super().__init__(higher_is_better)
        self.value = value

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        # Type checks
        if not isinstance(submission, pd.DataFrame):
            return False, "Submission must be a pandas DataFrame."
        if not isinstance(ground_truth, pd.DataFrame):
            return False, "Ground truth must be a pandas DataFrame."

        # Required columns
        required_cols = ["id", "label"]
        if submission.columns.tolist() != required_cols:
            return (
                False,
                f"Submission must have columns exactly {required_cols}, got {submission.columns.tolist()}.",
            )

        if "id" not in ground_truth.columns:
            return False, "Ground truth must contain an 'id' column."
        if "label" not in ground_truth.columns:
            return False, "Ground truth must contain a 'label' column."

        # Length check
        if len(submission) != len(ground_truth):
            return (
                False,
                f"Row count mismatch: submission={len(submission)} vs ground_truth={len(ground_truth)}.",
            )

        # Sort by id for alignment
        sub_sorted = submission.sort_values(by="id").reset_index(drop=True)
        gt_sorted = ground_truth.sort_values(by="id").reset_index(drop=True)

        # Check id alignment
        if not sub_sorted["id"].equals(gt_sorted["id"]):
            return False, "Submission ids do not match ground truth ids."

        # Validate labels are within allowed set
        invalid = [lab for lab in sub_sorted["label"].tolist() if lab not in LABEL_SET]
        if invalid:
            return False, f"Invalid labels found in submission: {sorted(set(invalid))}."

        return True, "Submission is valid."

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        ok, msg = self.validate_submission(y_pred, y_true)
        if not ok:
            raise InvalidSubmissionError(msg)

        y_true_sorted = y_true.sort_values(by="id").reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by="id").reset_index(drop=True)

        gt_labels = y_true_sorted[self.value].tolist()
        pd_labels = y_pred_sorted[self.value].tolist()

        return _macro_f1(gt_labels, pd_labels)
