from typing import Any, Tuple
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


ID_COL = "ID"
TARGET_COL = "Duration_Minutes"


def _safe_float_series(s: pd.Series) -> pd.Series:
    x = pd.to_numeric(s, errors="coerce")
    x = x.replace([np.inf, -np.inf], np.nan).fillna(0.0)
    x = x.clip(lower=0.0)
    return x


def _rmsle_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    y_true = np.clip(y_true, 0.0, None)
    y_pred = np.clip(y_pred, 0.0, None)
    diff = np.log1p(y_pred) - np.log1p(y_true)
    return float(np.sqrt(np.mean(np.square(diff))))


class USCongestionDurationMetrics(CompetitionMetrics):
    """RMSLE metric for US Traffic Congestion Duration prediction.

    Expects dataframes with columns [ID, Duration_Minutes].
    """

    def __init__(self, value: str = TARGET_COL, higher_is_better: bool = False):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both by ID for alignment
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_true and y_pred must be pandas DataFrames.")

        # Basic column checks
        if self.value not in y_true.columns or ID_COL not in y_true.columns:
            raise InvalidSubmissionError(
                f"Ground truth must contain columns [{ID_COL}, {self.value}]"
            )
        if self.value not in y_pred.columns or ID_COL not in y_pred.columns:
            raise InvalidSubmissionError(
                f"Predictions must contain columns [{ID_COL}, {self.value}]"
            )

        y_true_sorted = y_true.sort_values(by=[ID_COL]).reset_index(drop=True)
        y_pred_sorted = y_pred.sort_values(by=[ID_COL]).reset_index(drop=True)

        # Ensure ID alignment
        if not y_true_sorted[ID_COL].astype(str).equals(y_pred_sorted[ID_COL].astype(str)):
            raise InvalidSubmissionError("IDs in predictions do not match ground truth IDs in order.")

        yt = _safe_float_series(y_true_sorted[self.value]).to_numpy()
        yp = _safe_float_series(y_pred_sorted[self.value]).to_numpy()
        return _rmsle_score(yt, yp)

    def validate_submission(self, submission: Any, ground_truth: Any) -> Tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            return False, "Submission must be a pandas DataFrame."
        if not isinstance(ground_truth, pd.DataFrame):
            return False, "Ground truth must be a pandas DataFrame."

        # Column set must match exactly [ID, value]
        expected_cols = {ID_COL, self.value}
        sub_cols = set(submission.columns)
        gt_cols = set(ground_truth.columns)
        if sub_cols != expected_cols:
            return False, f"Submission must have exactly columns [{ID_COL}, {self.value}] in any order."
        if gt_cols != expected_cols:
            return False, f"Ground truth must have exactly columns [{ID_COL}, {self.value}] in any order."

        # Lengths
        if len(submission) != len(ground_truth):
            return False, (
                f"Row count mismatch: submission has {len(submission)}, ground truth has {len(ground_truth)}."
            )

        # IDs: same set and no duplicates
        sub_ids = submission[ID_COL].astype(str)
        gt_ids = ground_truth[ID_COL].astype(str)
        if sub_ids.duplicated().any():
            return False, "Submission contains duplicate IDs."
        if gt_ids.duplicated().any():
            return False, "Ground truth contains duplicate IDs."
        if set(sub_ids) != set(gt_ids):
            return False, "Submission IDs must match the ground truth test IDs exactly."

        # Predictions numeric and finite
        preds = _safe_float_series(submission[self.value])
        if not np.isfinite(preds).all():
            return False, "Predictions must be finite numbers."
        if (preds < 0).any():
            return False, "Predictions must be non-negative."

        return True, "Submission is valid."
