from typing import Any
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


VALID_CLASSES = [
    "Lung Opacity",
    "No Lung Opacity / Not Normal",
    "Normal",
]


def _normalize_rle(s: Any) -> str:
    if s is None:
        return ""
    if isinstance(s, float) and np.isnan(s):
        return ""
    return str(s).strip()


def rle_decode(rle: str, shape: tuple[int, int]) -> np.ndarray:
    """
    Decode an RLE string (1-indexed, column-major) into a binary mask of shape (H, W).
    Empty string -> all zeros mask.
    """
    h, w = shape
    size = h * w
    flat = np.zeros(size, dtype=np.uint8)
    rle = _normalize_rle(rle)
    if rle == "":
        return flat.reshape((w, h)).T
    parts = rle.split()
    try:
        starts = [int(x) for x in parts[0::2]]
        lengths = [int(x) for x in parts[1::2]]
    except Exception as e:
        # Malformed -> empty
        return flat.reshape((w, h)).T
    for st, ln in zip(starts, lengths):
        st0 = st - 1  # 1-indexed to 0-indexed
        if st0 < 0 or ln <= 0:
            continue
        end = st0 + ln
        end = min(end, size)
        flat[st0:end] = 1
    return flat.reshape((w, h)).T


def dice_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    y_true = (y_true > 0).astype(np.uint8)
    y_pred = (y_pred > 0).astype(np.uint8)
    inter = float(np.sum(y_true & y_pred))
    denom = float(np.sum(y_true) + np.sum(y_pred))
    if denom == 0.0:
        return 1.0
    return 2.0 * inter / denom


class RSNAPneumoniaMetrics(CompetitionMetrics):
    """
    Metric for the RSNA Pneumonia processed dataset.
    Final score = 0.5 * macro F1 over three classes + 0.5 * mean Dice on positive (GT) cases.
    Submission/GT format: CSV DataFrames with columns [id, class, mask_rle] for submission and
    [id, class, height, width, mask_rle] for ground truth answers.
    """

    def __init__(self, value: str = "mask_rle", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort and align by id
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")
        if "id" not in y_true.columns or "id" not in y_pred.columns:
            raise InvalidSubmissionError("Both y_true and y_pred must contain an 'id' column.")

        y_true = y_true.copy()
        y_pred = y_pred.copy()
        y_true["id"] = y_true["id"].astype(str)
        y_pred["id"] = y_pred["id"].astype(str)
        y_true = y_true.sort_values("id").reset_index(drop=True)
        y_pred = y_pred.sort_values("id").reset_index(drop=True)

        if len(y_true) != len(y_pred) or (y_true["id"].values != y_pred["id"].values).any():
            raise InvalidSubmissionError("IDs in prediction do not match IDs in ground truth.")

        # Macro F1 across classes
        if "class" not in y_true.columns or "class" not in y_pred.columns:
            raise InvalidSubmissionError("Both y_true and y_pred must contain a 'class' column.")
        y_t = y_true["class"].astype(str).values
        y_p = y_pred["class"].astype(str).values
        f1s = []
        for c in VALID_CLASSES:
            tp = float(np.sum((y_t == c) & (y_p == c)))
            fp = float(np.sum((y_t != c) & (y_p == c)))
            fn = float(np.sum((y_t == c) & (y_p != c)))
            prec = 0.0 if (tp + fp) == 0 else tp / (tp + fp)
            rec = 0.0 if (tp + fn) == 0 else tp / (tp + fn)
            f1 = 0.0 if (prec + rec) == 0 else (2.0 * prec * rec / (prec + rec))
            f1s.append(f1)
        macro_f1 = float(np.mean(f1s)) if len(f1s) > 0 else 0.0

        # Dice on positives (where GT mask has any positive pixel)
        if not {"height", "width", "mask_rle"}.issubset(set(y_true.columns)):
            raise InvalidSubmissionError("Ground truth must contain columns: height, width, mask_rle.")
        if "mask_rle" not in y_pred.columns:
            raise InvalidSubmissionError("Prediction must contain column: mask_rle.")

        dice_vals = []
        for i in range(len(y_true)):
            h = int(y_true.loc[i, "height"]) if not pd.isna(y_true.loc[i, "height"]) else 0
            w = int(y_true.loc[i, "width"]) if not pd.isna(y_true.loc[i, "width"]) else 0
            gt_rle = _normalize_rle(y_true.loc[i, "mask_rle"]) if "mask_rle" in y_true.columns else ""
            gt_mask = rle_decode(gt_rle, (h, w))
            if (gt_mask > 0).any():
                pr_rle = _normalize_rle(y_pred.loc[i, "mask_rle"]) if "mask_rle" in y_pred.columns else ""
                pr_mask = rle_decode(pr_rle, (h, w))
                dice_vals.append(dice_score(gt_mask, pr_mask))
        mean_dice = float(np.mean(dice_vals)) if len(dice_vals) > 0 else 0.0

        final = 0.5 * macro_f1 + 0.5 * mean_dice
        final = float(max(0.0, min(1.0, final)))
        return final

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Accept DataFrames as inputs
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        req_cols = {"id", "class", "mask_rle"}
        missing = req_cols - set(submission.columns)
        if missing:
            raise InvalidSubmissionError(f"Missing required columns in submission: {sorted(missing)}")

        if "id" not in ground_truth.columns:
            raise InvalidSubmissionError("Ground truth must contain an 'id' column.")

        # Check row count and IDs
        sub = submission.copy()
        gt = ground_truth.copy()
        sub["id"] = sub["id"].astype(str)
        gt["id"] = gt["id"].astype(str)
        if len(sub) != len(gt):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(sub)}) does not match ground truth ({len(gt)})."
            )
        sub = sub.sort_values("id").reset_index(drop=True)
        gt = gt.sort_values("id").reset_index(drop=True)
        if (sub["id"].values != gt["id"].values).any():
            raise InvalidSubmissionError("IDs in submission do not align with ground truth IDs.")

        # Classes validity
        if not sub["class"].astype(str).isin(VALID_CLASSES).all():
            bad = sub.loc[~sub["class"].astype(str).isin(VALID_CLASSES), "class"].unique().tolist()
            raise InvalidSubmissionError(
                f"Invalid classes in submission: {bad}. Valid classes: {VALID_CLASSES}"
            )

        # RLE basic format validation (even number of integers, or empty)
        def _ok_rle(s: Any) -> bool:
            s = _normalize_rle(s)
            if s == "":
                return True
            parts = s.split()
            if len(parts) % 2 != 0:
                return False
            try:
                _ = [int(x) for x in parts]
                return True
            except Exception:
                return False

        if not sub["mask_rle"].map(_ok_rle).all():
            raise InvalidSubmissionError("Malformed RLE strings detected in submission.")

        return "Submission is valid."
