import torch
import torch.nn as nn
import crown
import kd_tree
import shapely
import matplotlib
import numpy as np
from shapely.ops import split, unary_union
import matplotlib.pyplot as plt
from descartes import PolygonPatch
# torch.set_default_device(torch.device('cuda:0'))
from auto_LiRPA import BoundedModule
from shapely import wkt
import matplotlib as mpl

mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman', 'Times', 'Nimbus Roman', 'Liberation Serif'],
    'mathtext.fontset': 'custom',
    'mathtext.rm': 'Times New Roman',
    'font.weight': 'regular',
    'axes.labelweight': 'regular'
})


def sdf_circle(pts: torch.Tensor, center: torch.Tensor, radius: float) -> torch.Tensor:
    """
    Signed distance to a circle.
    pts: (N, 2)
    center: (2,)
    radius: float
    Returns: (N, 1)
    """
    return (pts - center).norm(dim=1, keepdim=True) - radius

def sdf_square(pts: torch.Tensor, center: torch.Tensor, dims: float) -> torch.Tensor:
    """
    Signed distance to an axis-aligned square.
    pts: (N, 2)
    center: (2,)
    dims: float (side length)
    Returns: (N, 1)
    """
    q = torch.abs(pts - center) - dims / 2.0
    outside = torch.clamp(q, min=0)
    outside_dist = outside.norm(dim=1, keepdim=True)
    inside_dist = torch.clamp(q.max(dim=1, keepdim=True).values, max=0)
    return outside_dist + inside_dist

class CircleSDF(nn.Module):
    def __init__(self, center, radius, device='cpu'):
        super().__init__()
        self.register_buffer('center', torch.tensor(center, dtype=torch.float32, device=device))
        self.radius = radius

    def forward(self, pts):
        dist = pts - self.center
        return torch.sqrt(torch.sum(dist ** 2, dim=1, keepdim=True)) - self.radius

class SquareSDF(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('center', torch.tensor(-0.5))
        self.register_buffer('dims', torch.tensor(2.))

    def forward(self, pts):
        q = torch.abs(pts - self.center) - self.dims / 2.
        q0 = q[:, 0:1]
        q1 = q[:, 1:2]
        pow_outside_dist = torch.relu(q0) ** 2 + torch.relu(q1) ** 2
        outside_dist = torch.sqrt(pow_outside_dist)
        q2 = torch.nn.functional.relu(q1 - q0)
        q3 = q0 + q2
        inside_dist = -torch.nn.functional.relu(-q3)
        return outside_dist + inside_dist

def project_line_onto_square(a1, a2, b, x1_min, x1_max, x2_min, x2_max):
    # Define the bounding box (square)
    square = shapely.geometry.box(x1_min, x2_min, x1_max, x2_max)

    # Define the line equation a1*x1 + a2*x2 + b = 0 in explicit form
    if a2 != 0:
        # Express x2 as a function of x1
        line = shapely.geometry.LineString([
            (x1_min, (-a1*x1_min - b) / a2),
            (x1_max, (-a1*x1_max - b) / a2)
        ])
    else:
        # Vertical line case: x1 = constant
        x1 = -b / a1
        line = shapely.geometry.LineString([(x1, x2_min), (x1, x2_max)])
    # print(a1, a2, b)
    # print(line)
    # print(square)
    segment = line.intersection(square)

    return segment

def carve(net, deep=True, inner=False, bbox=((-0.5, -0.5), (0.5, 0.5))):
    print("Carving")
    lower = torch.tensor(bbox[0])
    upper = torch.tensor(bbox[1])
    func = crown.CrownImplicitFunction(None, crown_func=net, crown_mode='crown', input_dim=2)
    if deep:
        # print("deep")
        lowers, uppers, lAs, lbs, uAs, ubs, pos_lowers, pos_uppers, neg_lowers, neg_uppers = kd_tree.construct_hybrid_unknown_tree(
            func, net, lower, upper, base_depth=15, max_depth=18, node_dim=2, include_pos_neg=True)
    else:
        # print("not deep")
        lowers, uppers, lAs, lbs, uAs, ubs, pos_lowers, pos_uppers, neg_lowers, neg_uppers = kd_tree.construct_hybrid_unknown_tree(
            func, net, lower, upper, base_depth=7, max_depth=8, node_dim=2, include_pos_neg=True)
    lowers = lowers.detach().cpu().numpy()
    uppers = uppers.detach().cpu().numpy()
    lAs = lAs.detach().cpu().numpy()
    lbs = lbs.detach().cpu().numpy()
    uAs = uAs.detach().cpu().numpy()
    ubs = ubs.detach().cpu().numpy()
    pos_lowers = pos_lowers.detach().cpu().numpy()
    pos_uppers = pos_uppers.detach().cpu().numpy()
    neg_lowers = neg_lowers.detach().cpu().numpy()
    neg_uppers = neg_uppers.detach().cpu().numpy()
    convex_poly_list = []
    # Include the negative nodes
    for n_l, n_u in zip(neg_lowers, neg_uppers):
        box_inside = shapely.geometry.Polygon([n_l, (n_l[0], n_u[1]), n_u, (n_u[0], n_l[1])])
        convex_poly_list.append(box_inside.buffer(0))

    # Include the negative portions of unknown nodes
    for l, u, lA, lb, uA, ub in zip(lowers, uppers, lAs, lbs, uAs, ubs):
        square = shapely.geometry.Polygon([l, (l[0], u[1]), u, (u[0], l[1])])
        if inner:
            if uA[0] == 0 and uA[1] == 0:
                continue
            inner_line = project_line_onto_square(uA[0], uA[1], ub, bbox[0][0], bbox[1][0], bbox[0][1], bbox[1][1])
            # plt.plot(inner_line.xy[0], inner_line.xy[1], c='k')
            slices1 = split(square, inner_line)
        else:
            if lA[0] == 0 and lA[1] == 0:
                continue
            outer_line = project_line_onto_square(lA[0], lA[1], lb, bbox[0][0], bbox[1][0], bbox[0][1], bbox[1][1])
            slices1 = split(square, outer_line)

        for g in slices1.geoms:
            if g.geom_type == 'Polygon':
                c = shapely.centroid(g)
                c = np.array([c.x, c.y])
                if inner:
                    cls = np.dot(uA, c) + ub
                else:
                    cls = np.dot(lA, c) + lb
                if cls < 0.:
                    convex_poly_list.append(g.buffer(0))
            else:
                print(g.geom_type)
    merged = shapely.ops.unary_union(convex_poly_list)

    # Extract the largest connected component
    if merged.geom_type == 'MultiPolygon':
        print('multipolygon')
        largest_component = max(merged.geoms, key=lambda p: p.area)
        return largest_component
    else:
        print('polygon')
        return merged


import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np


def make_sdf_cm():
    sdf_cm_colors = [(0, '#ffffff'), (0.5 + 1e-5, '#489acc'), (0.5 + 1e-5, '#fffae3'), (1, '#ff7424')]

    cmap = mpl.colors.LinearSegmentedColormap.from_list('SDF', sdf_cm_colors, N=256)
    return SDF_cmap(0.95, 'sdf_cmap', cmap._segmentdata)


"""
This class overrides the Colormap class and does some preprocesseing before the main call to the 
colormap to add alternating darkening to the isolines. It is only designed to work with data plotted
using mpl's countour/countourf methods. (eg. through plot_sdf below)
"""


class SDF_cmap(mpl.colors.LinearSegmentedColormap):
    def __init__(self, alph, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.alph = alph

    def set_contours(self, cont):
        self.cont = cont

    def set_min_max(self, minv, maxv):
        self.minv = minv
        self.maxv = maxv

    def __call__(self, X, **kwargs):
        X = X.squeeze()
        n_c = X.shape[0]
        Xp = (1 / 2) * ((-X + 1) * (self.minv / self.maxv) + X + 1)
        colors = super().__call__(Xp, **kwargs)
        i = np.minimum(np.floor(X * (n_c)), n_c - 1);
        mask_is = np.repeat((i % 2 == 1)[:, None], 4, axis=1).squeeze()
        mult_mask = (np.ones(colors.shape)).squeeze()
        mult_mask[mask_is] = self.alph
        mult_mask[:, 3] = 1
        return mult_mask * colors


def plot_sdf(sdf, X, Y, **kwargs):
    pts = np.stack((X.flatten(), Y.flatten()), 1)
    return plot_sdf_from_vals(X, Y, sdf(pts), **kwargs)


def plot_sdf_from_vals(X, Y, Z, ax=None, colorbar=True, cmap=None, N=25, alpha=1, levels=None, return_levels=False):
    if cmap is None:
        cmap = make_sdf_cm()

    Z = np.reshape(Z, X.shape)

    if ax is None:
        fig = plt.figure()
        ax = fig.gca()
    else:
        fig = plt.gcf()

    if levels is None:
        levels = mpl.ticker.MaxNLocator(N).tick_values(np.min(Z), np.max(Z))

    try:
        # If this is an SDF_cmap, want to set its min/max
        cmap.set_min_max(levels[0], levels[-1])
    except:
        pass
    # norm = mpl.colors.TwoSlopeNorm(vmin=np.min(Z), vcenter=0.0, vmax=np.max(Z))

    cp = ax.contourf(X, Y, Z, levels, cmap=cmap, alpha=alpha)

    if colorbar:
        fig.colorbar(cp, ax=ax)

    ax.set_aspect('equal', 'box')

    if return_levels:
        return cp, levels
    return cp


def make_grid(*dims, res=100):
    return np.meshgrid(*[np.linspace(*dim, res) for dim in dims])


import time

import warp as wp
import numpy as np


@wp.kernel
def point_to_polygon_distance_kernel(
        points: wp.array(dtype=wp.vec2),
        polygon: wp.array(dtype=wp.vec2),
        num_edges: int,
        out_distances: wp.array(dtype=wp.float32)
):
    tid = wp.tid()

    p = points[tid]
    min_dist = float(1e30)

    for i in range(num_edges):
        a = polygon[i]
        b = polygon[(i + 1) % num_edges]

        ab = b - a
        ap = p - a
        ab_len2 = wp.dot(ab, ab)

        # Projection factor t
        t = wp.dot(ap, ab) / (ab_len2 + 1e-8)
        t = wp.clamp(t, 0.0, 1.0)

        # Closest point on segment
        proj = a + t * ab
        d = wp.length(p - proj)

        if d < min_dist:
            min_dist = d

    out_distances[tid] = min_dist

def compute_point_to_polygon_distance_warp(points_np: np.ndarray, polygon_np: np.ndarray) -> tuple[np.ndarray, float]:
    assert points_np.shape[1] == 2
    assert polygon_np.shape[1] == 2

    device = wp.get_preferred_device()

    points = wp.array(points_np.astype(np.float32), dtype=wp.vec2, device=device)
    polygon = wp.array(polygon_np.astype(np.float32), dtype=wp.vec2, device=device)
    distances = wp.empty(len(points_np), dtype=wp.float32, device=device)

    # Warmup kernel (for accurate timing)
    wp.launch(
        kernel=point_to_polygon_distance_kernel,
        dim=len(points_np),
        inputs=[points, polygon, polygon.shape[0], distances],
        device=device
    )
    wp.synchronize()

    # Timed kernel
    with wp.ScopedTimer("mesh SD") as timer:
        wp.launch(
            kernel=point_to_polygon_distance_kernel,
            dim=len(points_np),
            inputs=[points, polygon, polygon.shape[0], distances],
            device=device
        )
        wp.synchronize()

    return distances.numpy()


import warp as wp

@wp.kernel
def signed_point_to_polygon_distance_kernel(
    points: wp.array(dtype=wp.vec2),
    polygon: wp.array(dtype=wp.vec2),
    num_edges: int,
    out_signed_distances: wp.array(dtype=wp.float32),
    out_closest_points: wp.array(dtype=wp.vec2),
):
    tid = wp.tid()
    p = points[tid]
    px = p[0]
    py = p[1]

    min_dist = float(1e30)
    closest_point = wp.vec2(0.0, 0.0)
    inside = bool(False)

    for i in range(num_edges):
        a = polygon[i]
        b = polygon[(i + 1) % num_edges]

        ab = b - a
        ap = p - a
        ab_len2 = wp.dot(ab, ab)
        t = wp.dot(ap, ab) / (ab_len2 + 1e-8)
        t = wp.clamp(t, 0.0, 1.0)
        proj = a + t * ab
        d = wp.length(p - proj)

        if d < min_dist:
            min_dist = d
            closest_point = proj

        # Ray casting for point-in-polygon test
        ay = a[1]
        by = b[1]
        ax = a[0]
        bx = b[0]

        cond1 = (ay <= py and by > py) or (by <= py and ay > py)
        if cond1:
            slope = (bx - ax) / (by - ay + 1e-8)
            x_intersect = ax + slope * (py - ay)
            if px < x_intersect:
                inside = not inside

    if inside:
        signed_dist = -min_dist
    else:
        signed_dist = min_dist

    out_signed_distances[tid] = signed_dist
    out_closest_points[tid] = closest_point


def compute_signed_distance_warp(points_np: np.ndarray, polygon_np: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    assert points_np.shape[1] == 2
    assert polygon_np.shape[1] == 2

    device = wp.get_preferred_device()

    points = wp.array(points_np.astype(np.float32), dtype=wp.vec2, device=device)
    polygon = wp.array(polygon_np.astype(np.float32), dtype=wp.vec2, device=device)

    signed_distances = wp.empty(len(points_np), dtype=wp.float32, device=device)
    closest_points = wp.empty(len(points_np), dtype=wp.vec2, device=device)

    # Launch kernel
    wp.launch(
        kernel=signed_point_to_polygon_distance_kernel,
        dim=len(points_np),
        inputs=[points, polygon, polygon.shape[0], signed_distances, closest_points],
        device=device,
    )
    wp.synchronize()

    return signed_distances.numpy(), closest_points.numpy()

def csg_sdf(x, outer, inner, sdf1, sdf2):
    dist_to_outer, closest_pt_to_outer = compute_signed_distance_warp(x, outer)
    dist_to_inner, closest_pt_to_inner = compute_signed_distance_warp(x, inner)
    # plt.plot(closest_pt_to_outer[:, 0], closest_pt_to_outer[:, 1], 'o')
    # plt.plot(closest_pt_to_inner[:, 0], closest_pt_to_inner[:, 1], 'o')

    # plt.show()
    mask_in = dist_to_inner < 0.
    mask_out = dist_to_outer > 0.
    mask_between = (dist_to_inner >= 0) & (dist_to_outer <= 0)
    inputs = np.empty_like(x)
    inputs[mask_in] = closest_pt_to_inner[mask_in]
    inputs[mask_out] = closest_pt_to_outer[mask_out]
    inputs[mask_between] = x[mask_between]
    inputs = torch.from_numpy(inputs).cuda()
    output = torch.minimum(sdf1(inputs), sdf2(inputs)).squeeze().detach().cpu().numpy()
    output[mask_in] += dist_to_inner[mask_in]
    output[mask_out] += dist_to_outer[mask_out]
    return output

def csg_ani(square_se, circle_se, square_io, circle_io, num_frames=300, domain=None):
    if domain is None:
        domain = ([-2, 2], [-2, 2])


    X, Y = make_grid(*domain, res=500)

    # test_pts = torch.from_numpy(np.stack((X.flatten(), Y.flatten()), 1).astype('float32'))
    test_pts = np.stack((X.flatten(), Y.flatten()), 1).astype('float32')

    square_start = square_se[0]
    square_end = square_se[1]
    circle_start = circle_se[0]
    circle_end = circle_se[1]
    square_path = np.linspace(square_start, square_end, num_frames)
    circle_path = np.linspace(circle_start, circle_end, num_frames)

    square_inner, square_outer = square_io
    circle_inner, circle_outer = circle_io

    preds = []
    fps = []
    square_pos_last = square_path[0]
    circle_pos_last = circle_path[0]
    for i in range(num_frames):
        time_start = time.perf_counter()
        square_pos = square_path[i]
        circle_pos = circle_path[i]
        square_trans = square_pos - square_pos_last
        circle_trans = circle_pos - circle_pos_last
        square_sdf =  lambda x: sdf_square(x, torch.from_numpy(square_pos).cuda(), 2)
        circle_sdf = lambda x: sdf_circle(x, torch.from_numpy(circle_pos).cuda(), 1.2)
        square_inner = shapely.affinity.translate(square_inner, square_trans[0], square_trans[1])
        square_outer = shapely.affinity.translate(square_outer, square_trans[0], square_trans[1])
        circle_inner = shapely.affinity.translate(circle_inner, circle_trans[0], circle_trans[1])
        circle_outer = shapely.affinity.translate(circle_outer, circle_trans[0], circle_trans[1])
        square_pos_last = square_pos
        circle_pos_last = circle_pos
        union_outer = shapely.union(square_outer, circle_outer)
        union_inner = shapely.union(square_inner, circle_inner)
        outer_coords = np.asarray(union_outer.exterior.coords).astype(np.float32)
        inner_coords = np.asarray(union_inner.exterior.coords).astype(np.float32)

        pred = csg_sdf(test_pts, outer_coords, inner_coords, square_sdf, circle_sdf)
        # test_pts = torch.from_numpy(test_pts).cuda()
        # pred = torch.minimum(square_sdf(test_pts), circle_sdf(test_pts)).squeeze().detach().cpu().numpy()
        time_end = time.perf_counter()
        # print(f'Time elapsed: {1000.0 * (time_end - time_start)}ms')
        fps.append(1 // (time_end - time_start))
        # viz_fig, viz_ax = plt.subplots(figsize=(8, 8))
        # plot_sdf_from_vals(X, Y, pred, colorbar=False, ax=viz_ax, N=50, alpha=1)
        #
        # plt.tight_layout()
        # plt.show()
        preds.append(pred)

    return preds, fps


if __name__ == '__main__':
    circle_sdf = CircleSDF(torch.tensor([[0.5, 0.5]]), 1.2)
    circle_outer = carve(circle_sdf, deep=True, bbox=((-0.7, -0.7), (1.7, 1.7)))
    circle_inner = carve(circle_sdf, deep=True, inner=True, bbox=((-0.7, -0.7), (1.7, 1.7)))

    square_sdf = SquareSDF()

    square_outer = carve(square_sdf, deep=True, bbox=((-2.5, -2.5), (2.5, 2.5)))
    square_inner = carve(square_sdf, deep=True, inner=True, bbox=((-2.5, -2.5), (2.5, 2.5)))

    circle_start = np.array([0.8, 0.8])
    circle_end = np.array([0.2, 0.2])
    square_start = np.array([-0.8, -0.8])
    square_end = np.array([-0.2, -0.2])

    circle_outer = shapely.affinity.translate(circle_outer, 0.3, 0.3)
    circle_inner = shapely.affinity.translate(circle_inner, 0.3, 0.3)

    square_outer = shapely.affinity.translate(square_outer, -0.3, -0.3)
    square_inner = shapely.affinity.translate(square_inner, -0.3, -0.3)

    all_frames, fps = csg_ani(np.vstack([square_start, square_end]),
                         np.vstack([circle_start, circle_end]),
                         (square_inner, square_outer),
                         (circle_inner, circle_outer),
                         num_frames=5)

    all_frames, fps = csg_ani(np.vstack([square_start, square_end]),
                         np.vstack([circle_start, circle_end]),
                         (square_inner, square_outer),
                         (circle_inner, circle_outer),
                         num_frames=120)

    for i in range(len(all_frames)):
        domain = ([-2, 2], [-2, 2])

        X, Y = make_grid(*domain, res=500)
        viz_fig, viz_ax = plt.subplots(figsize=(8, 8))
        plot_sdf_from_vals(X, Y, all_frames[i], colorbar=False, ax=viz_ax, N=50, alpha=1)
        plt.xticks([])
        plt.yticks([])
        plt.tight_layout()
        viz_fig.text(0.05, 0.95, f'FPS: {fps[i]}', ha='left', va='top', fontsize=40)
        plt.savefig(f'assets/csg_frames/frame_{i:03d}.png')
        plt.close()
