from typing import Any
import numpy as np
import pandas as pd
from sklearn.metrics import cohen_kappa_score

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class DisneylandReviewMetrics(CompetitionMetrics):
    """Quadratic Weighted Kappa metric for Disneyland Reviews rating prediction.

    Expected format:
    - Ground truth (y_true): DataFrame with columns [Review_ID, Rating]
    - Predictions (y_pred): DataFrame with columns [Review_ID, Rating]

    Ratings are coerced to integers in [1, 5] via rounding and clipping.
    """

    def __init__(self, value: str = "Rating", higher_is_better: bool = True):
        # Higher QWK is better
        super().__init__(higher_is_better)
        self.value = value

    @staticmethod
    def _safe_int_clip(arr, lo: int = 1, hi: int = 5) -> np.ndarray:
        a = pd.to_numeric(arr, errors="coerce").astype(float)
        # Replace NaNs with median or midpoint
        if np.isnan(a).all():
            fill = (lo + hi) / 2.0
            a = np.full_like(a, fill, dtype=float)
        else:
            med = np.nanmedian(a)
            a = np.where(np.isnan(a), med, a)
        a = np.rint(a)
        a = np.clip(a, lo, hi)
        return a.astype(int)

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both by the first column to align rows
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Coerce predictions to valid integer ratings
        y_pred_vals = self._safe_int_clip(y_pred[self.value].values, lo=1, hi=5)
        y_true_vals = pd.to_numeric(y_true[self.value], errors="raise").astype(int).values

        return float(cohen_kappa_score(y_true_vals, y_pred_vals, weights="quadratic"))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)}). Please ensure both have the same number of rows."
            )

        # Sort and compare first column (Review_ID)
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        if (
            submission_sorted[submission_sorted.columns[0]].values
            != ground_truth_sorted[ground_truth_sorted.columns[0]].values
        ).any():
            raise InvalidSubmissionError(
                "First column values do not match between submission and ground truth. Please ensure the first column values are identical."
            )

        # Columns check: must contain exactly Review_ID and Rating (order can vary)
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)
        missing_cols = true_cols - sub_cols
        extra_cols = sub_cols - true_cols
        if missing_cols:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_cols))}."
            )
        if extra_cols:
            raise InvalidSubmissionError(
                f"Extra unexpected columns found in submission: {', '.join(sorted(extra_cols))}."
            )

        # Basic value checks for Rating column if present
        if self.value in submission.columns:
            vals = pd.to_numeric(submission[self.value], errors="coerce")
            if vals.isnull().all():
                raise InvalidSubmissionError(
                    "All provided ratings are non-numeric/NaN. Please submit numeric ratings."
                )
        return "Submission is valid."
