from typing import Any, Dict, List, Tuple

import math
import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


class HardhatVestDetectionMetrics(CompetitionMetrics):
    """
    COCO-style mAP(0.50:0.95) for object detection with YOLO-normalized boxes.

    Expected inputs
    - y_true: pd.DataFrame with columns [image_id, class_id, cx, cy, w, h]
    - y_pred: pd.DataFrame with columns [image_id, PredictionString]
      where PredictionString is: "class confidence cx cy w h class confidence cx cy w h ..."
      Coordinates are normalized in [0,1]. Confidence in (0,1].
    """

    def __init__(self, value: str = "PredictionString", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    # -----------------------------
    # Public API required by base
    # -----------------------------
    def evaluate(self, y_true: pd.DataFrame, y_pred: pd.DataFrame) -> float:
        self._check_ground_truth_df(y_true)
        self._check_submission_df(y_pred)

        # validate correspondence between ids
        ok, msg = self.validate_submission(y_pred, y_true)
        if not ok:
            raise InvalidSubmissionError(msg)

        # derive number of classes from y_true
        if len(y_true) == 0:
            return 0.0
        num_classes = int(y_true["class_id"].max()) + 1

        # build dicts
        gts = self._ground_truth_to_dict(y_true)
        preds = self._submission_to_dict(y_pred, num_classes)

        mAP, _ = self._evaluate_map(preds, gts, num_classes)
        # clamp to [0,1]
        if not math.isfinite(mAP):
            mAP = 0.0
        return float(max(0.0, min(1.0, mAP)))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        try:
            if not isinstance(submission, pd.DataFrame):
                raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
            if not isinstance(ground_truth, pd.DataFrame):
                raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")

            self._check_submission_df(submission)
            self._check_ground_truth_df(ground_truth)

            # Image ID checks: unique ids, non-empty
            sub_ids = submission["image_id"].astype(str).tolist()
            if len(sub_ids) != len(set(sub_ids)):
                raise InvalidSubmissionError("Duplicate image_id in submission.")
            if len(sub_ids) == 0:
                raise InvalidSubmissionError("Submission is empty.")
            # sanity: no path separators
            if any(("/" in x) or ("\\" in x) for x in sub_ids):
                raise InvalidSubmissionError("image_id must be file name only, not a path.")

            # If ground truth may omit images with zero objects, don't require strict equality.
            # Require that all image_ids present in ground_truth appear in submission.
            true_ids = set(ground_truth["image_id"].astype(str).unique().tolist())
            missing_in_sub = true_ids - set(sub_ids)
            if missing_in_sub:
                raise InvalidSubmissionError(
                    f"Submission is missing ids present in ground truth: {sorted(list(missing_in_sub))[:5]}"
                )

            # Additionally, attempt to parse all PredictionString rows for basic format validity
            num_classes = int(ground_truth["class_id"].max()) + 1 if len(ground_truth) else 1
            _ = self._submission_to_dict(submission, num_classes)

            return True, "Submission is valid."
        except InvalidSubmissionError as e:
            return False, str(e)

    # -----------------------------
    # Internal helpers
    # -----------------------------
    @staticmethod
    def _check_ground_truth_df(df: pd.DataFrame):
        required = {"image_id", "class_id", "cx", "cy", "w", "h"}
        if not isinstance(df, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame.")
        missing = required - set(df.columns)
        if missing:
            raise InvalidSubmissionError(f"Ground truth missing columns: {sorted(list(missing))}")

    @staticmethod
    def _check_submission_df(df: pd.DataFrame):
        required = {"image_id", "PredictionString"}
        if not isinstance(df, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame.")
        missing = required - set(df.columns)
        if missing:
            raise InvalidSubmissionError(f"Submission missing columns: {sorted(list(missing))}")

    @staticmethod
    def _parse_prediction_string(s: str, num_classes: int) -> List[Tuple[int, float, float, float, float, float]]:
        if s is None:
            return []
        s = str(s).strip()
        if s == "" or s.lower() == "nan":
            return []
        parts = s.split()
        if len(parts) % 6 != 0:
            raise InvalidSubmissionError(
                "PredictionString length must be a multiple of 6: class conf cx cy w h ..."
            )
        out: List[Tuple[int, float, float, float, float, float]] = []
        for i in range(0, len(parts), 6):
            try:
                cls = int(float(parts[i]))
                conf = float(parts[i + 1])
                cx = float(parts[i + 2])
                cy = float(parts[i + 3])
                w = float(parts[i + 4])
                h = float(parts[i + 5])
            except Exception as e:  # noqa: F841
                # malformed token -> invalid submission
                raise InvalidSubmissionError("Malformed numeric values in PredictionString.")

            if not (0 <= cls < max(1, num_classes)):
                raise InvalidSubmissionError(f"Class id out of range in PredictionString: {cls}")
            if not (0.0 < conf <= 1.0):
                raise InvalidSubmissionError("Confidence must be in (0,1].")
            # clip to [0,1]
            cx = min(1.0, max(0.0, cx))
            cy = min(1.0, max(0.0, cy))
            w = min(1.0, max(0.0, w))
            h = min(1.0, max(0.0, h))
            if w <= 0.0 or h <= 0.0:
                # degenerate -> ignore (not added)
                continue
            out.append((cls, conf, cx, cy, w, h))
        return out

    def _submission_to_dict(
        self, df: pd.DataFrame, num_classes: int
    ) -> Dict[str, List[Tuple[int, float, float, float, float, float]]]:
        preds: Dict[str, List[Tuple[int, float, float, float, float, float]]] = {}
        for _, row in df.iterrows():
            img_id = row["image_id"]
            pred_str = row.get(self.value, row.get("PredictionString", ""))
            preds[img_id] = self._parse_prediction_string(pred_str, num_classes)
        return preds

    @staticmethod
    def _ground_truth_to_dict(
        df: pd.DataFrame,
    ) -> Dict[str, List[Tuple[int, float, float, float, float]]]:
        gts: Dict[str, List[Tuple[int, float, float, float, float]]] = {}
        for _, r in df.iterrows():
            gts.setdefault(r["image_id"], []).append(
                (
                    int(r["class_id"]),
                    float(r["cx"]),
                    float(r["cy"]),
                    float(r["w"]),
                    float(r["h"]),
                )
            )
        return gts

    # ---------- mAP implementation ----------
    @staticmethod
    def _cxcywh_to_xyxy(cx: float, cy: float, w: float, h: float) -> Tuple[float, float, float, float]:
        x1 = cx - w / 2.0
        y1 = cy - h / 2.0
        x2 = cx + w / 2.0
        y2 = cy + h / 2.0
        return (x1, y1, x2, y2)

    @staticmethod
    def _iou_xyxy(a: Tuple[float, float, float, float], b: Tuple[float, float, float, float]) -> float:
        ax1, ay1, ax2, ay2 = a
        bx1, by1, bx2, by2 = b
        inter_x1 = max(ax1, bx1)
        inter_y1 = max(ay1, by1)
        inter_x2 = min(ax2, bx2)
        inter_y2 = min(ay2, by2)
        iw = max(0.0, inter_x2 - inter_x1)
        ih = max(0.0, inter_y2 - inter_y1)
        inter = iw * ih
        if inter <= 0.0:
            return 0.0
        area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1)
        area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1)
        denom = area_a + area_b - inter
        if denom <= 0.0:
            return 0.0
        return inter / denom

    @staticmethod
    def _compute_ap(recalls: List[float], precisions: List[float]) -> float:
        # Precision envelope
        mrec = [0.0] + list(recalls) + [1.0]
        mpre = [0.0] + list(precisions) + [0.0]
        for i in range(len(mpre) - 2, -1, -1):
            mpre[i] = max(mpre[i], mpre[i + 1])
        # 101-point interpolation
        ap = 0.0
        for t in [i / 100.0 for i in range(0, 101)]:
            p = 0.0
            for r, pr in zip(mrec, mpre):
                if r >= t:
                    p = pr
                    break
            ap += p
        ap /= 101.0
        return ap

    def _evaluate_map(
        self,
        preds: Dict[str, List[Tuple[int, float, float, float, float, float]]],
        gts: Dict[str, List[Tuple[int, float, float, float, float]]],
        num_classes: int,
        iou_thresholds: List[float] | None = None,
    ) -> Tuple[float, Dict[int, float]]:
        if iou_thresholds is None:
            iou_thresholds = [0.50 + 0.05 * i for i in range(10)]

        # Build per-image GT by class and convert to xyxy
        gt_by_img_cls: Dict[str, Dict[int, List[Tuple[Tuple[float, float, float, float], bool]]]] = {}
        for img_id, items in gts.items():
            d: Dict[int, List[Tuple[Tuple[float, float, float, float], bool]]] = {}
            for (cls, cx, cy, w, h) in items:
                box = self._cxcywh_to_xyxy(cx, cy, w, h)
                d.setdefault(cls, []).append((box, False))
            gt_by_img_cls[img_id] = d

        # Build per-class prediction list: (conf, img_id, box)
        preds_by_cls: Dict[int, List[Tuple[float, str, Tuple[float, float, float, float]]]] = {
            c: [] for c in range(num_classes)
        }
        for img_id, items in preds.items():
            for (cls, conf, cx, cy, w, h) in items:
                preds_by_cls.setdefault(cls, []).append(
                    (conf, img_id, self._cxcywh_to_xyxy(cx, cy, w, h))
                )

        ap_per_class: Dict[int, float] = {}
        aps = []

        for c in range(num_classes):
            # number of GT for class c
            num_gt = sum(
                1 for img in gt_by_img_cls.values() for k, v in img.items() if k == c for _ in v
            )
            if num_gt == 0:
                continue  # skip from averaging

            preds_c = sorted(preds_by_cls.get(c, []), key=lambda x: x[0], reverse=True)
            if len(preds_c) == 0:
                ap_per_class[c] = 0.0
                aps.append(0.0)
                continue

            ap_t_list = []
            for thr in iou_thresholds:
                tp: List[float] = []
                fp: List[float] = []
                # Reset matched flags per image per threshold
                matched_flags: Dict[str, List[bool]] = {}
                for img_id, d in gt_by_img_cls.items():
                    g = d.get(c, [])
                    matched_flags[img_id] = [False] * len(g)

                for conf, img_id, pbox in preds_c:
                    g = gt_by_img_cls.get(img_id, {}).get(c, [])
                    best_iou = 0.0
                    best_j = -1
                    for j, (gbox, used) in enumerate(g):
                        if used:
                            continue
                        iou = self._iou_xyxy(pbox, gbox)
                        if iou > best_iou:
                            best_iou = iou
                            best_j = j
                    if best_iou >= thr and best_j >= 0 and not matched_flags[img_id][best_j]:
                        tp.append(1.0)
                        fp.append(0.0)
                        matched_flags[img_id][best_j] = True
                    else:
                        tp.append(0.0)
                        fp.append(1.0)

                # Cumulate
                cum_tp: List[float] = []
                cum_fp: List[float] = []
                s_tp = 0.0
                s_fp = 0.0
                for tpi, fpi in zip(tp, fp):
                    s_tp += tpi
                    s_fp += fpi
                    cum_tp.append(s_tp)
                    cum_fp.append(s_fp)
                recalls = [ct / max(num_gt, 1e-12) for ct in cum_tp]
                precisions = [
                    ct / max(ct + cf, 1e-12) for ct, cf in zip(cum_tp, cum_fp)
                ]
                ap_t = self._compute_ap(recalls, precisions)
                ap_t_list.append(ap_t)
            ap_c = sum(ap_t_list) / len(ap_t_list) if ap_t_list else 0.0
            ap_per_class[c] = ap_c
            aps.append(ap_c)

        if len(aps) == 0:
            return 0.0, {}
        mAP = sum(aps) / len(aps)
        return mAP, ap_per_class
