from typing import Any, List, Tuple

import numpy as np
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


EPS = 1e-15


class RedditSubredditClassificationMetrics(CompetitionMetrics):
    """
    Multiclass log-loss metric for the Reddit Subreddit Classification task.

    Expects:
    - y_true: DataFrame with columns ['id', 'subreddit'] (e.g., private/test_answer.csv)
    - y_pred: DataFrame with columns ['id', <subreddit_1>, ..., <subreddit_K>] (e.g., public/sample_submission.csv)

    Score is the average negative log-probability of the true class.
    Lower is better (higher_is_better=False).
    """

    def __init__(self, value: str = "subreddit", higher_is_better: bool = False):
        super().__init__(higher_is_better=higher_is_better)
        self.value = value

    # -------------- helpers --------------
    @staticmethod
    def _classes_from_truth(y_true: pd.DataFrame) -> List[str]:
        return sorted(y_true["subreddit"].unique().tolist())

    @staticmethod
    def _label_cols_from_submission(y_pred: pd.DataFrame) -> List[str]:
        return [c for c in y_pred.columns if c != "id"]

    @staticmethod
    def _ensure_dataframe(obj: Any, name: str) -> pd.DataFrame:
        if not isinstance(obj, pd.DataFrame):
            raise InvalidSubmissionError(f"{name} must be a pandas DataFrame.")
        return obj

    # -------------- public API --------------
    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Validate first to make sure the format is correct and aligned
        self.validate_submission(y_pred, y_true)

        # Sort by id to guarantee alignment
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        label_cols = self._label_cols_from_submission(y_pred)
        label_to_idx = {label: i for i, label in enumerate(label_cols)}

        # Map ground-truth labels to column indices
        try:
            y_true_idx = y_true[self.value].map(label_to_idx).to_numpy()
        except KeyError:
            raise InvalidSubmissionError(
                f"Ground truth must contain a '{self.value}' column.")

        P = y_pred[label_cols].to_numpy(dtype=float)
        # Numerical safety
        P = np.clip(P, EPS, 1.0)
        P = P / P.sum(axis=1, keepdims=True)

        true_probs = P[np.arange(P.shape[0]), y_true_idx]
        true_probs = np.clip(true_probs, EPS, 1.0)
        loss = -float(np.mean(np.log(true_probs)))

        if not np.isfinite(loss):
            raise InvalidSubmissionError("Computed log loss is not finite.")
        return loss

    def validate_submission(self, submission: Any, ground_truth: Any) -> Tuple[bool, str]:
        y_pred = self._ensure_dataframe(submission, "Submission")
        y_true = self._ensure_dataframe(ground_truth, "Ground truth")

        # Basic columns presence
        if "id" not in y_pred.columns:
            raise InvalidSubmissionError("Submission must include an 'id' column.")
        if self.value not in y_true.columns:
            raise InvalidSubmissionError(
                f"Ground truth must include a '{self.value}' column.")

        # Length must match
        if len(y_pred) != len(y_true):
            raise InvalidSubmissionError(
                f"Row count mismatch: submission={len(y_pred)} vs ground_truth={len(y_true)}")

        # Sort both by id and ensure one-to-one mapping
        y_pred_sorted = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)
        y_true_sorted = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        if (y_pred_sorted[y_pred_sorted.columns[0]].values != y_true_sorted[y_true_sorted.columns[0]].values).any():
            raise InvalidSubmissionError(
                "The values in the first column (ids) do not match between submission and ground truth.")

        # Check columns and classes
        sub_cols = self._label_cols_from_submission(y_pred)
        if len(sub_cols) < 2:
            raise InvalidSubmissionError(
                "Submission must contain probability columns for each class (at least two).")
        classes = self._classes_from_truth(y_true)

        missing_cols = sorted(list(set(classes) - set(sub_cols)))
        extra_cols = sorted(list(set(sub_cols) - set(classes)))
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required class columns in submission: {', '.join(missing_cols)}.")
        if extra_cols:
            raise InvalidSubmissionError(
                f"Unexpected extra columns in submission: {', '.join(extra_cols)}.")

        # Validate numeric content and probability simplex per row
        probs = y_pred_sorted[sub_cols].to_numpy(dtype=float)
        if not np.all(np.isfinite(probs)):
            raise InvalidSubmissionError("Submission contains NaN or Inf values.")
        if (probs < -1e-12).any():
            raise InvalidSubmissionError("Submission contains negative probabilities.")
        row_sums = probs.sum(axis=1)
        if not np.all(np.abs(row_sums - 1.0) <= 1e-3):
            raise InvalidSubmissionError(
                "Each row of the submission must sum to 1 within tolerance 1e-3.")

        # Check uniqueness of ids
        if not y_pred["id"].is_unique:
            raise InvalidSubmissionError("Duplicate ids found in submission.")

        return True, "Submission is valid."
