import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import crown
from kd_tree import *
import shapely
from shapely.ops import split
import trimesh
from skimage import measure
import time
from prettytable import PrettyTable
to_numpy = lambda x : x.detach().cpu().numpy()  # converts tensor to numpy array
print(torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class TimeVaryingCircleSDF(nn.Module):
    def __init__(self, radius):
        super().__init__()
        self.radius = radius

    def trajectory(self, t):
        """Swirling self-intersecting path: a Lissajous-like curve"""
        x = 0.5 * torch.sin(1 * torch.pi * 2 * t)
        y = 0.5 * torch.cos(3 * torch.pi * 2 * t)
        return torch.stack([x, y], dim=-1)

    def forward(self, pts):
        """
        pts: (N, 3), where the last coordinate is time
        """
        spatial_pts = pts[:, :2]  # (N, 2)
        t = pts[:, 2]  # (N,)
        center = self.trajectory(t)
        if center.dim() == 1:
            center = center.unsqueeze(0)
        dist = spatial_pts - center
        return torch.sqrt(torch.sum(dist ** 2, dim=1, keepdim=True)) - self.radius

class TimeVaryingPulsingCircleSDF(nn.Module):
    def __init__(self, base_radius=0.1, amplitude=0.05, freq=3.0):
        super().__init__()
        self.base_radius = base_radius
        self.amplitude = amplitude
        self.freq = freq

    def trajectory(self, t):
        """Simple circular motion"""
        x = 0.5 * torch.sin(2 * torch.pi * t)
        y = 0.5 * torch.cos(3 * torch.pi * t)
        return torch.stack([x, y], dim=-1)  # (N, 2)

    def radius_function(self, t):
        """Time-varying radius"""
        return self.base_radius + self.amplitude * torch.sin(2 * torch.pi * self.freq * t)  # (N,)

    def forward(self, pts):
        """
        pts: (N, 3), where last dim is (x, y, t)
        """
        spatial_pts = pts[:, :2]  # (N, 2)
        t = pts[:, 2]             # (N,)
        center = self.trajectory(t)  # (N, 2)
        radius = self.radius_function(t)  # (N,)

        dist_vec = spatial_pts - center
        dist = torch.sqrt(torch.sum(dist_vec ** 2, dim=1, keepdim=True))
        sdf = dist - radius.unsqueeze(1)  # (N,)
        return sdf  # (N, 1)

class TimeVaryingOvalSDF(nn.Module):
    def __init__(self, radius_x=0.3, radius_y=0.15):
        super().__init__()
        self.radius_x = radius_x
        self.radius_y = radius_y

    def trajectory(self, t):
        """Swirling self-intersecting path: a Lissajous-like curve"""
        x = 0.5 * torch.sin(2 * torch.pi * t)
        y = 0.5 * torch.cos(3 * torch.pi * t)
        return torch.stack([x, y], dim=-1)  # (N, 2)

    def rotation_matrix(self, t):
        """Time-varying 2D rotation matrix"""
        theta = 2 * torch.pi * t  # Rotation angle changes over time
        cos_t = torch.cos(theta)
        sin_t = torch.sin(theta)
        R = torch.stack([
            torch.stack([cos_t, -sin_t], dim=-1),
            torch.stack([sin_t,  cos_t], dim=-1)
        ], dim=-2)  # Shape: (N, 2, 2)
        return R

    def forward(self, pts):
        """
        pts: (N, 3), where the last coordinate is time
        """
        spatial_pts = pts[:, :2]  # (N, 2)
        t = pts[:, 2]             # (N,)
        center = self.trajectory(t)  # (N, 2)

        # Relative coordinates
        rel_pts = spatial_pts - center  # (N, 2)

        # Apply rotation
        R = self.rotation_matrix(t)     # (N, 2, 2)
        rotated_pts = torch.einsum('nij,nj->ni', R.transpose(1, 2), rel_pts)  # (N, 2)

        # Compute SDF of oval: ellipse with radii radius_x and radius_y
        x = rotated_pts[:, 0] / self.radius_x
        y = rotated_pts[:, 1] / self.radius_y
        sdf_val = torch.sqrt(x**2 + y**2) - 1.0
        return sdf_val.unsqueeze(-1)  # (N, 1)

def project_line_onto_square(a1, a2, b, x1_min, x1_max, x2_min, x2_max):
    # Define the bounding box (square)
    square = shapely.geometry.box(x1_min, x2_min, x1_max, x2_max)

    # Define the line equation a1*x1 + a2*x2 + b = 0 in explicit form
    if a2 != 0:
        # Express x2 as a function of x1
        line = shapely.geometry.LineString([
            (x1_min, (-a1*x1_min - b) / a2),
            (x1_max, (-a1*x1_max - b) / a2)
        ])
    else:
        # Vertical line case: x1 = constant
        x1 = -b / a1
        line = shapely.geometry.LineString([(x1, x2_min), (x1, x2_max)])
    # print(a1, a2, b)
    # print(line)
    # print(square)
    segment = line.intersection(square)

    return segment


def visualize_swept_region(sdf, time_steps=100, grid_res=256, bounds=6.0, outer_mesh=None):
    # x = torch.linspace(-bounds, bounds, grid_res)
    # y = torch.linspace(-bounds, bounds, grid_res)
    x = torch.linspace(-1., 1., grid_res)
    y = torch.linspace(-1., 1., grid_res)
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    print("generated meshgrid")
    times = torch.linspace(0, 1., time_steps)#.to(sdf.device)

    # Assuming xx and yy are of shape (H, W)
    H, W = xx.shape
    T = time_steps

    # Flatten spatial grid: (H * W,)
    x_flat = xx.reshape(-1)
    y_flat = yy.reshape(-1)

    # Repeat spatial coordinates for each time step
    xy = torch.stack([x_flat, y_flat], dim=1)  # (H*W, 2)
    xy = xy.repeat(T, 1)  # (T*H*W, 2)

    # Repeat time steps for all spatial locations
    t = times.repeat_interleave(H * W)  # (T*H*W,)
    print("reshaped time")
    print(t.device)
    # Combine into final (N, 3) tensor
    pts = torch.cat([xy, t.unsqueeze(1)], dim=1)  # (T*H*W, 3)

    print("getting sdf vals")
    # sdf_vals = []
    sdf_vals = sdf(pts).reshape(time_steps, H, W)

    swept_sdf = torch.min(sdf_vals, dim=0).values
    swept_sdf = swept_sdf.reshape(grid_res, grid_res).cpu().numpy()

    print("plotting sdf")
    plt.figure(figsize=(8, 8), facecolor="white")
    plt.contourf(xx.cpu(), yy.cpu(), swept_sdf < 0, levels=1)
    # plt.contour(xx.cpu(), yy.cpu(), swept_sdf, levels=[0], colors="black", linewidths=1)
    if outer_mesh:
        if outer_mesh.geom_type == 'MultiPolygon':
            for poly in outer_mesh.geoms:
                x, y = poly.exterior.xy
                plt.plot(x, y, color='r')
                for hole in poly.interiors:
                    hole_coords = np.array(hole.coords)
                    plt.plot(hole_coords[:, 0], hole_coords[:, 1], color='r')
        elif outer_mesh.geom_type == 'Polygon':
            x, y = outer_mesh.exterior.xy
            plt.plot(x, y, color='r')
            for hole in outer_mesh.interiors:
                hole_coords = np.array(hole.coords)
                plt.plot(hole_coords[:, 0], hole_coords[:, 1], color='r')
        else:
            raise NotImplementedError

    # === NEW: Plot trajectory with ticks at i/16 ===
    with torch.no_grad():
        t_traj_full = torch.linspace(0, 1, 200)  # for smooth trajectory line
        traj_xy_full = sdf.trajectory(t_traj_full).cpu().numpy()
        plt.plot(traj_xy_full[:, 0], traj_xy_full[:, 1], 'r-', label='Trajectory')

        # Ticks at i/16
        t_tick = torch.linspace(0, 1, 17)  # i/16 for i in 0..16
        traj_tick = sdf.trajectory(t_tick).cpu().numpy()
        for i in range(17):
            x_tick, y_tick = traj_tick[i]
            plt.plot(x_tick, y_tick, 'ko')  # black tick mark
            plt.text(x_tick + 0.01, y_tick + 0.01, f"{i}/16", fontsize=8, color='black')

    plt.title("Swept Region of Time-Varying Circle SDF")
    plt.axis("equal")
    plt.show()

# def visualize_swept_region(sdf, time_steps=100, grid_res=256, bounds=6.0, outer_mesh=None):
#     x = torch.linspace(-1., 1., grid_res)
#     y = torch.linspace(-1., 1., grid_res)
#     xx, yy = torch.meshgrid(x, y, indexing='ij')
#     times = torch.linspace(0, 1., time_steps)
#
#     H, W = xx.shape
#     T = time_steps
#
#     x_flat = xx.reshape(-1)
#     y_flat = yy.reshape(-1)
#
#     xy = torch.stack([x_flat, y_flat], dim=1)
#     xy = xy.repeat(T, 1)
#
#     t = times.repeat_interleave(H * W)
#     pts = torch.cat([xy, t.unsqueeze(1)], dim=1)
#
#     sdf_vals = sdf(pts).reshape(time_steps, H, W)
#     swept_sdf = torch.min(sdf_vals, dim=0).values
#     swept_sdf = swept_sdf.reshape(grid_res, grid_res).cpu().numpy()
#
#     # High-contrast background (white)
#     # Set up figure with white background
#     fig, ax = plt.subplots(figsize=(8, 8), facecolor='white')
#     ax.set_facecolor('white')
#
#     # Fill swept region with semi-transparent color
#     plt.contourf(xx.cpu(), yy.cpu(), swept_sdf < 0, levels=1)
#                  # colors=["#66b3ff"], alpha=0.6)  # light blue, 60% opacity
#
#     # Draw outer mesh lines AFTER fill
#     if outer_mesh:
#         if outer_mesh.geom_type == 'MultiPolygon':
#             for poly in outer_mesh.geoms:
#                 x, y = poly.exterior.xy
#                 plt.plot(x, y, color="red", linewidth=2.5, zorder=3)
#         elif outer_mesh.geom_type == 'Polygon':
#             x, y = outer_mesh.exterior.xy
#             plt.plot(x, y, color="red", linewidth=2.5, zorder=3)
#
#     # Trajectory line in yellow for high contrast
#     with torch.no_grad():
#         t_traj_full = torch.linspace(0, 1, 200)
#         traj_xy_full = sdf.trajectory(t_traj_full).cpu().numpy()
#         plt.plot(traj_xy_full[:, 0], traj_xy_full[:, 1], color='yellow', linewidth=1.5, label='Trajectory')
#
#         t_tick = torch.linspace(0, 1, 17)
#         traj_tick = sdf.trajectory(t_tick).cpu().numpy()
#         for i in range(17):
#             x_tick, y_tick = traj_tick[i]
#             plt.plot(x_tick, y_tick, 'ko', markersize=5)
#             plt.text(x_tick + 0.01, y_tick + 0.01, f"{i}/16", fontsize=8, color='black')
#
#     plt.title("Swept Region of Time-Varying Circle SDF", color="black", fontsize=14)
#     plt.axis("equal")
#     plt.show()


def visualize_sdf_trimesh(sdf, time_steps=100, grid_res=128, bounds=6.0, threshold=0.0):
    # Prepare 3D grid
    x = torch.linspace(-bounds, bounds, grid_res)
    y = torch.linspace(-bounds, bounds, grid_res)
    t = torch.linspace(0, 1, time_steps)

    xx, yy, tt = torch.meshgrid(x, y, t, indexing='ij')  # shape: (X, Y, T)
    pts = torch.stack([xx, yy, tt], dim=-1).reshape(-1, 3)  # shape: (N, 3)

    # Evaluate SDF
    sdf_vals = sdf(pts).reshape(grid_res, grid_res, time_steps).cpu().numpy()

    # Extract mesh using marching cubes
    verts, faces, normals, _ = measure.marching_cubes(sdf_vals, level=threshold, spacing=(
        (2 * bounds) / (grid_res - 1),  # x step
        (2 * bounds) / (grid_res - 1),  # y step
        1.0 / (time_steps - 1)          # time step
    ))

    # Offset to center domain at (0, 0, 0.5)
    verts[:, 0] -= bounds
    verts[:, 1] -= bounds

    # Create mesh and show
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
    mesh.show()


def collapse_bounds(lowers, uppers, As, bs):
    time_axis_split = As.shape[-2]
    time_axis_ticks = torch.linspace(0, 1, time_axis_split+1)
    new_As = As[..., :-1]
    new_bs_t0 = As[..., -1] * time_axis_ticks[:-1][(None,) * (As.dim() - 2)] + bs
    new_bs_t1 = As[..., -1] * time_axis_ticks[1:][(None,) * (As.dim() - 2)] + bs
    retain_t0 = (new_bs_t1 - new_bs_t0) > 0.
    new_bs = torch.where(retain_t0, new_bs_t0, new_bs_t1)

    return new_As, new_bs

def carve_sv(lowers, uppers, lAs, lbs, neg_lowers, neg_uppers, bbox=((-1., -1., 0), (1., 1., 1.))):
    lowers = lowers.detach().cpu().numpy()
    uppers = uppers.detach().cpu().numpy()
    lAs = lAs.detach().cpu().numpy()
    lbs = lbs.detach().cpu().numpy()
    convex_poly_list = []
    # Include the negative nodes
    for n_l, n_u in zip(neg_lowers, neg_uppers):
        box_inside = shapely.geometry.Polygon([n_l[:2], (n_l[0], n_u[1]), n_u[:2], (n_u[0], n_l[1])])
        convex_poly_list.append(box_inside.buffer(0))

    # Include the negative portions of unknown nodes
    for l, u, lA, lb in zip(lowers, uppers, lAs, lbs):
        square = shapely.geometry.Polygon([l[:2], (l[0], u[1]), u[:2], (u[0], l[1])])
        slices_list = []
        # print(lA, lb)
        lA = [lA]
        lb = [lb]
        for lA_t, lb_t in zip(lA, lb):
            if lA_t[0] == 0 and lA_t[1] == 0:
                slices_list.append(square)
                continue
            outer_line = project_line_onto_square(lA_t[0], lA_t[1], lb_t, bbox[0][0], bbox[1][0], bbox[0][1], bbox[1][1])
            slices1 = split(square, outer_line)

            for g in slices1.geoms:
                if g.geom_type == 'Polygon':
                    c = shapely.centroid(g)
                    c = np.array([c.x, c.y])
                    cls = np.dot(lA_t, c) + lb_t
                    if cls < 0.:
                        # continue
                        slices_list.append(g.buffer(0))
                else:
                    print(g.geom_type)
        if len(slices_list) > 0:
            largest_slice = max(slices_list, key=lambda p: p.area)
            convex_poly_list.append(largest_slice)
    merged = shapely.ops.unary_union(convex_poly_list)

    return merged

def build_unk_cover(lowers, uppers):
    lowers = lowers.detach().cpu().numpy()
    uppers = uppers.detach().cpu().numpy()
    voxels = []
    for l, u in zip(lowers, uppers):
        square = shapely.geometry.Polygon([l[:2], (l[0], u[1]), u[:2], (u[0], l[1])])
        voxels.append(square)

    merged = shapely.ops.unary_union(voxels)

    if merged.geom_type == 'MultiPolygon':
        print('multipolygon')
        merged = max(merged.geoms, key=lambda p: p.area)
    else:
        print('polygon')
    return merged

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sdf = TimeVaryingCircleSDF(radius=0.1)
    # sdf = TimeVaryingPulsingCircleSDF()
    # sdf = TimeVaryingOvalSDF(radius_x=0.05, radius_y=0.15)
    print(device)
    # visualize_sdf_trimesh(sdf, time_steps=1000, grid_res=128, bounds=6.0)

    implicit_func = crown.CrownImplicitFunction(None, crown_func=sdf, crown_mode='crown', input_dim=3)
    start_time = time.time()

    node_lower, node_upper, lAs, lbs, uAs, ubs, pos_lower, pos_upper, neg_lower, neg_upper = construct_hybrid_unknown_tree(
        implicit_func, None, torch.tensor((-1, -1, 0)), torch.tensor((1, 1, 1)), base_depth=15, max_depth=24, delta=0.01,
        batch_size=2048, include_pos_neg=True)

    print("neg lower shape: ", neg_lower.shape)
    print("node lower shape: ", node_lower.shape)
    lAs = lAs#.cpu().numpy()
    lbs = lbs#.cpu().numpy()
    uAs = uAs#.cpu().numpy()
    ubs = ubs#.cpu().numpy()
    node_valid = torch.full((node_lower.shape[0],), True)
    print("lAs shape: ", lAs.shape)
    tree_time = time.time() - start_time
    print(f"Time to build tree: {tree_time}")

    node_lower_valid = node_lower[node_valid]
    node_upper_valid = node_upper[node_valid]
    num_valid = node_valid.sum().item()

    first_stage_time = time.time() - start_time
    print("First pass time: ", first_stage_time)

    print(num_valid)
    # save all bounds and node bounds to .npz to later compute the mesh of the object
    out_valid = {
        'lower': to_numpy(node_lower_valid),
        'upper': to_numpy(node_upper_valid),
        'mA': 0.5 * lAs + 0.5 * uAs,
        'mb': 0.5 * lbs + 0.5 * ubs,
        'lA': lAs,
        'lb': lbs,
        'uA': uAs,
        'ub': ubs,
        'pos_lower': to_numpy(pos_lower),
        'pos_upper': to_numpy(pos_upper),
        'neg_lower': to_numpy(neg_lower),
        'neg_upper': to_numpy(neg_upper),
        'plane_constraints_lower': np.array([]),
        'plane_constraints_upper': np.array([]),
    }
    np.savez("trees/2d_sv.npz", **out_valid)


    xy_keys = torch.stack([
        node_lower_valid[:, 0],
        node_lower_valid[:, 1],
        node_upper_valid[:, 0],
        node_upper_valid[:, 1],
    ], dim=1)  # [N, 4]

    # Use torch.unique
    unique_keys, inverse = torch.unique(xy_keys, dim=0, return_inverse=True)

    # Group indices
    groups = []
    collapsed_lAs = []
    collapsed_lbs = []
    reordered_node_lower = []
    reordered_node_upper = []
    for k in range(len(unique_keys)):
        indices = torch.nonzero(inverse == k, as_tuple=True)[0]
        groups.append(indices)
        collapsed_lA, collapsed_lb = collapse_bounds(None, None, lAs[indices], lbs[indices])
        collapsed_lAs.append(collapsed_lA)
        collapsed_lbs.append(collapsed_lb)
        reordered_node_lower.append(node_lower_valid[indices])
        reordered_node_upper.append(node_upper_valid[indices])

    collapsed_lAs = torch.vstack(collapsed_lAs)
    collapsed_lbs = torch.hstack(collapsed_lbs)
    reordered_node_lower = torch.vstack(reordered_node_lower)
    reordered_node_upper = torch.vstack(reordered_node_upper)

    print("collapsed_lAs shape: ", collapsed_lAs.shape)
    print("node lower shape: ", reordered_node_upper.shape)
    total_time = time.time() - start_time

    # Helper function to convert seconds to hours, minutes, seconds, milliseconds
    def format_time(seconds):
        hours = int(seconds // 3600)
        minutes = int((seconds % 3600) // 60)
        secs = int(seconds % 60)
        millis = int((seconds * 1000) % 1000)
        return f"{hours:02}:{minutes:02}:{secs:02}.{millis:03}"

    # Create a table and add rows for each runtime
    table = PrettyTable()
    table.field_names = ["Stage", "Run Time (hh:mm:ss.ms)"]
    table.add_row(["Tree Building Time", format_time(first_stage_time)])
    table.add_row(["CROWN Pass", format_time(first_stage_time)])
    table.add_row(["Total", format_time(total_time)])

    # Print the table
    print(table)

    # unk_cover = build_unk_cover(node_lower_valid, node_upper_valid)
    print("built unk cover")
    # visualize_swept_region(sdf, time_steps=1000, grid_res=256, outer_mesh=unk_cover)
    print("visualized swept region")
    outer_mesh = carve_sv(reordered_node_lower, reordered_node_upper, collapsed_lAs, collapsed_lbs, neg_lower, neg_upper)
    visualize_swept_region(sdf, time_steps=1000, grid_res=256, outer_mesh=outer_mesh)

if __name__ == '__main__':
    main()