from typing import Any, List, Tuple
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class PersonSegmentationMetrics(CompetitionMetrics):
    """
    Metric class for single-class person segmentation using mean Dice coefficient.

    Submission/GT format: CSV with columns [id, rle]
    - rle is Run-Length Encoding in column-major order with 1-indexed starts.
    - Empty mask should be an empty string.
    Note: For robustness, we accept start positions of 0 in submissions as well.
    """

    def __init__(self, value: str = "rle", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    @staticmethod
    def _parse_rle(rle: str) -> List[Tuple[int, int]]:
        """
        Parse an RLE string into a list of (start, length) as integers.
        We accept starts >= 0 (to be lenient with formatting) and lengths > 0.
        Empty string -> empty list.
        """
        if rle is None:
            return []
        rle = str(rle).strip()
        if rle == "":
            return []
        parts = rle.split()
        if len(parts) % 2 != 0:
            raise InvalidSubmissionError("RLE must have even number of integers")
        try:
            starts = [int(x) for x in parts[0::2]]
            lengths = [int(x) for x in parts[1::2]]
        except Exception:
            raise InvalidSubmissionError("RLE must contain only integers")
        if any(l <= 0 for l in lengths):
            raise InvalidSubmissionError("RLE lengths must be positive integers")
        if any(s < 0 for s in starts):
            raise InvalidSubmissionError("RLE starts must be non-negative integers")
        return list(zip(starts, lengths))

    @staticmethod
    def _runs_sum(runs: List[Tuple[int, int]]) -> int:
        return sum(l for _, l in runs)

    @staticmethod
    def _runs_intersection(a: List[Tuple[int, int]], b: List[Tuple[int, int]]) -> int:
        """
        Compute intersection size between two RLE run lists in 1D (half-open [start, start+len)).
        This operates purely in 1D index space and does not require image shape.
        """
        i = j = 0
        inter = 0
        while i < len(a) and j < len(b):
            sa, la = a[i]
            sb, lb = b[j]
            ea = sa + la
            eb = sb + lb
            # overlap of [sa, ea) and [sb, eb)
            start = max(sa, sb)
            end = min(ea, eb)
            if end > start:
                inter += end - start
            # advance the run which ends first
            if ea <= eb:
                i += 1
            else:
                j += 1
        return inter

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Sort both by first column to align
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        # Basic checks
        if len(y_true) != len(y_pred):
            raise InvalidSubmissionError(
                f"Mismatched lengths: y_true={len(y_true)} vs y_pred={len(y_pred)}"
            )
        if (y_true[y_true.columns[0]].values != y_pred[y_pred.columns[0]].values).any():
            raise InvalidSubmissionError("IDs in y_true and y_pred do not match in the same order")

        dices = []
        for (_, gt_row), (_, pr_row) in zip(y_true.iterrows(), y_pred.iterrows()):
            gt_runs = self._parse_rle(gt_row[self.value])
            pr_runs = self._parse_rle(pr_row[self.value])
            s_gt = self._runs_sum(gt_runs)
            s_pr = self._runs_sum(pr_runs)
            if s_gt == 0 and s_pr == 0:
                d = 1.0
            else:
                inter = self._runs_intersection(gt_runs, pr_runs)
                denom = s_gt + s_pr
                d = 0.0 if denom == 0 else (2.0 * inter) / float(denom)
            if not (0.0 <= d <= 1.0):
                d = max(0.0, min(1.0, d))
            dices.append(float(d))

        return float(sum(dices) / len(dices)) if dices else 0.0

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        required_cols = {"id", self.value}
        sub_cols = set(submission.columns)
        true_cols = set(ground_truth.columns)

        # Submission must contain exactly [id, value]
        if sub_cols != required_cols:
            extra = sub_cols - required_cols
            missing = required_cols - sub_cols
            if missing:
                raise InvalidSubmissionError(
                    f"Missing required columns in submission: {', '.join(sorted(missing))}."
                )
            if extra:
                raise InvalidSubmissionError(
                    f"Extra unexpected columns found in submission: {', '.join(sorted(extra))}."
                )

        # Length match
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort and check IDs align exactly
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_truth_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)
        if (
            submission_sorted[submission_sorted.columns[0]].values
            != ground_truth_sorted[ground_truth_sorted.columns[0]].values
        ).any():
            raise InvalidSubmissionError(
                "First column values (IDs) do not match between submission and ground truth."
            )

        # Validate that all RLE strings are syntactically valid
        for rle in submission[self.value].tolist():
            _ = self._parse_rle(rle)

        return "Submission is valid."
