"""Repositioning module for calculating max slide distance of objects."""
from __future__ import annotations

import copy
import hashlib
import json
import math
import os
import random
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from shapely.affinity import translate
from shapely.geometry import MultiPolygon, Polygon
from shapely.geometry import Polygon as ShpPolygon, Point
from shapely.ops import unary_union
from shapely.validation import make_valid

from src.qa_pairs_generation.utils import (
    _label,
    build_room_polygon,
    generate_qa_pairs_with_subsampling,
    get_polygon_centroid,
    is_light_label,
    is_rug_label,
    merge_objects_and_openings,
    save_and_info,
)


# Swap labels ablation


def _stable_seed_from_layout(layout_id: Union[int, str]) -> int:
    """Stable 32-bit seed from layout_id (consistent across runs)."""
    return int(hashlib.md5(str(layout_id).encode()).hexdigest(), 16) % (2**32)


def swap_object_labels(
    room: dict,
    strategy: str = "rotate",  # "rotate" | "reverse" | "shuffle"
    seed: Optional[int] = None,  # used only when strategy="shuffle"
) -> dict:
    """
    Return a copy of room with all object labels swapped.

    Args:
        room: Room data dictionary
        strategy: Swap strategy - "rotate", "reverse", or "shuffle"
        seed: Random seed for shuffle strategy

    Returns:
        Copy of room with swapped labels
    """
    out = copy.deepcopy(room)
    objs = out.get("objects")
    if not isinstance(objs, list) or len(objs) < 2:
        return out  # nothing to do

    labels = [obj.get("label", "") for obj in objs]
    new_labels = labels[:]

    if strategy == "rotate":
        new_labels = labels[1:] + labels[:1]
    elif strategy == "reverse":
        new_labels = list(reversed(labels))
    elif strategy == "shuffle":
        rng = random.Random(seed if seed is not None else 0)
        rng.shuffle(new_labels)
        # ensure we actually swapped at least one position; if not, rotate
        if new_labels == labels and len(labels) > 1:
            new_labels = labels[1:] + labels[:1]
    else:
        raise ValueError("strategy must be 'rotate', 'reverse', or 'shuffle'")

    for obj, new_label in zip(objs, new_labels):
        obj["label"] = new_label
    return out


obj_for_movement = {
    "kitchen": ["stove", "fridge", "sink", "dishwasher"],
    "living_room": [
        "sofa", "loveseat", "armchair", "coffee_table",
        "side_table", "tv_stand", "bookshelf", "plant"
    ],
    "bedroom": [
        "bed", "dresser", "wardrobe", "desk", "chair", "bookshelf",
        "ottoman", "closet_alcove", "floor_lamp", "plant"
    ],
}

GEOMETRY_TOLERANCE = 1e-4


def _poly(points: List[Dict[str, float]]) -> Optional[ShpPolygon]:
    """Convert point list to Shapely polygon."""
    if not points or len(points) < 3:
        return None
    coords = [(float(p["x"]), float(p["y"])) for p in points]
    if coords[0] != coords[-1]:
        coords.append(coords[0])
    poly = ShpPolygon(coords)
    poly = make_valid(poly)
    if poly.is_empty or not poly.is_valid or poly.area <= 0:
        return None
    return poly


def _dir_vec(direction: str) -> Tuple[float, float]:
    """Convert direction string to unit vector."""
    d = direction.lower()
    if d in ("up", "top"):
        return (0.0, 1.0)
    if d in ("down", "bottom"):
        return (0.0, -1.0)
    if d == "left":
        return (-1.0, 0.0)
    if d == "right":
        return (1.0, 0.0)
    raise ValueError(f"Invalid direction '{direction}'. Use top/bottom/left/right.")


def calculate_max_slide_distance(
    object_to_move: Polygon,
    room_boundary: Polygon,
    obstacles: List[Polygon],
    direction: Tuple[float, float],
    step_distance: float = 0.01  # Move by 1cm at a time
) -> float:
    """
    Calculate max slide distance by moving iteratively in small steps.

    This avoids the "tunneling" problem of binary search.
    """
    if not object_to_move or object_to_move.is_empty:
        return 0.0

    obstacle_union = unary_union(obstacles)

    # Initial position sanity check
    if not room_boundary.covers(object_to_move):
        return 0.0
    if isinstance(obstacle_union, (Polygon, MultiPolygon)):
        if object_to_move.overlaps(obstacle_union):
            return 0.0

    vx, vy = direction

    # Iterative stepping loop
    current_dist = 0.0
    # Set a reasonable limit to prevent infinite loops (e.g., room diagonal)
    min_x, min_y, max_x, max_y = room_boundary.bounds
    max_possible_dist = np.sqrt((max_x - min_x) ** 2 + (max_y - min_y) ** 2)

    num_steps = int(max_possible_dist / step_distance)

    for i in range(1, num_steps + 1):
        next_dist = i * step_distance
        moved_object = translate(
            object_to_move, xoff=vx * next_dist, yoff=vy * next_dist
        )

        # Check for collision at the new step
        # A) Is it outside the room?
        if not room_boundary.buffer(-GEOMETRY_TOLERANCE).covers(moved_object):
            break  # We've hit a wall, stop here.

        # B) Does it overlap an obstacle? (Allowing touch)
        intersection = moved_object.intersection(obstacle_union)
        if not intersection.is_empty and intersection.area > GEOMETRY_TOLERANCE:
            break  # We've hit an obstacle, stop here.

        # If we reach here, the step was valid. Update our distance.
        current_dist = next_dist

    return current_dist


def move_object_and_get_distance(
    data: Dict,
    object_to_move: Dict,
    direction: str,
) -> Tuple[Optional[List[Dict[str, float]]], float]:
    """
    Prepare data and use the robust calculation engine to find the max
    move distance and new object position.
    """
    # Prepare data
    room_poly = build_room_polygon(data)
    poly_to_move = _poly(object_to_move.get("points") or [])

    if room_poly is None or poly_to_move is None:
        return object_to_move.get("points"), 0.0

    obstacle_polys = []
    for o in data.get("objects", []):
        if o is object_to_move:
            continue
        lab = _label(o)
        if is_rug_label(lab) or is_light_label(lab):
            continue
        p = _poly(o.get("points") or [])
        if p:
            obstacle_polys.append(p)

    dir_vector = _dir_vec(direction)

    # Call the geometry engine
    distance = calculate_max_slide_distance(
        object_to_move=poly_to_move,
        room_boundary=room_poly,
        obstacles=obstacle_polys,
        direction=dir_vector
    )

    # Calculate final position and return
    vx, vy = dir_vector
    P_new = translate(poly_to_move, xoff=vx * distance, yoff=vy * distance)

    coords = list(P_new.exterior.coords)[:-1]
    after_pts = [{"x": float(x), "y": float(y)} for x, y in coords]

    return after_pts, float(distance)


def _segments_from_points(points: List[Dict[str, float]], close=True):
    """Convert points to line segments."""
    if not points or len(points) < 2:
        return []
    segs = []
    n = len(points)
    for i in range(n - 1):
        a, b = points[i], points[i + 1]
        segs.append([(a["x"], a["y"]), (b["x"], b["y"])])
    if close and n >= 3:
        a, b = points[-1], points[0]
        segs.append([(a["x"], a["y"]), (b["x"], b["y"])])
    return segs


def render_move_subplot(
    data: Dict,
    moved_label: str,
    points_before: List[Dict[str, float]],
    points_after: List[Dict[str, float]],
    distance_moved: float,
    direction: str,
    out_path: str,
    *,
    dpi: int = 150,
):
    """Render a before/after visualization of object movement."""
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)

    COL_ROOM = "#444444"
    COL_WALL = "#000000"
    COL_OBJ = "#A0A0A0"
    COL_OBJ_EDGE = "#666666"
    COL_RUG = "#b2df8a"
    COL_LIGHT = "#ffd166"
    COL_MOVED_BEFORE = "#1f77b4"
    COL_MOVED_AFTER = "#e74c3c"
    COL_ARROW = "#e67e22"

    def _bounds():
        xs, ys = [], []
        rb = data.get("room_boundary") or []
        for p in rb:
            xs.append(p["x"])
            ys.append(p["y"])
        walls = data.get("walls") or (data.get("room", {}) or {}).get("walls") or []
        for w in walls:
            s, e = w.get("start"), w.get("end")
            if s and e:
                xs += [s["x"], e["x"]]
                ys += [s["y"], e["y"]]
        for o in data.get("objects") or []:
            for p in o.get("points") or []:
                xs.append(p["x"])
                ys.append(p["y"])
        if not xs:
            return (-1, -1, 1, 1)
        return (min(xs), min(ys), max(xs), max(ys))

    def _draw(ax, use_after=False):
        rb = data.get("room_boundary") or []
        if rb:
            arr = np.array([[p["x"], p["y"]] for p in rb], dtype=float)
            ax.plot(arr[:, 0], arr[:, 1], "--", color=COL_ROOM, linewidth=1.6, alpha=0.9)

        # walls
        wall_segs = []
        walls = data.get("walls") or (data.get("room", {}) or {}).get("walls") or []
        for w in walls:
            s, e = w.get("start"), w.get("end")
            if s and e:
                wall_segs.append([
                    (float(s["x"]), float(s["y"])),
                    (float(e["x"]), float(e["y"]))
                ])
        if wall_segs:
            ax.add_collection(LineCollection(wall_segs, linewidths=2.5, colors=[COL_WALL]))

        # objects
        for o in data.get("objects") or []:
            lab = (o.get("label") or o.get("name") or "").lower()
            pts = o.get("points") or []
            if not pts:
                continue
            arr = np.array([[p["x"], p["y"]] for p in pts], dtype=float)

            if lab == moved_label.lower():
                draw_pts = points_after if use_after else points_before
                arr2 = np.array([[p["x"], p["y"]] for p in draw_pts], dtype=float)
                color = COL_MOVED_AFTER if use_after else COL_MOVED_BEFORE
                ax.fill(
                    arr2[:, 0], arr2[:, 1], facecolor=color,
                    edgecolor="#222222", alpha=0.45, linewidth=1.5
                )
            else:
                if is_rug_label(lab):
                    ax.fill(
                        arr[:, 0], arr[:, 1], facecolor=COL_RUG,
                        edgecolor=COL_OBJ_EDGE, alpha=0.25
                    )
                elif is_light_label(lab):
                    ax.fill(
                        arr[:, 0], arr[:, 1], facecolor=COL_LIGHT,
                        edgecolor=COL_OBJ_EDGE, alpha=0.25
                    )
                else:
                    ax.fill(
                        arr[:, 0], arr[:, 1], facecolor=COL_OBJ,
                        edgecolor=COL_OBJ_EDGE, alpha=0.30
                    )

        # arrow
        if points_before and points_after:
            c0 = get_polygon_centroid(points_before)
            c1 = get_polygon_centroid(points_after)
            ax.annotate(
                "", xy=c1, xytext=c0,
                arrowprops=dict(arrowstyle="->", lw=2.0, color=COL_ARROW)
            )
            ax.plot([c0[0]], [c0[1]], "o", color="#2ecc71", ms=6, zorder=5)
            ax.plot([c1[0]], [c1[1]], "*", color="#d62728", ms=9, zorder=5)

        ax.set_aspect("equal", adjustable="box")
        minx, miny, maxx, maxy = _bounds()
        pad = 0.05 * max(maxx - minx, maxy - miny, 1e-6)
        ax.set_xlim(minx - pad, maxx + pad)
        ax.set_ylim(miny - pad, maxy + pad)
        ax.grid(True, which="both", linewidth=0.6, alpha=0.7)

    # subplot with before and after
    fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=dpi)
    _draw(axes[0], use_after=False)
    axes[0].set_title("Before movement")
    _draw(axes[1], use_after=True)
    axes[1].set_title("After movement")

    fig.suptitle(
        f"Object '{moved_label}' moved by {distance_moved:.2f} meters {direction}",
        fontsize=14
    )
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.05)
    plt.close(fig)


def process_single_file(
    file_path: str,
    file: str,
    room_type: str,
    out_dir: str
) -> Optional[Dict]:
    """Process a single JSON file and return QA pair data."""
    with open(file_path, "r") as f:
        data = json.load(f)

    # layout_id
    layout_id = data.get("layout_id")
    if layout_id is None:
        layout_id = file.replace(".json", "").replace("real_", "").replace("room_", "")

    swapped_path = (
        f"data/hssd_data/json_simplified/"
        f"swapped_labels/room_{layout_id}.json"
    )
    with open(swapped_path, "r") as f:
        swapped_labels_data = json.load(f)

    # room_type
    if room_type == "unknown":
        room_type = (
            data.get("room_type") or
            (data.get("room", {}) or {}).get("room_type") or
            "unknown"
        )

    # Always filter objects based on the determined room_type.
    if room_type in obj_for_movement and room_type in file_path:
        def check_any_label(obj):
            lab = _label(obj).lower()
            for target in obj_for_movement[room_type]:
                if target in lab:
                    return True
            return False

        objects = [obj for obj in data["objects"] if check_any_label(obj)]
    else:
        # Fallback for unknown or unsupported room types: use all objects.
        objects = data["objects"]

    # pick eligible movable objects (no openings, no rugs/lights)
    movable = [
        o for o in objects
        if (not is_rug_label(_label(o))) and (not is_light_label(_label(o)))
    ]
    if not movable:
        print("NONE movable objects found in", objects, room_type, file_path)
        return None

    seed = int(hashlib.md5(str(layout_id).encode()).hexdigest(), 16) % (2**32)
    rng = np.random.default_rng(seed=seed)
    rng = np.random.default_rng(seed=abs(hash(layout_id)))
    mv = rng.choice(movable)
    mv_label = _label(mv)

    before_pts = mv["points"]

    tried_dirs = set()
    directions = ["top", "bottom", "left", "right"]
    direction = rng.choice(directions)

    swapped_labels_objects = swapped_labels_data["objects"]
    mv = next((obj for obj in swapped_labels_objects if _label(obj) == mv_label), mv)

    after_pts, dist = move_object_and_get_distance(swapped_labels_data, mv, direction)
    tried_dirs.add(direction)

    # Try up to 4 times with different directions if dist == 0
    attempts = 1
    while dist < 1e-2 and attempts < 4:
        remaining_dirs = [d for d in directions if d not in tried_dirs]
        if not remaining_dirs:
            break
        direction = rng.choice(remaining_dirs)
        after_pts, dist = move_object_and_get_distance(data, mv, direction)
        tried_dirs.add(direction)
        attempts += 1

    if after_pts is None:
        after_pts = before_pts
        dist = 0.0

    # render
    out_png = Path(out_dir) / f"{layout_id}.png"
    out_png.parent.mkdir(parents=True, exist_ok=True)

    render_move_subplot(data, mv_label, before_pts, after_pts, dist, direction, out_png)

    return {
        "layout_id": layout_id,
        "room_type": room_type,
        "object_to_move": mv_label,
        "direction": direction,
        "answer": round(float(dist), 3),
        "N_objects": len(objects),
    }


def main_repositioning(
    input_dir: str = "data/hssd_data/new_format",
    output_csv: str = "benchmark/{parent_folder_name}/{parent_folder_name}_qa_hssd_data.csv",
    output_img: str = "benchmark/{parent_folder_name}/{parent_folder_name}_qa_hssd_images/",
    enable_subsampling: bool = False,
    bedrooms_count: int = 80,
    living_rooms_count: int = 80,
    kitchens_count: int = 40,
):
    """Generate repositioning QA pairs."""
    parent_folder_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
    output_csv = output_csv.format(parent_folder_name=parent_folder_name)
    output_img = output_img.format(parent_folder_name=parent_folder_name)

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)

    # Configure subsampling
    subsample_config = None
    if enable_subsampling:
        subsample_config = {
            "bedrooms": bedrooms_count,
            "living_rooms": living_rooms_count,
            "kitchens": kitchens_count
        }
        print(f"Subsampling enabled: {subsample_config}")
    else:
        print("Processing all available files")

    qa_pairs = generate_qa_pairs_with_subsampling(
        input_dir=input_dir,
        process_single_file=partial(process_single_file, out_dir=output_img),
        subsample_config=subsample_config
    )

    save_and_info(qa_pairs, output_csv=output_csv)
