from typing import Any, List
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


def _parse_labels_str(labels_str: str) -> List[int]:
    if not isinstance(labels_str, str):
        # allow lists already
        if isinstance(labels_str, list):
            return [int(x) for x in labels_str]
        raise InvalidSubmissionError("Labels must be a space-separated string of integers per row.")
    s = labels_str.strip()
    if s == "":
        return []
    out: List[int] = []
    for tok in s.split(" "):
        if tok == "":
            continue
        try:
            v = int(tok)
        except Exception:
            v = int(float(tok))
        if v < 0:
            v = 0
        out.append(v)
    return out


def _iou_per_shape(pred: List[int], true: List[int]) -> float:
    assert len(pred) == len(true), "Prediction and ground-truth label lists must have equal length for each id."
    labels_union = set(true) | set(pred)
    if not labels_union:
        return 1.0
    ious = []
    for c in labels_union:
        inter = sum(1 for p, t in zip(pred, true) if p == c and t == c)
        union = sum(1 for p, t in zip(pred, true) if (p == c) or (t == c))
        ious.append(0.0 if union == 0 else inter / union)
    return sum(ious) / len(ious)


class ShapenetPartSegmentationMetrics(CompetitionMetrics):
    """Mean IoU over shapes for ShapeNetPart-style per-point segmentation.

    Expects two pandas DataFrames, representing:
    - y_true: test_answer.csv with columns [id, labels]
    - y_pred: submission.csv with columns [id, labels]

    The labels column is a space-separated string of integer part labels with
    length equal to the number of points for that id (as specified in public/test.csv).
    """

    def __init__(self, value: str = "labels", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        # Basic validation first (raises InvalidSubmissionError with helpful message)
        self.validate_submission(y_pred, y_true)

        # Sort by id for aligned computation
        y_true = y_true.sort_values(by=y_true.columns[0]).reset_index(drop=True)
        y_pred = y_pred.sort_values(by=y_pred.columns[0]).reset_index(drop=True)

        scores = []
        for (_, true_row), (_, pred_row) in zip(y_true.iterrows(), y_pred.iterrows()):
            true_labels = _parse_labels_str(true_row[self.value])
            pred_labels = _parse_labels_str(pred_row[self.value])
            scores.append(_iou_per_shape(pred_labels, true_labels))
        return float(sum(scores) / len(scores)) if scores else 0.0

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        # Type checks
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError(
                "Submission must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError(
                "Ground truth must be a pandas DataFrame. Please provide a valid pandas DataFrame."
            )

        # Column presence
        for df, name in [(submission, "submission"), (ground_truth, "ground truth")]:
            if df.shape[1] < 2:
                raise InvalidSubmissionError(f"{name.title()} must have at least two columns: id and {self.value}.")
            needed_cols = {submission.columns[0], self.value}
            if not needed_cols.issubset(df.columns):
                raise InvalidSubmissionError(
                    f"{name.title()} must contain columns: id and {self.value}. Found: {list(df.columns)}"
                )

        # Row count check
        if len(submission) != len(ground_truth):
            raise InvalidSubmissionError(
                f"Number of rows in submission ({len(submission)}) does not match ground truth ({len(ground_truth)})."
            )

        # Sort for comparison
        submission_sorted = submission.sort_values(by=submission.columns[0]).reset_index(drop=True)
        ground_sorted = ground_truth.sort_values(by=ground_truth.columns[0]).reset_index(drop=True)

        # IDs must match exactly
        sub_ids = submission_sorted.iloc[:, 0].astype(str).tolist()
        true_ids = ground_sorted.iloc[:, 0].astype(str).tolist()
        if sub_ids != true_ids:
            raise InvalidSubmissionError("First column id values do not match the ground truth ids exactly.")

        # Check labels are parseable and lengths are consistent with ground truth
        for i in range(len(submission_sorted)):
            pred_list = _parse_labels_str(str(submission_sorted.iloc[i][self.value]))
            true_list = _parse_labels_str(str(ground_sorted.iloc[i][self.value]))
            if len(pred_list) != len(true_list):
                rid = true_ids[i]
                raise InvalidSubmissionError(
                    f"For id {rid}, number of predicted labels ({len(pred_list)}) does not match ground truth ({len(true_list)})."
                )
        return "Submission is valid."
