from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError

# Types
Box = Tuple[float, float, float, float]  # (x1, y1, x2, y2) in [0,1]
Pred = Tuple[float, Box]  # (score, box)


def _is_finite_number(x: Any) -> bool:
    try:
        return math.isfinite(float(x))
    except Exception:
        return False


def _parse_pred_string(s: str) -> List[Pred]:
    s = (s or "").strip()
    if s == "":
        return []
    parts = s.split()
    if len(parts) % 5 != 0:
        raise InvalidSubmissionError(
            "PredictionString must contain 5*k numbers: score x1 y1 x2 y2 repeated"
        )
    out: List[Pred] = []
    for i in range(0, len(parts), 5):
        score, x1, y1, x2, y2 = parts[i : i + 5]
        # Validate numeric and finite
        for v in (score, x1, y1, x2, y2):
            if not _is_finite_number(v):
                raise InvalidSubmissionError("Non-finite value in PredictionString")
        sc = float(score)
        x1 = float(x1)
        y1 = float(y1)
        x2 = float(x2)
        y2 = float(y2)
        # Clamp to [0,1]
        x1 = max(0.0, min(1.0, x1))
        y1 = max(0.0, min(1.0, y1))
        x2 = max(0.0, min(1.0, x2))
        y2 = max(0.0, min(1.0, y2))
        # Ensure proper ordering
        if x2 < x1:
            x1, x2 = x2, x1
        if y2 < y1:
            y1, y2 = y2, y1
        # Discard degenerate boxes with zero area
        if (x2 - x1) <= 0.0 or (y2 - y1) <= 0.0:
            continue
        # Scores must be in [0, 1]
        if sc < 0.0:
            sc = 0.0
        if sc > 1.0:
            sc = 1.0
        out.append((sc, (x1, y1, x2, y2)))
    # Sort by score desc for stable evaluation
    out.sort(key=lambda t: t[0], reverse=True)
    return out


def _parse_gt_string(s: str) -> List[Box]:
    s = (s or "").strip()
    if s == "":
        return []
    parts = s.split()
    if len(parts) % 4 != 0:
        raise InvalidSubmissionError(
            "Ground truth 'boxes' must contain 4*k numbers: x1 y1 x2 y2 repeated"
        )
    out: List[Box] = []
    for i in range(0, len(parts), 4):
        x1, y1, x2, y2 = parts[i : i + 4]
        for v in (x1, y1, x2, y2):
            if not _is_finite_number(v):
                raise InvalidSubmissionError("Non-finite value in ground truth boxes")
        x1 = max(0.0, min(1.0, float(x1)))
        y1 = max(0.0, min(1.0, float(y1)))
        x2 = max(0.0, min(1.0, float(x2)))
        y2 = max(0.0, min(1.0, float(y2)))
        if x2 < x1:
            x1, x2 = x2, x1
        if y2 < y1:
            y1, y2 = y2, y1
        if (x2 - x1) <= 0.0 or (y2 - y1) <= 0.0:
            continue
        out.append((x1, y1, x2, y2))
    return out


def _iou(a: Box, b: Box) -> float:
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    ix1 = max(ax1, bx1)
    iy1 = max(ay1, by1)
    ix2 = min(ax2, bx2)
    iy2 = min(ay2, by2)
    iw = max(0.0, ix2 - ix1)
    ih = max(0.0, iy2 - iy1)
    inter = iw * ih
    if inter <= 0.0:
        return 0.0
    area_a = max(0.0, (ax2 - ax1)) * max(0.0, (ay2 - ay1))
    area_b = max(0.0, (bx2 - bx1)) * max(0.0, (by2 - by1))
    denom = area_a + area_b - inter
    if denom <= 0.0:
        return 0.0
    return inter / denom


class Sku110kDetectionMetrics(CompetitionMetrics):
    """mAP@0.50:0.95 evaluator for SKU110K-style object detection.

    Expected input formats:
    - y_true: pandas.DataFrame with columns ['image_id', 'boxes'] where 'boxes' is
      a whitespace-separated list of normalized xyxy coords (x1 y1 x2 y2 repeated).
    - y_pred: pandas.DataFrame with columns ['image_id', 'PredictionString'] where
      'PredictionString' is 'score x1 y1 x2 y2' repeated.
    """

    def __init__(self, value: str = "PredictionString", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        if not isinstance(y_true, pd.DataFrame):
            raise InvalidSubmissionError("y_true must be a pandas DataFrame")
        if not isinstance(y_pred, pd.DataFrame):
            raise InvalidSubmissionError("y_pred must be a pandas DataFrame")

        # Validate columns
        required_true_cols = {"image_id", "boxes"}
        required_pred_cols = {"image_id", "PredictionString"}
        if not required_true_cols.issubset(set(y_true.columns)):
            missing = required_true_cols - set(y_true.columns)
            raise InvalidSubmissionError(
                f"y_true missing required columns: {', '.join(sorted(missing))}"
            )
        if not required_pred_cols.issubset(set(y_pred.columns)):
            missing = required_pred_cols - set(y_pred.columns)
            raise InvalidSubmissionError(
                f"y_pred missing required columns: {', '.join(sorted(missing))}"
            )

        # Sort and align
        y_true = y_true.sort_values("image_id").reset_index(drop=True)
        y_pred = y_pred.sort_values("image_id").reset_index(drop=True)

        true_ids = y_true["image_id"].astype(str).tolist()
        pred_ids = y_pred["image_id"].astype(str).tolist()
        if set(true_ids) != set(pred_ids):
            raise InvalidSubmissionError("image_id sets do not match between y_true and y_pred")
        if len(set(pred_ids)) != len(pred_ids):
            raise InvalidSubmissionError("Duplicate image_id rows found in y_pred")

        # Build GT map and predictions map
        gt_map: Dict[str, List[Box]] = {}
        for _, row in y_true.iterrows():
            img_id = str(row["image_id"]).strip()
            gt_map[img_id] = _parse_gt_string(str(row["boxes"]))

        preds_map: Dict[str, List[Pred]] = {}
        for _, row in y_pred.iterrows():
            img_id = str(row["image_id"]).strip()
            preds_map[img_id] = _parse_pred_string(str(row["PredictionString"]))

        # Evaluation thresholds 0.50:0.95 step 0.05
        iou_thresholds: List[float] = [x / 100.0 for x in range(50, 100, 5)]

        test_ids = sorted(true_ids)
        aps: List[float] = []
        for thr in iou_thresholds:
            # Accumulate all predictions across images
            all_preds: List[Tuple[float, str, Box]] = []
            total_gt = 0
            for img_id in test_ids:
                total_gt += len(gt_map.get(img_id, []))
                for (sc, box) in preds_map.get(img_id, []):
                    all_preds.append((sc, img_id, box))
            if total_gt == 0:
                aps.append(0.0)
                continue

            # Sort predictions by score desc
            all_preds.sort(key=lambda t: t[0], reverse=True)

            # For each image, track which GT boxes are already matched
            matched: Dict[str, List[bool]] = {img_id: [False] * len(gt_map.get(img_id, [])) for img_id in test_ids}

            tps: List[int] = []
            fps: List[int] = []
            for sc, img_id, pbox in all_preds:
                gts = gt_map.get(img_id, [])
                # Find best IoU among unmatched GTs
                best_iou = 0.0
                best_j = -1
                for j, gt in enumerate(gts):
                    if matched[img_id][j]:
                        continue
                    iou = _iou(pbox, gt)
                    if iou > best_iou:
                        best_iou = iou
                        best_j = j
                if best_iou >= thr and best_j >= 0:
                    matched[img_id][best_j] = True
                    tps.append(1)
                    fps.append(0)
                else:
                    tps.append(0)
                    fps.append(1)

            # Compute precision-recall
            if len(tps) == 0:
                aps.append(0.0)
                continue
            cum_tp: List[int] = []
            cum_fp: List[int] = []
            ctp = 0
            cfp = 0
            for tp, fp in zip(tps, fps):
                ctp += tp
                cfp += fp
                cum_tp.append(ctp)
                cum_fp.append(cfp)
            recalls = [ctp_i / total_gt for ctp_i in cum_tp]
            precisions = [ctp_i / max(1, (ctp_i + cfp_i)) for ctp_i, cfp_i in zip(cum_tp, cum_fp)]

            # Precision envelope
            for i in range(len(precisions) - 2, -1, -1):
                if precisions[i] < precisions[i + 1]:
                    precisions[i] = precisions[i + 1]

            # Integrate AP over recall changes
            ap = 0.0
            prev_recall = 0.0
            for i in range(len(recalls)):
                r = recalls[i]
                p = precisions[i]
                if r < prev_recall:
                    r = prev_recall
                delta = r - prev_recall
                if delta > 0:
                    ap += delta * p
                    prev_recall = r
            aps.append(ap)

        mAP = sum(aps) / len(aps) if aps else 0.0
        if not math.isfinite(mAP):
            mAP = 0.0
        return max(0.0, min(1.0, mAP))

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        if not isinstance(submission, pd.DataFrame):
            raise InvalidSubmissionError("Submission must be a pandas DataFrame")
        if not isinstance(ground_truth, pd.DataFrame):
            raise InvalidSubmissionError("Ground truth must be a pandas DataFrame")

        # Columns check
        sub_cols = set(submission.columns)
        gt_cols = set(ground_truth.columns)
        required_sub = {"image_id", "PredictionString"}
        required_gt = {"image_id", "boxes"}
        missing_sub = required_sub - sub_cols
        missing_gt = required_gt - gt_cols
        if missing_sub:
            raise InvalidSubmissionError(
                f"Missing required columns in submission: {', '.join(sorted(missing_sub))}."
            )
        if missing_gt:
            raise InvalidSubmissionError(
                f"Missing required columns in ground truth: {', '.join(sorted(missing_gt))}."
            )

        # Row counts and id alignment
        sub_ids = submission["image_id"].astype(str).tolist()
        gt_ids = ground_truth["image_id"].astype(str).tolist()
        if len(set(sub_ids)) != len(sub_ids):
            raise InvalidSubmissionError("Duplicate image_id rows in submission")
        if set(sub_ids) != set(gt_ids):
            raise InvalidSubmissionError(
                "Submission image_ids must exactly match ground truth image_ids"
            )

        # Parse to ensure syntactic validity
        for _, row in submission.iterrows():
            _ = _parse_pred_string(str(row["PredictionString"]))
        for _, row in ground_truth.iterrows():
            _ = _parse_gt_string(str(row["boxes"]))

        return True, "Submission is valid."
