"""Obstruction detection module for finding objects between two points."""
import copy
import hashlib
import json
import os
import random
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from shapely.geometry import LineString, Point
from shapely.geometry import Polygon as ShpPolygon
from shapely.ops import unary_union

from src.qa_pairs_generation.utils import (
    _label,
    euclidean_distance,
    generate_qa_pairs_with_subsampling,
    get_polygon_centroid,
    merge_objects_and_openings,
    render_layout_pair,
    save_and_info,
)


# Swap labels ablation


def _stable_seed_from_layout(layout_id: Union[int, str]) -> int:
    """Stable 32-bit seed from layout_id (consistent across runs)."""
    return int(hashlib.md5(str(layout_id).encode()).hexdigest(), 16) % (2**32)


def swap_object_labels(
    room: dict,
    strategy: str = "rotate",  # "rotate" | "reverse" | "shuffle"
    seed: Optional[int] = None,  # used only when strategy="shuffle"
) -> dict:
    """
    Return a copy of room with all object labels swapped.

    Args:
        room: Room data dictionary
        strategy: Swap strategy - "rotate", "reverse", or "shuffle"
        seed: Random seed for shuffle strategy

    Returns:
        Copy of room with swapped labels
    """
    out = copy.deepcopy(room)
    objs = out.get("objects")
    if not isinstance(objs, list) or len(objs) < 2:
        return out  # nothing to do

    labels = [obj.get("label", "") for obj in objs]
    new_labels = labels[:]

    if strategy == "rotate":
        new_labels = labels[1:] + labels[:1]
    elif strategy == "reverse":
        new_labels = list(reversed(labels))
    elif strategy == "shuffle":
        rng = random.Random(seed if seed is not None else 0)
        rng.shuffle(new_labels)
        # ensure we actually swapped at least one position; if not, rotate
        if new_labels == labels and len(labels) > 1:
            new_labels = labels[1:] + labels[:1]
    else:
        raise ValueError("strategy must be 'rotate', 'reverse', or 'shuffle'")

    for obj, new_label in zip(objs, new_labels):
        obj["label"] = new_label
    return out


def _points_to_polygon(points: List[Dict[str, float]]) -> ShpPolygon | None:
    """Convert a list of point dictionaries to a Shapely polygon."""
    if not points or len(points) < 3:
        return None
    coords = [(float(p["x"]), float(p["y"])) for p in points]
    if coords[0] != coords[-1]:
        coords.append(coords[0])
    poly = ShpPolygon(coords)
    if not poly.is_valid or poly.is_empty or poly.area <= 0:
        return None
    return poly


def _collect_entities_with_openings(data: Dict) -> List[Dict]:
    """Return objects + openings as object-like dicts: {'label', 'points'}."""
    out: List[Dict] = []
    out.extend(data.get("objects") or [])
    openings = data.get("openings") or {}
    for key in ("windows", "doors"):
        for ent in openings.get(key) or []:
            out.append({
                "label": ent.get("label") or key,
                "points": ent.get("points") or []
            })
    return out


def find_objects_intersecting_line_poly(
    data: Dict,
    start_label: str,
    goal_label: str,
    *,
    include_openings: bool = True,
    ignore_labels: List[str] | None = None,
) -> Tuple[List[Dict], Tuple[float, float]]:
    """
    Find objects intersecting a line between two labeled objects.

    Uses polygon centroids and polygon/segment intersection.

    Args:
        data: Room data dictionary
        start_label: Label of start object
        goal_label: Label of goal object
        include_openings: Whether to include doors/windows
        ignore_labels: Labels to ignore in intersection check

    Returns:
        Tuple of (objects_on_the_way, (dx, dy)) where objects_on_the_way
        are dicts from entities list.
    """
    if include_openings:
        entities = _collect_entities_with_openings(data)
    else:
        entities = data.get("objects") or []
    label_lc_map = {_label(o).lower(): o for o in entities if _label(o)}

    s_ent = label_lc_map.get(start_label.strip().lower())
    g_ent = label_lc_map.get(goal_label.strip().lower())
    if not s_ent or not g_ent:
        return [], (0.0, 0.0)

    s_cent = get_polygon_centroid(s_ent["points"])
    g_cent = get_polygon_centroid(g_ent["points"])
    seg = LineString([s_cent, g_cent])
    dx, dy = g_cent[0] - s_cent[0], g_cent[1] - s_cent[1]

    ignore = set([_label(s_ent).lower(), _label(g_ent).lower()])
    if ignore_labels:
        ignore |= {l.strip().lower() for l in ignore_labels}

    on_the_way: List[Dict] = []
    for obj in entities:
        lab = _label(obj)
        if not lab or lab.strip().lower() in ignore:
            continue
        poly = _points_to_polygon(obj.get("points") or [])
        if poly is None:
            continue

        inter = seg.intersection(poly)
        if inter.is_empty:
            continue

        # Count as "on the way" if the intersection isn't just a single
        # endpoint exactly at s_cent or g_cent.
        allowed_endpoints = {Point(s_cent), Point(g_cent)}
        if inter.geom_type == "Point":
            if any(inter.equals(pt) for pt in allowed_endpoints):
                continue
        if inter.geom_type == "MultiPoint":
            if all(
                any(g.equals(pt) for pt in allowed_endpoints)
                for g in inter.geoms
            ):
                continue

        on_the_way.append(obj)

    return on_the_way, (dx, dy)


def render_direct_line(
    data: Dict,
    start_label: str,
    goal_label: str,
    on_the_way: List[Dict],
    out_path: str,
    *,
    show_grid: bool = True,
    show_axes: bool = True,
    dpi: int = 150,
) -> str:
    """Render a visualization showing the direct line between two objects."""
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)

    COL_ROOM = "#444444"
    COL_WALL = "#000000"
    COL_WINDOW = "#1f77b4"
    COL_DOOR = "#22aa55"
    COL_OBJ = "#999999"
    COL_OBJ_EDGE = "#666666"
    COL_HIT = "#f39c12"  # highlight objects on the way
    COL_LINE = "#e74c3c"  # line between centroids
    COL_START = "#2ecc71"
    COL_GOAL = "#1f77b4"

    fig, ax = plt.subplots(figsize=(8, 8), dpi=dpi)

    # room boundary (optional)
    rb = data.get("room_boundary") or []
    if rb:
        arr = np.array([[p["x"], p["y"]] for p in rb], dtype=float)
        ax.plot(
            arr[:, 0], arr[:, 1], "--", color=COL_ROOM,
            linewidth=1.6, alpha=0.9, label="Room boundary"
        )

    # walls
    wall_segs = []
    walls = data.get("walls") or (data.get("room", {}) or {}).get("walls") or []
    for w in walls:
        s, e = w.get("start"), w.get("end")
        if s and e:
            wall_segs.append([
                (float(s["x"]), float(s["y"])),
                (float(e["x"]), float(e["y"]))
            ])
    if wall_segs:
        ax.add_collection(
            LineCollection(wall_segs, linewidths=2.5, colors=[COL_WALL], label="Walls")
        )

    # openings (for context)
    openings = data.get("openings") or {}
    for ent in openings.get("windows") or []:
        pts = ent.get("points") or []
        if len(pts) >= 2:
            arr = np.array([[p["x"], p["y"]] for p in pts], dtype=float)
            ax.plot(
                arr[:, 0], arr[:, 1], "-", color=COL_WINDOW,
                linewidth=2.0, alpha=0.9, label="Window"
            )
    for ent in openings.get("doors") or []:
        pts = ent.get("points") or []
        if len(pts) >= 2:
            arr = np.array([[p["x"], p["y"]] for p in pts], dtype=float)
            ax.plot(
                arr[:, 0], arr[:, 1], "-", color=COL_DOOR,
                linewidth=2.0, alpha=0.9, label="Door"
            )

    # objects (fill), but highlight on-the-way
    objs = data.get("objects") or []
    hit_set = {(_label(o).strip().lower()) for o in on_the_way}

    def _fill_poly(o, face, edge, alpha=0.35, lw=1.0, z=1):
        pts = o.get("points") or []
        if len(pts) < 3:
            return
        arr = np.array([[p["x"], p["y"]] for p in pts], dtype=float)
        ax.fill(
            arr[:, 0], arr[:, 1], facecolor=face, edgecolor=edge,
            alpha=alpha, linewidth=lw, zorder=z
        )

    for o in objs:
        lab_l = _label(o).strip().lower()
        if lab_l in hit_set:
            _fill_poly(o, COL_HIT, COL_HIT, alpha=0.40, lw=1.4, z=3)
        else:
            _fill_poly(o, COL_OBJ, COL_OBJ_EDGE, alpha=0.28, lw=1.0, z=1)

    # determine start/goal centroids
    ent_map = {
        _label(o).strip().lower(): o
        for o in _collect_entities_with_openings(data)
    }
    s = ent_map.get(start_label.strip().lower())
    g = ent_map.get(goal_label.strip().lower())
    if s and g:
        s_cent = get_polygon_centroid(s["points"])
        g_cent = get_polygon_centroid(g["points"])
        ax.plot(
            [s_cent[0], g_cent[0]], [s_cent[1], g_cent[1]], "-",
            color=COL_LINE, linewidth=2.6, zorder=4, label="Direct line"
        )
        ax.plot(
            s_cent[0], s_cent[1], "o", color=COL_START,
            markersize=7, zorder=5, label="Start"
        )
        ax.plot(
            g_cent[0], g_cent[1], "*", color=COL_GOAL,
            markersize=11, zorder=5, label="Target"
        )

    # bounds
    xs, ys = [], []
    for src in [rb] if rb else []:
        xs += [p["x"] for p in src]
        ys += [p["y"] for p in src]
    for (x1, y1), (x2, y2) in wall_segs:
        xs += [x1, x2]
        ys += [y1, y2]
    for o in _collect_entities_with_openings(data):
        for p in o.get("points") or []:
            xs.append(p["x"])
            ys.append(p["y"])
    if xs:
        minx, maxx = min(xs), max(xs)
        miny, maxy = min(ys), max(ys)
        pad = 0.05 * max(maxx - minx, maxy - miny, 1e-6)
        ax.set_xlim(minx - pad, maxx + pad)
        ax.set_ylim(miny - pad, maxy + pad)
    ax.set_aspect("equal", adjustable="box")

    if show_axes:
        ax.set_xlabel("X (m)")
        ax.set_ylabel("Y (m)")
    if show_grid:
        ax.grid(True, which="both", linewidth=0.6, alpha=0.7)

    # dedup legend
    hds, lbls = ax.get_legend_handles_labels()
    seen = set()
    H = []
    L = []
    for h, l in zip(hds, lbls):
        if l in seen:
            continue
        seen.add(l)
        H.append(h)
        L.append(l)
    if len(L) >= 2:
        ax.legend(H, L, loc="upper right", fontsize=8)

    plt.tight_layout(pad=0.2)
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.05)
    plt.close(fig)
    return out_path


def process_single_file(
    file_path: str,
    file: str,
    room_type_arg: str,
    out_dir: str
) -> Dict:
    """Process a single JSON file and return QA pair data."""
    with open(file_path, "r") as f:
        data = json.load(f)

    # layout_id
    layout_id = data.get("layout_id")
    if layout_id is None:
        layout_id = file.replace(".json", "").replace("real_", "").replace("room_", "")

    swapped_path = (
        f"data/hssd_data/json_simplified/"
        f"swapped_labels/room_{layout_id}.json"
    )
    with open(swapped_path, "r") as f:
        swapped_labels_data = json.load(f)

    # room_type
    room_type = room_type_arg
    if room_type == "unknown":
        room_type = (
            data.get("room_type") or
            (data.get("room", {}) or {}).get("room_type") or
            "unknown"
        )

    # objects (include windows/doors as objects)
    objects = merge_objects_and_openings(data)
    if len(objects) < 2:
        raise ValueError(
            f"Need at least 2 objects (incl. windows/doors) in {file_path}"
        )

    # Create a DETERMINISTIC random number generator based on the layout_id
    rng = np.random.default_rng(seed=abs(hash(layout_id)))
    obj1, obj2 = rng.choice(objects, 2, replace=False)

    # pick labels however you like (override map, pool, etc.)
    on_the_way, (dx, dy) = find_objects_intersecting_line_poly(
        swapped_labels_data,
        start_label=_label(obj1),
        goal_label=_label(obj2),
        include_openings=True,
    )

    # centroids
    center1 = get_polygon_centroid(obj1["points"])
    center2 = get_polygon_centroid(obj2["points"])

    # render
    out_png = Path(out_dir) / f"{layout_id}.png"
    out_png.parent.mkdir(parents=True, exist_ok=True)

    render_direct_line(data, _label(obj1), _label(obj2), on_the_way, out_png)

    return {
        "answer": [_label(obj) for obj in on_the_way],
        "layout_id": layout_id,
        "room_type": room_type,
        "object_1": _label(obj1),
        "object_2": _label(obj2),
        "N_points_obj_1": len(obj1["points"]),
        "N_points_obj_2": len(obj2["points"]),
        "center_1": center1,
        "center_2": center2,
        "N_objects": len(objects),
    }


def main_obstruction(
    input_dir: str = "data/hssd_data/new_format",
    output_csv: str = "benchmark/{parent_folder_name}/{parent_folder_name}_qa_hssd_data.csv",
    output_img: str = "benchmark/{parent_folder_name}/{parent_folder_name}_qa_hssd_images/",
    enable_subsampling: bool = False,
    bedrooms_count: int = 80,
    living_rooms_count: int = 80,
    kitchens_count: int = 40,
):
    """Generate obstruction QA pairs."""
    parent_folder_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
    output_csv = output_csv.format(parent_folder_name=parent_folder_name)
    output_img = output_img.format(parent_folder_name=parent_folder_name)

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)

    # Configure subsampling
    subsample_config = None
    if enable_subsampling:
        subsample_config = {
            "bedrooms": bedrooms_count,
            "living_rooms": living_rooms_count,
            "kitchens": kitchens_count
        }
        print(f"Subsampling enabled: {subsample_config}")
    else:
        print("Processing all available files")

    qa_pairs = generate_qa_pairs_with_subsampling(
        input_dir=input_dir,
        process_single_file=partial(process_single_file, out_dir=output_img),
        subsample_config=subsample_config
    )

    save_and_info(qa_pairs, output_csv=output_csv)
