from __future__ import annotations

import csv
import math
from typing import Any, Dict, List, Tuple
from pathlib import Path

import pandas as pd

from mledojo.metrics.base import CompetitionMetrics, InvalidSubmissionError


# Classes used in this competition
CLASSES: List[str] = ["rach", "mop_lom", "tray_son", "mat_bo_phan"]
CLASS_SET = set(CLASSES)


# ------------------------------
# Geometry and parsing utilities
# ------------------------------

def _iou(b1: Tuple[float, float, float, float], b2: Tuple[float, float, float, float]) -> float:
    x1 = max(b1[0], b2[0])
    y1 = max(b1[1], b2[1])
    x2 = min(b1[2], b2[2])
    y2 = min(b1[3], b2[3])
    inter_w = max(0.0, x2 - x1)
    inter_h = max(0.0, y2 - y1)
    inter = inter_w * inter_h
    if inter <= 0:
        return 0.0
    a1 = max(0.0, (b1[2] - b1[0])) * max(0.0, (b1[3] - b1[1]))
    a2 = max(0.0, (b2[2] - b2[0])) * max(0.0, (b2[3] - b2[1]))
    denom = a1 + a2 - inter
    if denom <= 0:
        return 0.0
    return float(inter / denom)


def _parse_prediction_string(pred_str: Any) -> List[Tuple[str, float, float, float, float, float]]:
    # Coerce NaN/non-string to empty
    if not isinstance(pred_str, str):
        try:
            # Some CSVs may read empty cell as float('nan') which is truthy
            if isinstance(pred_str, float) and math.isnan(pred_str):
                pred_str = ""
            else:
                pred_str = str(pred_str)
        except Exception:
            pred_str = ""
    pred_str = (pred_str or "").strip()
    if not pred_str:
        return []
    parts = pred_str.split()
    if len(parts) % 6 != 0:
        parts = parts[: len(parts) // 6 * 6]
    out: List[Tuple[str, float, float, float, float, float]] = []
    for i in range(0, len(parts), 6):
        cls = parts[i]
        try:
            score = float(parts[i + 1])
            x1 = float(parts[i + 2])
            y1 = float(parts[i + 3])
            x2 = float(parts[i + 4])
            y2 = float(parts[i + 5])
        except Exception:
            continue
        if cls not in CLASS_SET:
            continue
        if math.isnan(score) or math.isinf(score):
            continue
        score = max(0.0, min(1.0, score))
        if not (x2 > x1 and y2 > y1):
            continue
        out.append((cls, score, x1, y1, x2, y2))
    return out


def _average_precision(recalls: List[float], precisions: List[float]) -> float:
    if not recalls:
        return 0.0
    paired = sorted(zip(recalls, precisions))
    recalls_sorted = [p[0] for p in paired]
    precisions_sorted = [p[1] for p in paired]
    for i in range(len(precisions_sorted) - 2, -1, -1):
        precisions_sorted[i] = max(precisions_sorted[i], precisions_sorted[i + 1])
    ap = 0.0
    prev_r = 0.0
    for r, p in zip(recalls_sorted, precisions_sorted):
        r_clamped = max(0.0, min(1.0, r))
        dr = r_clamped - prev_r
        if dr > 0:
            ap += p * dr
            prev_r = r_clamped
    return float(max(0.0, min(1.0, ap)))


def _compute_map50_from_csv(gt_csv: str, sub_csv: str) -> float:
    # Load GT boxes grouped by image and class
    gt_by_img_cls: Dict[Tuple[str, str], List[Tuple[float, float, float, float]]] = {}
    gt_images = set()
    with open(gt_csv, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        assert reader.fieldnames == ["image_id", "class", "x_min", "y_min", "x_max", "y_max"], (
            f"Unexpected ground truth columns: {reader.fieldnames}"
        )
        for row in reader:
            iid = row["image_id"]
            cls = row["class"]
            try:
                x1 = float(row["x_min"])
                y1 = float(row["y_min"])
                x2 = float(row["x_max"])
                y2 = float(row["y_max"])
            except Exception:
                continue
            if cls not in CLASS_SET:
                continue
            if not (x2 > x1 and y2 > y1):
                continue
            gt_images.add(iid)
            gt_by_img_cls.setdefault((iid, cls), []).append((x1, y1, x2, y2))

    # Load predictions grouped by image
    preds_by_img: Dict[str, List[Tuple[str, float, float, float, float, float]]] = {}
    sub_images = set()
    with open(sub_csv, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        assert reader.fieldnames == ["image_id", "PredictionString"], (
            f"Unexpected submission columns: {reader.fieldnames}"
        )
        for row in reader:
            iid = row["image_id"]
            sub_images.add(iid)
            preds = _parse_prediction_string(row.get("PredictionString", ""))
            preds.sort(key=lambda x: x[1], reverse=True)
            preds_by_img[iid] = preds

    all_images = sub_images.union(gt_images)

    aps: List[float] = []
    for cls in CLASSES:
        scored_flags: List[Tuple[float, int]] = []
        matched = {iid: [False] * len(gt_by_img_cls.get((iid, cls), [])) for iid in all_images}

        for iid in all_images:
            gt_boxes = gt_by_img_cls.get((iid, cls), [])
            preds = [p for p in preds_by_img.get(iid, []) if p[0] == cls]
            for (_, score, px1, py1, px2, py2) in preds:
                best_iou = 0.0
                best_j = -1
                for j, g in enumerate(gt_boxes):
                    if matched[iid][j]:
                        continue
                    cur_iou = _iou((px1, py1, px2, py2), g)
                    if cur_iou > best_iou:
                        best_iou = cur_iou
                        best_j = j
                if best_iou >= 0.5 and best_j >= 0:
                    matched[iid][best_j] = True
                    scored_flags.append((score, 1))
                else:
                    scored_flags.append((score, 0))

        scored_flags.sort(key=lambda x: x[0], reverse=True)
        total_gt = sum(len(v) for (i, c), v in gt_by_img_cls.items() if c == cls)
        if total_gt == 0:
            continue
        if not scored_flags:
            aps.append(0.0)
            continue

        tp = 0
        fp = 0
        precisions: List[float] = []
        recalls: List[float] = []
        for score, is_tp in scored_flags:
            if is_tp:
                tp += 1
            else:
                fp += 1
            prec = tp / max(1, (tp + fp))
            rec = tp / max(1, total_gt)
            precisions.append(float(max(0.0, min(1.0, prec))))
            recalls.append(float(max(0.0, min(1.0, rec))))
        ap = _average_precision(recalls, precisions)
        aps.append(ap)

    if not aps:
        return 0.0
    return float(max(0.0, min(1.0, sum(aps) / len(aps))))


# ------------------------------
# Metric implementation class
# ------------------------------

class VehicleDamageDetectionMetrics(CompetitionMetrics):
    """mAP@0.50 for Vehicle Damage Detection submissions.

    Expected CSV schemas:
    - Ground truth (test_answer.csv): columns [image_id, class, x_min, y_min, x_max, y_max]
    - Submission (sample_submission.csv): columns [image_id, PredictionString]
    """

    def __init__(self, value: str = "map50", higher_is_better: bool = True):
        super().__init__(higher_is_better)
        self.value = value

    def evaluate(self, y_true: Any, y_pred: Any) -> float:
        # Accept file paths (str/Path) or DataFrames
        if isinstance(y_true, pd.DataFrame):
            gt_path = "/tmp/__gt.csv"
            y_true.to_csv(gt_path, index=False)
        else:
            gt_path = str(y_true) if isinstance(y_true, (str, Path)) else None

        if isinstance(y_pred, pd.DataFrame):
            sub_path = "/tmp/__sub.csv"
            y_pred.to_csv(sub_path, index=False)
        else:
            sub_path = str(y_pred) if isinstance(y_pred, (str, Path)) else None

        if not gt_path or not sub_path:
            raise InvalidSubmissionError("evaluate expects file paths or pandas DataFrames for y_true and y_pred")

        return _compute_map50_from_csv(gt_path, sub_path)

    def validate_submission(self, submission: Any, ground_truth: Any) -> tuple[bool, str]:
        # Coerce to DataFrame for schema checks
        if isinstance(submission, (str, Path)):
            try:
                submission_df = pd.read_csv(str(submission))
            except Exception as e:
                return False, f"Failed to read submission CSV: {e}"
        elif isinstance(submission, pd.DataFrame):
            submission_df = submission.copy()
        else:
            return False, "Submission must be a file path or pandas DataFrame."

        if isinstance(ground_truth, (str, Path)):
            try:
                gt_df = pd.read_csv(str(ground_truth))
            except Exception as e:
                return False, f"Failed to read ground-truth CSV: {e}"
        elif isinstance(ground_truth, pd.DataFrame):
            gt_df = ground_truth.copy()
        else:
            return False, "Ground truth must be a file path or pandas DataFrame."

        # Expected columns
        expected_sub_cols = ["image_id", "PredictionString"]
        expected_gt_cols = ["image_id", "class", "x_min", "y_min", "x_max", "y_max"]

        if list(submission_df.columns) != expected_sub_cols:
            return False, (
                f"Submission must have columns {expected_sub_cols}, found {submission_df.columns.tolist()}"
            )
        if list(gt_df.columns) != expected_gt_cols:
            return False, (
                f"Ground truth must have columns {expected_gt_cols}, found {gt_df.columns.tolist()}"
            )

        # image_id checks
        sub_ids = submission_df["image_id"].tolist()
        if len(sub_ids) != len(set(sub_ids)):
            return False, "Duplicate image_id entries found in submission. Each image_id must appear exactly once."

        gt_ids = sorted(gt_df["image_id"].unique().tolist())
        # Ground-truth may not list images with zero objects. Require that all GT image_ids are present in submission.
        if not set(gt_ids).issubset(set(sub_ids)):
            return False, "Submission is missing some image_ids present in ground truth."

        # Validate PredictionString format per row (but do not require any predictions)
        for _, row in submission_df.iterrows():
            _ = _parse_prediction_string(row.get("PredictionString", ""))

        return True, "Submission is valid."
