from typing import Any, List, Tuple, Dict
import pandas as pd
import numpy as np

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class RadRoadDetectionMetrics(CompetitionMetrics):
    """
    Object-detection metric for RAD Road Anomaly Detection.

    - Submission format (public/sample_submission.csv):
        columns = ["image_id", "PredictionString"]
        PredictionString is a whitespace-separated sequence of 6-tuples:
            class score x_min y_min x_max y_max
        Example: "HMV 0.92 10 10 40 40 Pedestrian 0.55 50 60 90 120"

    - Ground truth format (private/test_answer.csv):
        columns = ["image_id", "class", "x_min", "y_min", "x_max", "y_max"]
        Multiple rows per image, one per ground-truth box. Images with no GT
        may be absent from this file entirely.

    Evaluation:
        COCO-style mAP averaged over IoU thresholds [0.50, 0.55, ..., 0.95]
        and over all classes that appear in the ground truth.
    """

    def __init__(self, value: str = "PredictionString", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    # ----------------------- Utility helpers -----------------------
    @staticmethod
    def _parse_prediction_string(pred_str: Any,
                                 allowed_classes: List[str]) -> List[Tuple[int, float, int, int, int, int]]:
        """Parse a PredictionString into a list of (cls_id, score, x1, y1, x2, y2).
        Unknown classes are ignored. Invalid groups are skipped.
        """
        if pd.isna(pred_str):
            return []
        s = str(pred_str).strip()
        if not s:
            return []
        parts = s.split()
        if len(parts) % 6 != 0:
            raise InvalidSubmissionError(
                "PredictionString must consist of groups of 6 tokens: 'class score x_min y_min x_max y_max'"
            )
        name_to_id = {c: i for i, c in enumerate(allowed_classes)}
        out: List[Tuple[int, float, int, int, int, int]] = []
        for i in range(0, len(parts), 6):
            cname = parts[i]
            if cname not in name_to_id:
                # Unknown class -> ignore this group
                continue
            try:
                score = float(parts[i + 1])
                x1 = int(float(parts[i + 2]))
                y1 = int(float(parts[i + 3]))
                x2 = int(float(parts[i + 4]))
                y2 = int(float(parts[i + 5]))
            except Exception:
                # Malformed numbers -> skip this group
                continue
            # sanitize bbox ordering and non-negativity
            if x2 < x1:
                x1, x2 = x2, x1
            if y2 < y1:
                y1, y2 = y2, y1
            x1 = max(0, x1); y1 = max(0, y1); x2 = max(0, x2); y2 = max(0, y2)
            score = np.clip(score, 0.0, 1.0)
            out.append((name_to_id[cname], float(score), x1, y1, x2, y2))
        return out

    @staticmethod
    def _iou(a: Tuple[int, int, int, int], b: Tuple[int, int, int, int]) -> float:
        ax1, ay1, ax2, ay2 = a
        bx1, by1, bx2, by2 = b
        inter_x1 = max(ax1, bx1)
        inter_y1 = max(ay1, by1)
        inter_x2 = min(ax2, bx2)
        inter_y2 = min(ay2, by2)
        iw = max(0, inter_x2 - inter_x1 + 1)
        ih = max(0, inter_y2 - inter_y1 + 1)
        inter = iw * ih
        area_a = max(0, ax2 - ax1 + 1) * max(0, ay2 - ay1 + 1)
        area_b = max(0, bx2 - bx1 + 1) * max(0, by2 - by1 + 1)
        union = area_a + area_b - inter
        if union <= 0:
            return 0.0
        return inter / union

    @staticmethod
    def _ap_from_pr(recalls: List[float], precisions: List[float]) -> float:
        # Interpolated precision envelope and integration
        mrec = [0.0] + list(recalls) + [1.0]
        mpre = [0.0] + list(precisions) + [0.0]
        # make precision monotonically non-increasing
        for i in range(len(mpre) - 2, -1, -1):
            mpre[i] = max(mpre[i], mpre[i + 1])
        ap = 0.0
        for i in range(1, len(mrec)):
            ap += (mrec[i] - mrec[i - 1]) * mpre[i]
        return ap

    # ----------------------- Required API -----------------------
    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        if not isinstance(y_true, pd.DataFrame):
            raise ValueError("y_true must be a pandas DataFrame (private/test_answer.csv)")
        if not isinstance(y_pred, pd.DataFrame):
            raise ValueError("y_pred must be a pandas DataFrame (public/sample_submission.csv or submissions)")

        required_true_cols = {"image_id", "class", "x_min", "y_min", "x_max", "y_max"}
        if not required_true_cols.issubset(set(y_true.columns)):
            raise ValueError(f"y_true must contain columns {sorted(required_true_cols)}")
        required_pred_cols = {"image_id", self.value}
        if not required_pred_cols.issubset(set(y_pred.columns)):
            raise ValueError(f"y_pred must contain columns ['image_id', '{self.value}']")

        # allowed classes come from ground truth
        classes = sorted([str(c) for c in y_true["class"].unique()])
        cls_to_id = {c: i for i, c in enumerate(classes)}

        # organize ground-truth boxes per image and class
        gt_by_img_cls: Dict[str, Dict[int, List[Tuple[int, int, int, int]]]] = {}
        npos_per_class: Dict[int, int] = {i: 0 for i in range(len(classes))}
        for _, r in y_true.iterrows():
            iid = str(r["image_id"])  # ensure string
            cname = str(r["class"]) 
            if cname not in cls_to_id:
                # Safety: ignore unknown class rows in GT
                continue
            cid = cls_to_id[cname]
            box = (int(r["x_min"]), int(r["y_min"]), int(r["x_max"]), int(r["y_max"]))
            gt_by_img_cls.setdefault(iid, {}).setdefault(cid, []).append(box)
            npos_per_class[cid] += 1

        # parse predictions per image
        preds_all: List[Tuple[int, str, float, Tuple[int, int, int, int]]] = []  # (cid, iid, score, box)
        for _, r in y_pred.iterrows():
            iid = str(r["image_id"])  # enforce string
            try:
                items = self._parse_prediction_string(r[self.value], classes)
            except InvalidSubmissionError as e:
                # Forward parsing errors during evaluation to be explicit
                raise ValueError(str(e))
            for (cid, score, x1, y1, x2, y2) in items:
                preds_all.append((cid, iid, float(score), (x1, y1, x2, y2)))

        # evaluate AP per class across IoU thresholds
        iou_thresholds = [0.50 + 0.05 * k for k in range(10)]
        ap_per_class: List[float] = []

        # include all classes (even if npos == 0) to average fairly
        for cid in range(len(classes)):
            # gather predictions for this class
            preds_c = [(iid, score, box) for (cc, iid, score, box) in preds_all if cc == cid]
            # sort by confidence desc
            preds_c.sort(key=lambda x: -x[1])

            ap_at_t: List[float] = []
            for thr in iou_thresholds:
                tp = []
                fp = []
                # matched flags per image for GT boxes of this class
                matched: Dict[str, List[bool]] = {}
                for iid in set([iid for (iid, _, _) in preds_c]) | set(gt_by_img_cls.keys()):
                    gts = gt_by_img_cls.get(iid, {}).get(cid, [])
                    matched[iid] = [False] * len(gts)

                npos = npos_per_class.get(cid, 0)

                for iid, score, box in preds_c:
                    gts = gt_by_img_cls.get(iid, {}).get(cid, [])
                    best_iou = 0.0
                    best_j = -1
                    for j, gbox in enumerate(gts):
                        if matched[iid][j]:
                            continue
                        iou_val = self._iou(box, gbox)
                        if iou_val > best_iou:
                            best_iou = iou_val
                            best_j = j
                    if best_iou >= thr and best_j >= 0:
                        tp.append(1); fp.append(0)
                        matched[iid][best_j] = True
                    else:
                        tp.append(0); fp.append(1)

                # precision-recall curve
                tp_c = 0
                fp_c = 0
                precisions = []
                recalls = []
                for t, f in zip(tp, fp):
                    tp_c += t
                    fp_c += f
                    denom = max(1, tp_c + fp_c)
                    precisions.append(tp_c / denom)
                    recalls.append(0.0 if npos == 0 else tp_c / npos)
                ap = self._ap_from_pr(recalls, precisions)
                ap_at_t.append(ap)

            ap_cls = float(np.mean(ap_at_t)) if len(ap_at_t) else 0.0
            ap_per_class.append(ap_cls)

        if not ap_per_class:
            return 0.0
        return float(np.mean(ap_per_class))

    def validate_submission(self, submission: Any, ground_truth: Any) -> str:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

        # expected columns
        if "image_id" not in submission.columns or self.value not in submission.columns:
            raise InvalidSubmissionError(
                f"Submission must contain columns ['image_id', '{self.value}']"
            )
        required_true_cols = {"image_id", "class", "x_min", "y_min", "x_max", "y_max"}
        if not required_true_cols.issubset(set(ground_truth.columns)):
            raise InvalidSubmissionError(
                "Ground truth must contain columns ['image_id','class','x_min','y_min','x_max','y_max']"
            )

        # ensure unique rows per image in submission
        sub_ids = [str(i) for i in submission["image_id"].tolist()]
        if len(set(sub_ids)) != len(sub_ids):
            raise InvalidSubmissionError("Duplicate image_id rows found in submission.")

        # ensure all ground-truth image_ids are covered by submission (may allow extra ids for empty-GT images)
        gt_ids = sorted([str(i) for i in ground_truth["image_id"].unique()])
        if not set(gt_ids).issubset(set(sub_ids)):
            raise InvalidSubmissionError(
                "Submission is missing some image_ids that appear in the ground truth."
            )

        # validate prediction strings
        allowed_classes = sorted([str(c) for c in ground_truth["class"].unique()])
        for iid, s in zip(submission["image_id"], submission[self.value]):
            if pd.isna(s) or str(s).strip() == "":
                # empty prediction allowed
                continue
            parts = str(s).split()
            if len(parts) % 6 != 0:
                raise InvalidSubmissionError(
                    f"Invalid PredictionString for image {iid}: must be groups of 6 tokens."
                )
            for i in range(0, len(parts), 6):
                cname = parts[i]
                if cname not in allowed_classes:
                    raise InvalidSubmissionError(
                        f"Invalid class '{cname}' in PredictionString for image {iid}."
                    )
                try:
                    float(parts[i + 1])
                    float(parts[i + 2]); float(parts[i + 3]); float(parts[i + 4]); float(parts[i + 5])
                except Exception:
                    raise InvalidSubmissionError(
                        f"Non-numeric score/coordinates in PredictionString for image {iid}."
                    )
        return "Submission is valid."
