from typing import Any, Tuple
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


def _parse_rle_pairs(encoding: str) -> Tuple[np.ndarray, np.ndarray]:
    """
    Parse an RLE string into starts (1-indexed) and lengths arrays.
    Returns empty arrays for empty/NaN encodings.
    """
    if encoding is None:
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64)
    s = str(encoding).strip()
    if s == "" or s.lower() == "nan":
        return np.array([], dtype=np.int64), np.array([], dtype=np.int64)

    parts = s.split()
    if len(parts) % 2 != 0:
        raise InvalidSubmissionError("RLE string has an odd number of elements; expected start/length pairs.")

    try:
        starts = np.array([int(p) for p in parts[0::2]], dtype=np.int64)
        lengths = np.array([int(p) for p in parts[1::2]], dtype=np.int64)
    except Exception:
        raise InvalidSubmissionError("RLE contains non-integer tokens.")

    if (lengths < 0).any() or (starts < 0).any():
        raise InvalidSubmissionError("RLE contains negative values.")

    return starts, lengths


def _rle_decode_1d(starts: np.ndarray, lengths: np.ndarray, size: int) -> np.ndarray:
    """Decode RLE to a 1D binary array of given size (row-major flatten)."""
    arr = np.zeros(size, dtype=np.uint8)
    if starts.size == 0:
        return arr
    # Convert potentially 1-indexed starts to 0-indexed
    base = 1 if (starts > 0).any() and (starts.min() == 1 or starts.min() > 0) else 0
    lo = starts - (1 if base == 1 else 0)
    hi = lo + lengths
    lo = np.maximum(lo, 0)
    hi = np.minimum(hi, size)
    for l, h in zip(lo.tolist(), hi.tolist()):
        if h > l:
            arr[l:h] = 1
    return arr


class UrbanSegmentationISPRSMetrics(CompetitionMetrics):
    """Mean IoU metric for ISPRS-style urban segmentation submissions.

    Expects DataFrames with columns: [image_id, class_id, encoding]
    where encoding is an RLE string of a binary mask for that class.
    """

    def __init__(self, value: str = "encoding", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value  # kept for API symmetry; not directly used

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic validation of inputs
        if not isinstance(y_true, pd.DataFrame) or not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("Both y_true and y_pred must be pandas DataFrames.")

        required_cols = {"image_id", "class_id", "encoding"}
        if set(y_true.columns) != required_cols or set(y_pred.columns) != required_cols:
            raise InvalidSubmissionError("DataFrames must have exactly the columns: image_id, class_id, encoding.")

        # Sort for alignment
        y_true = y_true.copy()
        y_pred = y_pred.copy()
        y_true["class_id"] = y_true["class_id"].astype(int)
        y_pred["class_id"] = y_pred["class_id"].astype(int)

        # Validate that image_id sets match and there are exactly 6 classes per image
        true_groups = y_true.groupby("image_id")
        pred_groups = y_pred.groupby("image_id")
        if set(true_groups.groups.keys()) != set(pred_groups.groups.keys()):
            raise InvalidSubmissionError("The set of image_ids in predictions does not match ground truth.")

        # Accumulate intersections/unions per class across all images
        intersections = {cid: 0 for cid in range(6)}
        unions = {cid: 0 for cid in range(6)}

        for image_id in sorted(true_groups.groups.keys()):
            gt_img = true_groups.get_group(image_id)
            pr_img = pred_groups.get_group(image_id)

            # Ensure all classes present exactly once
            if set(gt_img["class_id"]) != set(range(6)):
                raise InvalidSubmissionError(f"Ground truth for {image_id} must contain class_ids 0..5 exactly once.")
            if set(pr_img["class_id"]) != set(range(6)):
                raise InvalidSubmissionError(f"Prediction for {image_id} must contain class_ids 0..5 exactly once.")

            gt_img = gt_img.sort_values("class_id")
            pr_img = pr_img.sort_values("class_id")

            for cid, (enc_true, enc_pred) in enumerate(zip(gt_img["encoding"].tolist(), pr_img["encoding"].tolist())):
                st_t, ln_t = _parse_rle_pairs(enc_true)
                st_p, ln_p = _parse_rle_pairs(enc_pred)
                end_true = (st_t + ln_t - 1) if st_t.size else np.array([0])
                end_pred = (st_p + ln_p - 1) if st_p.size else np.array([0])
                size = int(max(end_true.max(initial=0), end_pred.max(initial=0)))
                size = max(size, 0)
                arr_t = _rle_decode_1d(st_t, ln_t, size)
                arr_p = _rle_decode_1d(st_p, ln_p, size)

                inter = int(np.logical_and(arr_t, arr_p).sum())
                union = int(np.logical_or(arr_t, arr_p).sum())
                intersections[cid] += inter
                unions[cid] += union

        class_ious = []
        for cid in range(6):
            if unions[cid] == 0:
                continue  # ignore classes with no ground-truth/predictions overall
            iou = intersections[cid] / unions[cid] if unions[cid] > 0 else 0.0
            # clamp numeric issues
            iou = float(max(0.0, min(1.0, iou)))
            class_ious.append(iou)

        if not class_ious:
            return 0.0
        return float(np.mean(class_ious))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        import pandas as pd

        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        required_cols = {"image_id", "class_id", "encoding"}
        if set(submission.columns) != required_cols:
            raise InvalidSubmissionError(
                "Submission must have exactly the columns: image_id, class_id, encoding."
            )
        if set(ground_truth.columns) != required_cols:
            raise InvalidSubmissionError(
                "Ground truth must have exactly the columns: image_id, class_id, encoding."
            )

        # Sort and align by image_id then class_id
        submission = submission.copy()
        ground_truth = ground_truth.copy()
        submission["class_id"] = submission["class_id"].astype(int)
        ground_truth["class_id"] = ground_truth["class_id"].astype(int)

        # Basic shape checks
        if submission.shape[0] != ground_truth.shape[0]:
            raise InvalidSubmissionError(
                f"Row count mismatch: submission has {submission.shape[0]} rows, ground truth has {ground_truth.shape[0]} rows."
            )

        # Check that image_id sets match
        sub_ids = set(submission["image_id"].unique().tolist())
        true_ids = set(ground_truth["image_id"].unique().tolist())
        if sub_ids != true_ids:
            raise InvalidSubmissionError("Image IDs in submission do not match ground truth.")

        # Each image must have exactly 6 class rows and class_ids {0..5}
        for df, name in [(submission, "Submission" ), (ground_truth, "Ground truth")]:
            counts = df.groupby("image_id")["class_id"].nunique()
            if not (counts == 6).all():
                bad = counts[counts != 6]
                raise InvalidSubmissionError(
                    f"{name} must contain exactly 6 class rows per image (0..5). Offenders: {bad.to_dict()}"
                )
            by_image = df.groupby("image_id")["class_id"].apply(lambda s: set(s.tolist()))
            for img_id, classes in by_image.items():
                if classes != set(range(6)):
                    raise InvalidSubmissionError(f"{name} for {img_id} must contain class_ids 0..5 exactly once.")

        # Light RLE validation (parseability)
        for enc in submission["encoding"].tolist():
            _parse_rle_pairs(enc)  # will raise if invalid

        return True, "Submission is valid."
