import itertools

from matplotlib import pyplot as plt, font_manager
from matplotlib.patches import FancyArrow, Polygon
from scipy.spatial import Voronoi
import torch

plt.rcParams.update({
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'svg.fonttype': 'path',
    'svg.hashsalt': 'fixed-salt',  # any constant string for reproducibility
})
font = font_manager.FontProperties(size=10)
font_legend = font_manager.FontProperties(size=10)
strip_svg_meta: dict[str, None] = {k: None for k in ('Creator', 'Date', 'Format', 'Type')}
strip_pdf_meta: dict[str, None] = {k: None for k in ('Title', 'Author', 'Subject', 'Keywords', 'Creator', 'Producer', 'CreationDate', 'ModDate', 'Trapped')}


def add_arrow(ax, origin, vector, color):
    arrow = FancyArrow(
        *origin,
        *vector,
        width=.001,
        length_includes_head=True,
        head_width=.1,
        head_length=.1,
        shape='full',
        overhang=0,
        head_starts_at_zero=False,
        color=color,
        zorder=5.
    )
    return ax.add_patch(arrow)


def plot_2d():
    dtype = torch.float64

    basis_xy = torch.as_tensor([[.4, .1], [.6, .6]], dtype=dtype)  # (N, D)
    reduce_transform = torch.as_tensor([[1, 0], [-2, 1]], dtype=dtype)  # (N, N)
    target_xy = torch.as_tensor([[.4, .33]], dtype=dtype)  # (1, D)
    basis_vector_origin_ab = torch.as_tensor([[-0., -1.]], dtype=dtype)  # (1, D)

    reduced_basis_xy = reduce_transform @ basis_xy  # (N, D)
    # corner_coords_xy = torch.as_tensor(list(itertools.product([-1., 1.], repeat=basis_xy.size(-1))), dtype=dtype)  # (2^D, D)
    corner_coords_xy = torch.as_tensor([[1., 1.], [-1., 1.], [-1., -1.], [1., -1.]], dtype=dtype)  # (2^D, D)
    basis_vector_origin_xy = basis_vector_origin_ab @ basis_xy  # (1, D)
    o_basis_1_t_xy, r_1 = torch.linalg.qr(basis_xy.transpose(-2, -1), mode='reduced')
    o_basis_1_xy = o_basis_1_t_xy.transpose(-2, -1) * r_1.diagonal(dim1=-2, dim2=-1)[..., None]  # (N, D)
    o_basis_2_t_xy, r_2 = torch.linalg.qr(basis_xy.flip(dims=(-2,)).transpose(-2, -1), mode='reduced')
    o_basis_2_xy = o_basis_2_t_xy.transpose(-2, -1) * r_2.diagonal(dim1=-2, dim2=-1)[..., None]  # (N, D)
    target_rtn_ab = torch.linalg.solve(basis_xy, target_xy, left=False).round()  # (1, D)
    target_rtn_xy = target_rtn_ab @ basis_xy  # (1, D)
    proj_1_p1 = o_basis_1_xy[-1].dot(target_xy[0]) / o_basis_1_xy[-1].dot(basis_xy[-1])  # ()
    plane_1_p1 = proj_1_p1.round()  # ()
    target_1_p1_xy = target_xy - (proj_1_p1 - plane_1_p1) * o_basis_1_xy[-1]  # (1, D)
    target_1_r1_xy = target_xy - plane_1_p1 * basis_xy[-1]  # (1, D)
    proj_1_p2 = o_basis_1_xy[-2].dot(target_1_r1_xy[0]) / o_basis_1_xy[-2].dot(basis_xy[-2])  # ()
    plane_1_p2 = proj_1_p2.round()  # ()
    target_1_p2_xy = target_1_p1_xy - (proj_1_p2 - plane_1_p2) * o_basis_1_xy[-2]  # ()

    proj_2_p1 = o_basis_2_xy[-1].dot(target_xy[0]) / o_basis_2_xy[-1].dot(basis_xy[0])  # ()
    plane_2_p1 = proj_2_p1.round()  # ()
    target_2_p1_xy = target_xy - (proj_2_p1 - plane_2_p1) * o_basis_2_xy[-1]  # (1, D)
    target_2_r1_xy = target_xy - plane_2_p1 * basis_xy[0]  # (1, D)
    proj_2_p2 = o_basis_2_xy[-2].dot(target_2_r1_xy[0]) / o_basis_2_xy[-2].dot(basis_xy[1])  # ()
    plane_2_p2 = proj_2_p2.round()  # ()
    target_2_p2_xy = target_2_p1_xy - (proj_2_p2 - plane_2_p2) * o_basis_2_xy[-2]  # ()

    lattice_sizes = torch.linalg.solve(basis_xy, corner_coords_xy, left=False).abs().amax(dim=0).ceil().to(dtype=torch.int64) + 1  # (D), + 1 to prevent missing voronoi edges
    lattice_ab = torch.stack(torch.meshgrid([torch.arange(-s, s + 1, dtype=dtype) for s in lattice_sizes], indexing='ij'), dim=-1)  # (..., D)
    lattice_mid_ab = torch.stack(torch.meshgrid([torch.arange(-s - .5, s + 1., dtype=dtype) for s in lattice_sizes], indexing='ij'), dim=-1)  # (..., D)
    lattice_xy = lattice_ab @ basis_xy  # (..., D)
    lattice_mid_xy = lattice_mid_ab @ basis_xy  # (..., D)
    lattice_cuboid_1_xy = lattice_xy[..., None, :] + .5 * corner_coords_xy @ o_basis_1_xy  # (..., 2^D, D)
    lattice_cuboid_2_xy = lattice_xy[..., None, :] + .5 * corner_coords_xy @ o_basis_2_xy  # (..., 2^D, D)
    reduced_lattice_sizes = torch.linalg.solve(reduced_basis_xy, corner_coords_xy, left=False).abs().amax(dim=0).ceil().to(dtype=torch.int64)  # (D)
    reduced_lattice_ab = torch.stack(torch.meshgrid([torch.arange(-s, s + 1, dtype=dtype) for s in reduced_lattice_sizes], indexing='ij'), dim=-1)  # (..., D)
    reduced_lattice_xy = reduced_lattice_ab @ reduced_basis_xy  # (..., D)

    plane_id = int(plane_1_p1) + lattice_sizes[-1]  # int

    vor = Voronoi(lattice_xy.flatten(end_dim=-2))
    vor_vertices = vor.vertices  # (?, 2)
    vor_ridge_vertices = [rv for rv in vor.ridge_vertices if -1 not in rv]  # list[list[int]], (?, 2)
    vor_ridges = vor_vertices[vor_ridge_vertices]  # (?, 2, 2)

    closest_xy = lattice_xy.flatten(end_dim=-2)[torch.linalg.vector_norm(lattice_xy - target_xy, ord=2, dim=-1).argmin(), None]  # (1, D)

    colors = '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#000000'

    fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(8., 4.))

    for i, ax in enumerate(axs.flat):
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-1., 1.)
        ax.set_ylim(-1., 1.)
        ax.set_aspect('equal')
        ax.set_facecolor((1., 1., 1., 1.))

        ax.scatter(lattice_xy[..., 0], lattice_xy[..., 1], color=colors[-1], s=25., marker='o', linewidths=0., zorder=3.)
        if i != 1:
            ax.scatter(target_xy[:, 0], target_xy[:, 1], color=colors[2], s=49., marker='P', linewidths=0., zorder=4.)


    for ridge in vor_ridges:
        axs[1, 0].plot(ridge[:, 0], ridge[:, 1], color=colors[7], linewidth=1., linestyle='-')

    axs[0, 0].plot(lattice_xy[[0, -1], :, 0], lattice_xy[[0, -1], :, 1], color=colors[0], linewidth=1., linestyle=':')
    axs[0, 0].plot(lattice_xy[:, [0, -1], 0].transpose(0, 1), lattice_xy[:, [0, -1], 1].transpose(0, 1), color=colors[0], linewidth=1., linestyle=':')
    axs[0, 1].plot(reduced_lattice_xy[[0, -1], :, 0], reduced_lattice_xy[[0, -1], :, 1], color=colors[0], linewidth=1., linestyle=':')
    axs[0, 1].plot(reduced_lattice_xy[:, [0, -1], 0].transpose(0, 1), reduced_lattice_xy[:, [0, -1], 1].transpose(0, 1), color=colors[0], linewidth=1., linestyle=':')
    axs[0, 2].plot(lattice_xy[[0, -1], :plane_id, 0], lattice_xy[[0, -1], :plane_id, 1], color=colors[1], linewidth=1., linestyle='--')
    axs[0, 2].plot(lattice_xy[[0, -1], plane_id, 0], lattice_xy[[0, -1], plane_id, 1], color=colors[1], linewidth=1., linestyle='--')
    axs[0, 2].plot(lattice_xy[[0, -1], plane_id+1:, 0], lattice_xy[[0, -1], plane_id+1:, 1], color=colors[1], linewidth=1., linestyle='--')
    axs[0, 3].plot(lattice_xy[[0, -1], plane_id, 0], lattice_xy[[0, -1], plane_id, 1], color=colors[1], linewidth=1., linestyle='--')
    axs[1, 1].plot(lattice_mid_xy[[0, -1], :, 0], lattice_mid_xy[[0, -1], :, 1], color=colors[7], linewidth=1., linestyle='-')
    axs[1, 1].plot(lattice_mid_xy[:, [0, -1], 0].transpose(0, 1), lattice_mid_xy[:, [0, -1], 1].transpose(0, 1), color=colors[7], linewidth=1., linestyle='-')

    for cuboid_xy in lattice_cuboid_1_xy.flatten(end_dim=-3):
        polygon = Polygon(cuboid_xy, closed=True, fill=False, edgecolor=colors[7], linewidth=1., linestyle='-')
        axs[1, 2].add_patch(polygon)
    for cuboid_xy in lattice_cuboid_2_xy.flatten(end_dim=-3):
        polygon = Polygon(cuboid_xy, closed=True, fill=False, edgecolor=colors[7], linewidth=1., linestyle='-')
        axs[1, 3].add_patch(polygon)

    axs[0, 0].scatter(closest_xy[:, 0], closest_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    axs[0, 2].scatter(target_xy[:, 0], target_xy[:, 1], color=colors[2], s=49., marker='P', linewidths=0., zorder=4.)
    axs[0, 2].scatter(target_1_p1_xy[:, 0], target_1_p1_xy[:, 1], color=colors[2], s=49., marker='X', linewidths=0., zorder=4.)
    # axs[0, 3].scatter(lattice_xy[:, plane_id, 0], lattice_xy[:, plane_id, 1], color=colors[-1], edgecolors=colors[1], s=36., marker='o', linewidths=1., zorder=4.)
    axs[0, 3].scatter(target_1_p1_xy[:, 0], target_1_p1_xy[:, 1], color=colors[2], s=49., marker='X', linewidths=0., zorder=4.)
    axs[0, 3].scatter(target_1_p2_xy[:, 0], target_1_p2_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    axs[1, 0].scatter(closest_xy[:, 0], closest_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    axs[1, 1].scatter(target_rtn_xy[:, 0], target_rtn_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    axs[1, 2].scatter(target_1_p2_xy[:, 0], target_1_p2_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    axs[1, 3].scatter(target_2_p2_xy[:, 0], target_2_p2_xy[:, 1], color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)

    add_arrow(ax=axs[0, 0], origin=basis_vector_origin_xy[0], vector=basis_xy[0], color=colors[0])
    add_arrow(ax=axs[0, 0], origin=basis_vector_origin_xy[0], vector=basis_xy[1], color=colors[0])
    add_arrow(ax=axs[0, 1], origin=basis_vector_origin_xy[0], vector=reduced_basis_xy[0], color=colors[0])
    add_arrow(ax=axs[0, 1], origin=basis_vector_origin_xy[0], vector=reduced_basis_xy[1], color=colors[0])
    add_arrow(ax=axs[0, 2], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[0], color=colors[1])
    add_arrow(ax=axs[0, 2], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[1], color=colors[1])
    add_arrow(ax=axs[0, 3], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[0], color=colors[1])
    add_arrow(ax=axs[0, 3], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[1], color=colors[1])
    add_arrow(ax=axs[1, 1], origin=basis_vector_origin_xy[0], vector=basis_xy[0], color=colors[0])
    add_arrow(ax=axs[1, 1], origin=basis_vector_origin_xy[0], vector=basis_xy[1], color=colors[0])
    add_arrow(ax=axs[1, 2], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[0], color=colors[1])
    add_arrow(ax=axs[1, 2], origin=basis_vector_origin_xy[0], vector=o_basis_1_xy[1], color=colors[1])
    add_arrow(ax=axs[1, 3], origin=basis_vector_origin_xy[0], vector=o_basis_2_xy[0], color=colors[1])
    add_arrow(ax=axs[1, 3], origin=basis_vector_origin_xy[0], vector=o_basis_2_xy[1], color=colors[1])

    ax = axs[0, 0]
    handles = {}
    handles['Lattice Point'] = ax.scatter(10., 10., color=colors[-1], s=25., marker='o', linewidths=0., zorder=3.)
    handles['Target Point'] = ax.scatter(10., 10., color=colors[2], s=49., marker='P', linewidths=0., zorder=4.)
    handles['Returned Lattice Point'] = ax.scatter(10., 10., color=colors[-1], edgecolors=colors[2], s=49., marker='o', linewidths=2., zorder=4.)
    handles['Babai\'s Projected Point'] = ax.scatter(10., 10., color=colors[2], s=49., marker='X', linewidths=0., zorder=4.)
    handles['Basis Vector'], = ax.plot(10., 10., color=colors[0], linewidth=1., linestyle='-')
    handles['Basis Direction'], = ax.plot(10., 10., color=colors[0], linewidth=1., linestyle=':')
    handles['Gram-Schmidt Vector'], = ax.plot(10., 10., color=colors[1], linewidth=1., linestyle='-')
    handles['Babai\'s Hyperplane'], = ax.plot(10., 10., color=colors[1], linewidth=1., linestyle='--')
    handles['Rounding Boundary'], = ax.plot(10., 10., color=colors[7], linewidth=1., linestyle='-')
    fig.legend(handles.values(), handles.keys(), loc='center left', ncol=1, framealpha=1., bbox_to_anchor=(1., .5), prop=font_legend)#, handler_map={FancyArrow: HandlerFancyArrow()})

    axs[0, 0].set_xlabel('(a) Closest Vector Problem', fontproperties=font)
    axs[0, 1].set_xlabel('(b) Basis Reduction', fontproperties=font)
    axs[0, 2].set_xlabel('(c) Projection Step 1', fontproperties=font)
    axs[0, 3].set_xlabel('(d) Projection Step 2', fontproperties=font)
    axs[1, 0].set_xlabel('(e) Optimal / Voronoi', fontproperties=font)
    axs[1, 1].set_xlabel('(f) Round-to-Nearest', fontproperties=font)
    axs[1, 2].set_xlabel('(g) Babai', fontproperties=font)
    axs[1, 3].set_xlabel('(h) Babai (Another Order)', fontproperties=font)
    axs[0, 0].set_ylabel('Babai\'s Algorithm', fontproperties=font)
    axs[1, 0].set_ylabel('Rounding Boundaries', fontproperties=font)

    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()
    fig.savefig(f'1_babai.pdf', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_pdf_meta)
    fig.savefig(f'1_babai.svg', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_svg_meta)
    fig.show()
    # fig.clf()


if __name__ == '__main__':
    plot_2d()
