import fire
import os
import json
from typing import List, Dict, Tuple, Optional
import numpy as np
import shapely
from shapely.geometry import Polygon as ShpPolygon, LineString, Point, box
from shapely.strtree import STRtree

from src.qa_pairs_generation.utils import get_polygon_centroid, generate_qa_pairs_with_subsampling, save_and_info


SEED = 42  # For reproducibility

# Define a small epsilon for floating-point comparisons
epsilon = 1e-9


# ---- helpers ----


def find_objects_intersecting_line(room_data, obj1, obj2, mode="any"):
    def polygon_from_points(points):
        if not points or len(points) < 3:
            return None
        coords = [(p["x"], p["y"]) for p in points]
        if coords[0] != coords[-1]:
            coords.append(coords[0])
        poly = ShpPolygon(coords)
        return poly if poly.is_valid and not poly.is_empty else None

    def centroid_xy(points):
        poly = polygon_from_points(points)
        if poly and poly.area > 0:
            c = poly.centroid
            return (float(c.x), float(c.y))
        xs = [p["x"] for p in points]
        ys = [p["y"] for p in points]
        return (sum(xs) / len(points), sum(ys) / len(points))

    def intersects_line(poly, seg, mode="any"):
        if mode == "any":
            return poly.intersects(seg)
        if mode == "proper":
            if poly.contains(seg) or poly.touches(seg):
                return False
            return poly.intersects(seg)
        return poly.intersects(seg)

    objs = room_data.get("objects", [])
    c1 = centroid_xy(obj1["points"])
    c2 = centroid_xy(obj2["points"])
    seg = LineString([c1, c2])

    # build polygons + keep originals aligned
    polys, keep = [], []
    for o in objs:
        if o is obj1 or o is obj2:
            continue
        p = polygon_from_points(o.get("points", []))
        if p is not None:
            polys.append(p)
            keep.append(o)

    if not polys:
        return [], (c2[0] - c1[0], c2[1] - c1[1])

    arr = np.array(polys, dtype=object)
    tree = shapely.strtree.STRtree(arr)

    # Shapely 2.x: query returns indices when the tree was built from an array
    idxs = tree.query(seg)
    out = []
    for i in np.atleast_1d(idxs).tolist():
        if intersects_line(arr[i], seg, mode=mode):
            out.append(keep[i])

    return out, (c2[0] - c1[0], c2[1] - c1[1])


def process_single_file(file_path: str, file: str, room_type: str) -> Dict:
    """
    Process a single JSON file and return QA pair data
    """
    with open(file_path, "r") as f:
        data = json.load(f)

    # layout_id = room_data.get("layout_id", file.replace(".json", ""))
    layout_id = file.replace(".json", "").replace("real_", "").replace("room_", "")  # Clean layout_id
    # Use provided room_type or extract from room data
    if room_type == "unknown":
        room_type = data["room"]["room_type"]

    objects = data["objects"]

    # Randomly select any two objects from the available objects
    target_object, reference_object = np.random.choice(objects, 2, replace=False)

    target_object, reference_object = np.random.choice(objects, 2, replace=False)
    objects_intersecting_line, line_vector = find_objects_intersecting_line(data, target_object, reference_object, mode="proper")  # or "any"
    objects_intersecting_line_names = [obj["label"] for obj in objects_intersecting_line]

    # print(objects_intersecting_line_names)

    return {
        "answer": objects_intersecting_line_names,
        "layout_id": layout_id,
        "room_type": room_type,
        "reference_object": reference_object["label"],
        "target_object": target_object["label"],
        "center_reference": get_polygon_centroid(reference_object["points"]),
        "center_target": get_polygon_centroid(target_object["points"]),
        "intersecting_objects": objects_intersecting_line_names,
        "N_objects": len(data["objects"]),
    }


def main_obstruction(
    input_dir: str = "data/real_data",
    output_csv: str = "benchmark/{parent_folder_name}/{parent_folder_name}_qa_real_data.csv",
    enable_subsampling: bool = False,
    bedrooms_count: int = 80,
    living_rooms_count: int = 80,
    kitchens_count: int = 40,
):
    parent_folder_name = os.path.basename(os.path.dirname(os.path.realpath(__file__)))
    output_csv = output_csv.format(parent_folder_name=parent_folder_name)

    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)

    # Configure subsampling
    subsample_config = None
    if enable_subsampling:
        subsample_config = {"bedrooms": bedrooms_count, "living_rooms": living_rooms_count, "kitchens": kitchens_count}
        print(f"Subsampling enabled: {subsample_config}")
    else:
        print("Processing all available files")

    qa_pairs = generate_qa_pairs_with_subsampling(input_dir=input_dir, process_single_file=process_single_file, SEED=SEED, subsample_config=subsample_config)
    # print(qa_pairs)
    save_and_info(qa_pairs, output_csv=output_csv)
