import logging
import re
import math
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles
from pytorch3d import _C

from datasets import Dataset

eval_logger = logging.getLogger("lmms-eval")

REFBBOX3D_REC_METRICS = ["IoU", "ACC@0.1", "ACC@0.3", "ACC@0.5", "ACC@0.7", "ACC@0.9"]


def refbbox3d_rec_preprocess_dataset(dataset: Dataset, start: int = 0, end: int = 100):
    """
    Evenly interleave samples from mixed datasets in a round-robin order, then slice by [start, end) to support batching cases (e.g., 0-100, 100-200).

    Args:
        dataset: Original mixed Dataset
        start: Start index (inclusive) in the unified order
        end: End index (exclusive) in the unified order

    Returns:
        Processed Dataset (applies debug_map, expands class_name into answer column).
    """

    def debug_map(x):
        if isinstance(x["image"], bytes):
            from PIL import Image
            from io import BytesIO
            img = Image.open(BytesIO(x["image"]))
            return {
                **x,
                "image": img,
                "image_width": img.width,
                "image_height": img.height
            }
        elif hasattr(x["image"], "width"):
            return {
                **x,
                "image": x["image"],
                "image_width": x["image"].width,
                "image_height": x["image"].height
            }
        else:
            print("Unknown Type:", type(x["image"]))
            return {
                **x,
                "image": x["image"],
                "image_width": -1,
                "image_height": -1
            }

    def _infer_dataset_name_for_ds_item(d):
        candidates = [
            d.get("image_path"),
            d.get("image_path_str"),
            d.get("image_file"),
            d.get("image_id"),
            d.get("file_path"),
        ]
        ds = "unknown"
        for p in candidates:
            if isinstance(p, str) and p:
                m = re.search(r"datasets/([^/]+)/", p)
                if m:
                    ds = m.group(1)
                    break
                if p.startswith("datasets/"):
                    parts = p.split("/")
                    if len(parts) > 1:
                        ds = parts[1]
                        break
        return ds

    ds_to_indices = {}
    for idx, ex in enumerate(dataset):
        ds_name = _infer_dataset_name_for_ds_item(ex)
        ds_to_indices.setdefault(ds_name, []).append(idx)

    all_indices_rr = []
    ds_names = list(ds_to_indices.keys())
    ptrs = {k: 0 for k in ds_names}
    while len(ds_names) > 0:
        progressed = False
        for name in list(ds_names):
            p = ptrs[name]
            if p < len(ds_to_indices[name]):
                all_indices_rr.append(ds_to_indices[name][p])
                ptrs[name] = p + 1
                progressed = True
            else:
                ds_names.remove(name)
        if not progressed:
            break

    total = len(all_indices_rr)
    s = max(0, int(start))
    e = total if end is None else max(0, int(end))
    s = min(s, total)
    e = min(max(e, s), total)
    window_indices = all_indices_rr[s:e]

    if len(window_indices) == 0:
        return Dataset.from_list([])

    dataset = dataset.select(window_indices)
    dataset = dataset.map(debug_map)

    def explode_answers(example):
        answers = [example.pop("class_name")]
        return [{"answer": answer, **example} for answer in answers]

    exploded_rows = []
    for example in dataset:
        exploded_rows.extend(explode_answers(example))

    new_dataset = Dataset.from_list(exploded_rows)
    print(f"Exploded dataset from {len(dataset)} to {len(new_dataset)} rows")

    return new_dataset


def refbbox3d_rec_doc_to_visual(doc):
    image = doc["image"]
    if isinstance(image, dict) and 'bytes' in image:
        image = image['bytes']
    if isinstance(image, bytes):
        from PIL import Image
        from io import BytesIO
        image = Image.open(BytesIO(image))
    image = image.convert("RGB")
    return [image.convert("RGB")]


def refbbox3d_rec_doc_to_text(doc):
    assert isinstance(doc["answer"], str), "Answer must be a string"
    return (
        "Please output the 9-DoF 3D bounding box parameters (x, y, z, w, h, l, r1, r2, r3, all normalized to 0-999) for: "
        + doc["answer"]
    )


def parse_float_sequence_within(input_str):
    """
    Extract the first sequence of nine floating-point numbers from a string.

    Args:
        input_str (str): String that may contain a bracketed list of floats.

    Returns:
        list: List of nine floats if found, otherwise nine zeros.
    """
    pattern_bracket = r"\[\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?),\s*(-?\d+(?:\.\d+)?)\s*\]"
    m = re.search(pattern_bracket, input_str)
    if m:
        return [float(m.group(i)) for i in range(1, 10)]

    floats = re.findall(r"-?\d+(?:\.\d+)?", input_str)
    if len(floats) >= 9:
        return [float(x) for x in floats[:9]]

    return [0, 0, 0, 0, 0, 0, 0, 0, 0]


def _euler_to_matrix_xyz(rx: float, ry: float, rz: float) -> np.ndarray:
    """Convert XYZ Euler angles to rotation matrix. Order: Rz @ Ry @ Rx.

    Args:
        rx, ry, rz: rotations in radians around X, Y, Z axes respectively.
    Returns:
        3x3 numpy rotation matrix.
    """
    cx, sx = math.cos(rx), math.sin(rx)
    cy, sy = math.cos(ry), math.sin(ry)
    cz, sz = math.cos(rz), math.sin(rz)
    Rx = np.array([[1, 0, 0], [0, cx, -sx], [0, sx, cx]], dtype=float)
    Ry = np.array([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], dtype=float)
    Rz = np.array([[cz, -sz, 0], [sz, cz, 0], [0, 0, 1]], dtype=float)
    return Rz @ Ry @ Rx


def _bbox3d_complete_to_corners(box3d: list) -> np.ndarray:
    """Convert bbox3d_complete [x,y,z,w,h,l,r1,r2,r3] to 8 corner points (8,3).

    Coordinate convention: length l along X, height h along Y, width w along Z.
    """
    if box3d is None or len(box3d) < 9:
        return None
    x, y, z, w, h, l, r1, r2, r3 = box3d

    if not all(np.isfinite([x, y, z, w, h, l, r1, r2, r3])):
        print(f"WARNING: Invalid box3d parameters: {box3d}")
        return None

    if w <= 0 or h <= 0 or l <= 0:
        print(f"WARNING: Invalid box dimensions: w={w}, h={h}, l={l}")
        return None

    hx, hy, hz = l / 2.0, h / 2.0, w / 2.0
    local = np.array(
        [
            [-hx, -hy, -hz],
            [ hx, -hy, -hz],
            [ hx,  hy, -hz],
            [-hx,  hy, -hz],
            [-hx, -hy,  hz],
            [ hx, -hy,  hz],
            [ hx,  hy,  hz],
            [-hx,  hy,  hz],
        ],
        dtype=float,
    )
    R = _euler_to_matrix_xyz(r1, r2, r3)

    if not np.all(np.isfinite(R)):
        print(f"WARNING: Invalid rotation matrix: {R}")
        return None

    world = (local @ R.T) + np.array([x, y, z], dtype=float)

    if not np.all(np.isfinite(world)):
        print(f"WARNING: Invalid corner points: {world}")
        return None

    return world


def refbbox3d_rec_process_result(doc, result):
    """
    Args:
        doc: a instance of the eval dataset
        results: [pred]
    Returns:
        a dictionary with key: metric name, value: metric value
    """
    pred = result[0] if len(result) > 0 else ""
    pred_norm = parse_float_sequence_within(pred)
    is_all_zero_pred = isinstance(pred_norm, (list, tuple)) and len(pred_norm) >= 9 and all(float(x) == 0.0 for x in pred_norm[:9])
    img_w = doc.get('image_width', None)
    img_h = doc.get('image_height', None)
    if not is_all_zero_pred and img_w is not None and img_h is not None:
        x_pix = pred_norm[0] / 999.0 * img_w
        y_pix = pred_norm[1] / 999.0 * img_h
        z_min_log, z_max_log = -4.0, 5.0
        log_z = pred_norm[2] / 999.0 * (z_max_log - z_min_log) + z_min_log
        z_m = float(np.exp(log_z))
        intr = doc.get('intrinsics') or {}
        fx = float(intr.get('fx', img_h))
        fy = float(intr.get('fy', img_h))
        cx = float(intr.get('cx', img_w / 2.0))
        cy = float(intr.get('cy', img_h / 2.0))
        x3d = z_m * (x_pix - cx) / max(fx, 1e-6)
        y3d = z_m * (y_pix - cy) / max(fy, 1e-6)
        width_m = pred_norm[3] / 999.0 * 15.0
        height_m = pred_norm[4] / 999.0 * 15.0
        length_m = pred_norm[5] / 999.0 * 15.0
        two_pi = 2 * math.pi
        r1 = pred_norm[6] / 999.0 * two_pi
        r2 = pred_norm[7] / 999.0 * two_pi
        r3 = pred_norm[8] / 999.0 * two_pi
        pred_complete = [x3d, y3d, z_m, width_m, height_m, length_m, r1, r2, r3]
    else:
        pred_complete = None if is_all_zero_pred else pred_norm
    ann_id = doc["question_id"]
    gt_complete_in = doc.get('bbox3d_complete')
    gt_complete = None
    if isinstance(gt_complete_in, (list, tuple)) and len(gt_complete_in) >= 9 and img_w is not None and img_h is not None:
        gx_pix, gy_pix, gz = gt_complete_in[0], gt_complete_in[1], gt_complete_in[2]
        intr = doc.get('intrinsics') or {}
        fx = float(intr.get('fx', img_h))
        fy = float(intr.get('fy', img_h))
        cx = float(intr.get('cx', img_w / 2.0))
        cy = float(intr.get('cy', img_h / 2.0))
        gx3d = gz * (gx_pix - cx) / max(fx, 1e-6)
        gy3d = gz * (gy_pix - cy) / max(fy, 1e-6)
        gw, gh, gl, gr1, gr2, gr3 = gt_complete_in[3], gt_complete_in[4], gt_complete_in[5], gt_complete_in[6], gt_complete_in[7], gt_complete_in[8]
        gt_complete = [gx3d, gy3d, gz, gw, gh, gl, gr1, gr2, gr3]

    pred_corners = _bbox3d_complete_to_corners(pred_complete)
    gt_corners = _bbox3d_complete_to_corners(gt_complete if gt_complete is not None else gt_complete_in)

    def _infer_dataset_name(d):
        candidates = [
            d.get("image_path"),
            d.get("image_path_str"),
            d.get("image_file"),
            d.get("image_id"),
            d.get("file_path"),
        ]
        ds = "unknown"
        for p in candidates:
            if isinstance(p, str) and p:
                m = re.search(r"datasets/([^/]+)/", p)
                if m:
                    ds = m.group(1)
                    break
                if p.startswith("datasets/"):
                    parts = p.split("/")
                    if len(parts) > 1:
                        ds = parts[1]
                        break
        return ds

    dataset_name = _infer_dataset_name(doc)

    data_dict = {
        "answer": doc["answer"],
        "pred": pred_complete,
        "pred_normalized": pred_norm,
        "ann_id": ann_id,
        "bbox3d_complete": gt_complete if gt_complete is not None else doc['bbox3d_complete'],
        "image_width": doc['image_width'],
        "image_height": doc['image_height'],
        "pred_corners": pred_corners.tolist() if isinstance(pred_corners, np.ndarray) else None,
        "gt_corners": gt_corners.tolist() if isinstance(gt_corners, np.ndarray) else None,
        "dataset": dataset_name,
    }
    return {f"refbbox3d_{metric}": data_dict for metric in REFBBOX3D_REC_METRICS}


def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> torch.BoolTensor:
    """Check if four vertices of each face are coplanar for each box. Returns (B,) bool tensor."""
    faces = torch.tensor(_box_planes, dtype=torch.int64, device=boxes.device)
    verts = boxes.index_select(index=faces.view(-1), dim=1)
    B = boxes.shape[0]
    P, V = faces.shape
    v0, v1, v2, v3 = verts.reshape(B, P, V, 3).unbind(2)
    e0 = F.normalize(v1 - v0, dim=-1)
    e1 = F.normalize(v2 - v0, dim=-1)
    normal = F.normalize(torch.cross(e0, e1, dim=-1), dim=-1)
    mat1 = (v3 - v0).view(B, 1, -1)
    mat2 = normal.view(B, -1, 1)
    return (mat1.bmm(mat2).abs() < eps).view(B)


def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-8) -> torch.BoolTensor:
    """Check that each triangular face has non-zero area. Returns (B,) bool tensor."""
    faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)
    verts = boxes.index_select(index=faces.view(-1), dim=1)
    B = boxes.shape[0]
    T, V = faces.shape
    v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)
    normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
    face_areas = normals.norm(dim=-1) / 2
    return (face_areas > eps).all(1).view(B)


def _box3d_overlap_omni(boxes_dt: torch.Tensor, boxes_gt: torch.Tensor, eps_coplanar: float = 1e-4, eps_nonzero: float = 1e-8) -> torch.Tensor:
    """Compute 3D IoU using Omni3D-style vertex method. Inputs are (B, 8, 3). Returns (B, B) IoU matrix."""
    invalid_coplanar = ~_check_coplanar(boxes_dt, eps=eps_coplanar)
    invalid_nonzero = ~_check_nonzero(boxes_dt, eps=eps_nonzero)
    ious = _C.iou_box3d(boxes_dt, boxes_gt)[1]
    if invalid_coplanar.any():
        ious[invalid_coplanar] = 0
    if invalid_nonzero.any():
        ious[invalid_nonzero] = 0
    return ious


def compute_iou(corners1, corners2):
    """Compute IoU of two 3D boxes from their 8x3 vertices arrays."""
    if corners1 is None or corners2 is None:
        return 0.0
    try:
        if not isinstance(corners1, (list, tuple, np.ndarray)) or not isinstance(corners2, (list, tuple, np.ndarray)):
            return 0.0
        corners1 = np.array(corners1)
        corners2 = np.array(corners2)
        if corners1.shape != (8, 3) or corners2.shape != (8, 3):
            return 0.0
        if np.any(~np.isfinite(corners1)) or np.any(~np.isfinite(corners2)):
            return 0.0
        t1 = torch.tensor(corners1, dtype=torch.float32).view(1, 8, 3)
        t2 = torch.tensor(corners2, dtype=torch.float32).view(1, 8, 3)
        with torch.no_grad():
            ious = _box3d_overlap_omni(t1, t2)
            iou_val = float(ious[0, 0].item())
            if not np.isfinite(iou_val) or iou_val > 1.0:
                print(f"DEBUG IoU: corners1 range=[{corners1.min():.3f}, {corners1.max():.3f}]")
                print(f"DEBUG IoU: corners2 range=[{corners2.min():.3f}, {corners2.max():.3f}]")
                print(f"DEBUG IoU: raw iou_val={iou_val}, isfinite={np.isfinite(iou_val)}")
                print(f"DEBUG IoU: ious tensor={ious}")
            if not np.isfinite(iou_val) or iou_val < 0:
                print(f"WARNING: Invalid IoU {iou_val}, returning 0.0")
                return 0.0
            if iou_val > 1.0:
                print(f"WARNING: IoU {iou_val} > 1.0, clamping to 1.0")
                return 1.0
        return iou_val
    except Exception:
        return 0.0


def compute_accuracy(box1, box2, threshold=0.5):
    """Return whether IoU of two boxes is above threshold."""
    iou = compute_iou(box1, box2)
    return iou >= threshold


def refbbox3d_rec_aggregation_result(results, metric):
    """Aggregate results using the specified metric."""
    scorers = {
        "IoU": compute_iou,
        "ACC@0.1": lambda x, y: compute_accuracy(x, y, 0.1),
        "ACC@0.3": lambda x, y: compute_accuracy(x, y, 0.3),
        "ACC@0.5": lambda x, y: compute_accuracy(x, y, 0.5),
        "ACC@0.7": lambda x, y: compute_accuracy(x, y, 0.7),
        "ACC@0.9": lambda x, y: compute_accuracy(x, y, 0.9),
    }
    overall_scores = []
    per_dataset_scores = {}

    for result in results:
        gt_corners = result['gt_corners']
        pred_corners = result["pred_corners"]
        dataset_name = result.get("dataset", "unknown")
        score = scorers[metric](gt_corners, pred_corners)
        if np.isfinite(score) and score <= 1.0:
            overall_scores.append(score)
            per_dataset_scores.setdefault(dataset_name, []).append(score)
        else:
            print(f"WARNING: Invalid score {score} for {dataset_name}, skipping")

    overall = sum(overall_scores) / max(len(overall_scores), 1) if overall_scores else 0.0

    try:
        pretty_parts = []
        for ds, arr in per_dataset_scores.items():
            mean_ds = sum(arr) / max(len(arr), 1)
            pretty_parts.append(f"{ds}: {mean_ds:.4f} (n={len(arr)})")
        per_ds_str = ", ".join(pretty_parts)
        print(f"Aggregated {metric} overall: {overall:.4f}; per-dataset => {per_ds_str}")
    except Exception as e:
        print(f"Aggregated {metric} overall: {overall:.4f}; per-dataset print error: {e}")

    return overall


def refbbox3d_rec_iou(results):
    return refbbox3d_rec_aggregation_result(results, "IoU")


def refbbox3d_rec_acc01(results):
    return refbbox3d_rec_aggregation_result(results, "ACC@0.1")


def refbbox3d_rec_acc03(results):
    return refbbox3d_rec_aggregation_result(results, "ACC@0.3")


def refbbox3d_rec_acc05(results):
    return refbbox3d_rec_aggregation_result(results, "ACC@0.5")


def refbbox3d_rec_acc07(results):
    return refbbox3d_rec_aggregation_result(results, "ACC@0.7")


def refbbox3d_rec_acc09(results):
    return refbbox3d_rec_aggregation_result(results, "ACC@0.9")


def refbbox3d_rec_center_acc(results):
    return refbbox3d_rec_aggregation_result(results, "Center_ACC")
