from matplotlib import pyplot as plt, font_manager
from matplotlib.patches import FancyArrow, Arc
import torch

plt.rcParams.update({
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'svg.fonttype': 'path',
    'svg.hashsalt': 'fixed-salt',  # any constant string for reproducibility
})
font = font_manager.FontProperties(size=12.)
font_legend = font_manager.FontProperties(size=10.)
strip_svg_meta: dict[str, None] = {k: None for k in ('Creator', 'Date', 'Format', 'Type')}
strip_pdf_meta: dict[str, None] = {k: None for k in ('Title', 'Author', 'Subject', 'Keywords', 'Creator', 'Producer', 'CreationDate', 'ModDate', 'Trapped')}


def get_projection_matrix() -> torch.Tensor:
    """
    Compute P = K [R | t]
    :return: (3, 4)
    """

    dtype: torch.dtype = torch.float64

    focal_length: torch.Tensor = torch.as_tensor(1., dtype=dtype)
    frame_center: torch.Tensor = torch.as_tensor([0., 0.], dtype=dtype)
    camera_position: torch.Tensor = torch.as_tensor([10., 10., 10.], dtype=dtype)
    camera_rotation: torch.Tensor = torch.as_tensor([
        45.,
        torch.as_tensor(- 3. ** -.5, dtype=dtype).asin() * (180. / torch.pi),
        -150.,
    ], dtype=dtype)

    fx, fy = focal_length, focal_length
    cx, cy = frame_center
    rx, ry, rz = camera_rotation.deg2rad()

    K: torch.Tensor = torch.as_tensor([
        [fx, 0., cx],
        [0., fy, cy],
        [0., 0., 1.],
    ], dtype=dtype)

    Rx: torch.Tensor = torch.as_tensor([
        [1., 0., 0.],
        [0., rx.cos(), -rx.sin()],
        [0., rx.sin(), rx.cos()],
    ], dtype=dtype)  # roll
    Ry: torch.Tensor = torch.as_tensor([
        [ry.cos(), 0., ry.sin()],
        [0., 1., 0.],
        [-ry.sin(), 0., ry.cos()],
    ], dtype=dtype)  # pitch
    Rz: torch.Tensor = torch.as_tensor([
        [rz.cos(), -rz.sin(), 0.],
        [rz.sin(), rz.cos(), 0.],
        [0., 0., 1.],
    ], dtype=dtype)  # yaw
    R: torch.Tensor = Rz @ Ry @ Rx

    R = torch.as_tensor([
        [-1. / 2. ** .5, 1. / 2. ** .5, 0.],
        [-1. / 6. ** .5, -1. / 6. ** .5, 2. ** .5 / 3. ** .5],
        [1. / 3. ** .5, 1. / 3. ** .5, 1. / 3. ** .5],
    ], dtype=dtype)

    t: torch.Tensor = -R @ camera_position[:, None]
    P: torch.Tensor = K @ torch.cat([R, t], dim=-1)
    return P


def project_points(points: torch.Tensor, orthographic: bool = True) -> torch.Tensor:
    """
    Project 3D points onto 2D image frame.
    :param points: (..., 3)
    :return: (..., 2)
    """

    projection_matrix: torch.Tensor = get_projection_matrix()  # (3, 4)
    points_3d: torch.Tensor = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)  # (..., 4)
    points_2d: torch.Tensor = points_3d @ projection_matrix.t()  # (..., 3)
    if orthographic:
        projected_points: torch.Tensor = points_2d[..., :-1]
    else:
        projected_points: torch.Tensor = points_2d[..., :-1] / points_2d[..., -1:]  # (..., 2)
    return projected_points  # (..., 2)


def add_arrow(ax, start, end, **kwargs):
    arrow = FancyArrow(
        *start,
        *(end - start),
        width=.001,
        length_includes_head=True,
        head_width=.1,
        head_length=.1,
        shape='full',
        overhang=0,
        head_starts_at_zero=False,
        **kwargs,
    )
    return ax.add_patch(arrow)


def add_line(ax, start, end, **kwargs):
    return ax.plot([start[0], end[0]], [start[1], end[1]], **kwargs)[0]


def add_point(ax, point, **kwargs):
    return ax.scatter(point[:, 0], point[:, 1],  **kwargs)


def plot_obq_argmin():
    dtype: torch.dtype = torch.float64

    basis = torch.as_tensor([[2., 0.], [1.2, 1.8]], dtype=dtype)
    target_ab = torch.as_tensor([[.2, .5]], dtype=dtype)
    assert (target_ab[:, 1:] <= .5).all()

    origin = torch.zeros(1, 2, dtype=dtype)
    o_basis_t, r = torch.linalg.qr(basis.t(), mode='reduced')
    o_basis = o_basis_t.t() * r.diagonal()[..., None]
    i_basis = torch.linalg.inv(basis).t()
    assert ((basis @ basis.t()) @ (i_basis @ i_basis.t())).allclose(torch.eye(2, dtype=dtype), atol=1e-6)
    target = target_ab @ basis
    target_0 = target_ab[:, :1] @ basis[:1]
    target_z = torch.linalg.lstsq(o_basis[:-1].t(), target.t())[0].t() @ o_basis[:-1]

    span_line_endpoints = torch.stack([-100. * basis[0], 100. * basis[0]], dim=0)

    colors = '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#000000'
    colors_0 = 'black', 'gray'
    colors_2 = 'darkblue', 'blue', 'cyan', 'green'

    fig, ax = plt.subplots(1, 1, figsize=(2.7, 2.7))

    # add_arrow(ax, origin[0], basis[0], color=colors_0[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 0
    add_arrow(ax, origin[0], basis[1], color=colors_2[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 1
    add_arrow(ax, origin[0], o_basis[1], color=colors_2[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 1 proj, prep basis 0
    add_arrow(ax, origin[0], i_basis[1], color=colors_2[2], linestyle='-', linewidth=1., zorder=5.)  # inv basis 1, prep basis 0, 2
    add_arrow(ax, target[0], target_z[0], color=colors_2[3], linestyle='-', linewidth=1., zorder=5.9)  # error
    add_arrow(ax, target[0], target_0[0], color=colors_2[1], linestyle='-', linewidth=1., zorder=5.7)  # error direction 1
    add_arrow(ax, target_0[0], target_z[0], color=colors_0[-1], linestyle='-', linewidth=1., zorder=5.5)  # error direction 0

    add_line(ax, span_line_endpoints[0], span_line_endpoints[1], color=colors_0[0], linestyle='-', linewidth=1., zorder=1.5)  # basis 0 span line
    add_line(ax, basis[1], o_basis[1], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # basis 1 proj line
    add_line(ax, o_basis[1] * -100. - origin[0], o_basis[1] * 100. - origin[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 1 span

    # add_point(ax, origin, color=colors[-1], s=4., zorder=3.)
    add_point(ax, target, color=colors[-1], marker='P', s=49., zorder=3.)
    add_point(ax, target_z, color=colors[-1], marker='X', s=49., zorder=3.)
    # add_point(ax, target_0, color=colors[-1], s=4., zorder=3.)

    ax.set_xlim(-.5, 1.75)
    ax.set_ylim(-.25, 2.)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    ax.set_facecolor('white')

    handles = {}
    handles[r"Auxiliary Line in Orthogonal Directions"], = ax.plot(100., 100., color=colors_0[-1], linestyle=':', linewidth=1.)  # proj line
    handles[r"Basis Vector $\boldsymbol{b}_{{j}_{2}}$"], = ax.plot(100., 100., color=colors_2[0], linestyle='-', linewidth=1.)  # basis 2
    handles[r"Target Point $\boldsymbol{y} \coloneq \Sigma_{j} {\zeta}_{j} \boldsymbol{b}_{j}$"] = ax.scatter(100., 100., color=colors[-1], marker='P', s=49.)
    handles[r"Nearest Hyperplane (Hyperline) $\mathcal{NHP} \coloneq \left\lfloor {\zeta}_{{j}_{2}} \right\rceil \boldsymbol{b}_{{j}_{2}} + \mathrm{Span} \left\{ \boldsymbol{b}_{j} \; | \; {j} \ne {j}_{2} \right\}$"], = ax.plot(100., 100., color=colors_0[0], linestyle='-', linewidth=1.)  # basis 0
    handles[r"Babai's Projected Point $\mathrm{Proj}_{\mathcal{NHP}} \left( \boldsymbol{y} \right) \coloneq \Sigma_{j} \left( {\zeta}_{j} + \Delta {\zeta}_{j} \right) \boldsymbol{b}_{j}$"] = ax.scatter(100., 100., color=colors[-1], marker='X', s=49.)
    handles[r"Error Vector $\Delta \boldsymbol{y} \coloneq \mathrm{Proj}_{\mathcal{NHP}} \left( \boldsymbol{y} \right) - \boldsymbol{y} = \Sigma_{j} \Delta {\zeta}_{j} \boldsymbol{b}_{j}$"], = ax.plot(100., 100., color=colors_2[3], linestyle='-', linewidth=1.)  # error
    handles[r"Error Component Vector $\Delta {\zeta}_{{j}_{2}} \boldsymbol{b}_{{j}_{2}}$"], = ax.plot(100., 100., color=colors_2[1], linestyle='-', linewidth=1.)  # error direction 2
    handles[r"Remaining Error Component Vector $\Sigma_{{j} \ne {j}_{2}} \Delta {\zeta}_{j} \boldsymbol{b}_{j}$"], = ax.plot(100., 100., color=colors_0[-1], linestyle='-', linewidth=1.)  # error direction 0
    handles[r"Inverse Basis Vector $\boldsymbol{n}_{{j}_{2}}: \left< \boldsymbol{n}_{{j}_{2}}, \boldsymbol{b}_{{j}_{2}} \right> = 1; \boldsymbol{n}_{{j}_{2}} \perp \boldsymbol{b}_{j}, \forall {j} \ne {j}_{2}$"], = ax.plot(100., 100., color=colors_2[2], linestyle='-', linewidth=1.)  # inv basis 2, prep basis 0, 1
    handles[r"Projected Basis Vector $\mathrm{Proj}_{\boldsymbol{n}_{{j}_{2}}} \left( \boldsymbol{b}_{{j}_{2}} \right)$"], = ax.plot(100., 100., color=colors_2[0], linestyle=':', linewidth=1.5)  # basis 2 proj, prep basis 0, 1

    fig.legend(handles.values(), handles.keys(), loc='center left', ncol=1, framealpha=1., bbox_to_anchor=(1., .5), prop=font_legend)
    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()
    fig.savefig(f'3_obq_2d.pdf', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_pdf_meta)
    fig.savefig(f'3_obq_2d.svg', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_svg_meta)
    # fig.show()
    fig.clf()


def plot_obq_update():
    dtype: torch.dtype = torch.float64

    # ref_points = torch.as_tensor([
    #     [-1., -1., -1.],
    #     [1., -1., -1.],
    #     [-1., 1., -1.],
    #     [-1., -1., 1.],
    #     [1., 1., -1.],
    #     [1., -1., 1.],
    #     [-1., 1., 1.],
    #     [1., 1., 1.],
    # ], dtype=dtype)
    # ref_points_proj = project_points(ref_points)
    # for i in range(len(ref_points_proj)):
    #     ax.scatter(ref_points_proj[i, 0], ref_points_proj[i, 1], color=f'#{"".join(["dd" if ref_points[i, j] > 0. else "00" for j in range(3)])}')

    basis = torch.as_tensor([[2.5, 0., 0.], [1.2, 1.8, 0.], [1.6, -.8, 1.5]], dtype=dtype)
    target_ab = torch.as_tensor([[.2, .5, .5]], dtype=dtype)
    assert (target_ab[:, 1:] <= .5).all()

    origin = torch.zeros(1, 3, dtype=dtype)
    o_basis_t, r = torch.linalg.qr(basis.t(), mode='reduced')
    o_basis = o_basis_t.t() * r.diagonal()[..., None]
    o_basis_2_t, r_2 = torch.linalg.qr(basis[[0, 2, 1]].t(), mode='reduced')
    o_basis_2 = (o_basis_2_t.t() * r_2.diagonal()[..., None])[[0, 2, 1]]
    i_basis = torch.linalg.inv(basis).t()
    assert ((basis @ basis.t()) @ (i_basis @ i_basis.t())).allclose(torch.eye(3, dtype=dtype), atol=1e-6)
    target = target_ab @ basis
    target_0 = target_ab[:, :1] @ basis[:1]
    target_1 = target_ab[:, :2] @ basis[:2]
    target_01 = target_1 - target_0
    target_z = torch.linalg.lstsq(o_basis[:-1].t(), target.t())[0].t() @ o_basis[:-1]
    target_x = torch.linalg.lstsq(o_basis[1:].t(), target.t())[0].t() @ o_basis[1:]
    target_xz = torch.linalg.lstsq(o_basis[1:].t(), target_z.t())[0].t() @ o_basis[1:]
    target_1_x = torch.linalg.lstsq(o_basis[1:].t(), target_1.t())[0].t() @ o_basis[1:]
    target_z_1 = torch.linalg.lstsq(basis[:-1].t(), target_z.t())[0].t()[:, 1:2] @ basis[1:2]
    error = target_z - target
    error_ab = torch.linalg.lstsq(basis.t(), error.t())[0].t()
    error_0 = error_ab[:, :1] @ basis[:1]
    error_1 = error_ab[:, :2] @ basis[:2]
    error_01 = error_1 - error_0 + target_1

    span_line_endpoints = torch.stack([-100. * basis[0], 100. * basis[0]], dim=0)
    nearest_plane_endpoints = torch.stack([-100. * o_basis[0], 100. * o_basis[0], 100. * o_basis[0] + 100. * o_basis[1], -100. * o_basis[0] + 100. * o_basis[1]], dim=0)
    ortho_plane_endpoints = torch.stack([-100. * o_basis[1], 100. * o_basis[1], 100. * o_basis[1] + 100. * o_basis[2], -100. * o_basis[1] + 100. * o_basis[2]], dim=0)

    angle_theta = (i_basis[1].dot(i_basis[2]) / (torch.linalg.vector_norm(i_basis[1]) * torch.linalg.vector_norm(i_basis[2]))).acos() * (180. / torch.pi)

    origin_proj = project_points(origin)
    basis_proj = project_points(basis)
    o_basis_proj = project_points(o_basis)
    o_basis_2_proj = project_points(o_basis_2)
    i_basis_proj = project_points(i_basis)
    target_proj = project_points(target)
    target_0_proj = project_points(target_0)
    target_1_proj = project_points(target_1)
    target_01_proj = project_points(target_01)
    target_z_proj = project_points(target_z)
    target_x_proj = project_points(target_x)
    target_xz_proj = project_points(target_xz)
    target_1_x_proj = project_points(target_1_x)
    target_z_1_proj = project_points(target_z_1)
    error_01_proj = project_points(error_01)

    span_line_endpoints_proj = project_points(span_line_endpoints)
    nearest_plane_endpoints_proj = project_points(nearest_plane_endpoints)
    ortho_plane_endpoints_proj = project_points(ortho_plane_endpoints)

    colors = '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#000000'
    colors_0 = 'black', 'gray'
    colors_1 = 'darkred', 'red', 'magenta', 'yellow'
    colors_2 = 'darkblue', 'blue', 'cyan', 'green'
    colors_bg = 'pink', 'lightblue'

    fig, axs = plt.subplots(2, 2, figsize=(5.7, 5.7), gridspec_kw={'width_ratios': [1., 1.]})
    ax, ax_01, ax_12, ax_simple = axs[0, 1], axs[1, 0], axs[1, 1], axs[0, 0]

    # add_arrow(ax, origin_proj[0], basis_proj[0], color=colors_0[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 0
    add_arrow(ax, origin_proj[0], basis_proj[1], color=colors_1[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 1
    add_arrow(ax, origin_proj[0], basis_proj[2], color=colors_2[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 2
    add_arrow(ax, origin_proj[0], o_basis_proj[1], color=colors_1[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 1 proj, prep basis 0, 2
    add_arrow(ax, origin_proj[0], o_basis_2_proj[2], color=colors_2[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 2 proj, prep basis 0, 1
    add_arrow(ax, origin_proj[0], i_basis_proj[1], color=colors_1[2], linestyle='-', linewidth=1., zorder=5.8)  # inv basis 1, prep basis 0, 2
    add_arrow(ax, origin_proj[0], i_basis_proj[2], color=colors_2[2], linestyle='-', linewidth=1., zorder=5.8)  # inv basis 2, prep basis 0, 1
    add_arrow(ax, target_proj[0], target_z_proj[0], color=colors_2[3], linestyle='-', linewidth=1., zorder=5.9)  # error
    add_arrow(ax, target_x_proj[0], target_xz_proj[0], color=colors_2[3], linestyle=':', linewidth=1.5, zorder=5.4)  # error proj
    add_arrow(ax, target_proj[0], target_1_proj[0], color=colors_2[1], linestyle='-', linewidth=1., zorder=5.7)  # error direction 2
    add_arrow(ax, target_1_proj[0], error_01_proj[0], color=colors_1[1], linestyle='-', linewidth=1., zorder=5.6)  # error direction 1
    add_arrow(ax, error_01_proj[0], target_z_proj[0], color=colors_0[-1], linestyle='-', linewidth=1., zorder=5.5)  # error direction 0
    add_arrow(ax, target_x_proj[0], target_1_x_proj[0], color=colors_2[1], linestyle=':', linewidth=1.5, zorder=5.3)  # error direction 2 proj
    add_arrow(ax, target_1_x_proj[0], target_xz_proj[0], color=colors_1[1], linestyle=':', linewidth=1.5, zorder=5.2)  # error direction 1 proj

    # add_arrow(ax_simple, origin_proj[0], basis_proj[0], color=colors_0[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 0
    add_arrow(ax_simple, origin_proj[0], basis_proj[1], color=colors_1[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 1
    add_arrow(ax_simple, origin_proj[0], basis_proj[2], color=colors_2[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 2
    add_arrow(ax_simple, target_proj[0], target_z_proj[0], color=colors_2[3], linestyle='-', linewidth=1., zorder=5.9)  # error

    add_arrow(ax_12, origin[0, 1:], o_basis[1, 1:], color=colors_1[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 1 proj, prep basis 0, 2
    add_arrow(ax_12, origin[0, 1:], o_basis_2[2, 1:], color=colors_2[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 2 proj, prep basis 0, 1
    add_arrow(ax_12, origin[0, 1:], i_basis[1, 1:], color=colors_1[2], linestyle='-', linewidth=1., zorder=5.8)  # inv basis 1, prep basis 0, 2
    add_arrow(ax_12, origin[0, 1:], i_basis[2, 1:], color=colors_2[2], linestyle='-', linewidth=1., zorder=5.8)  # inv basis 2, prep basis 0, 1
    add_arrow(ax_12, target_x[0, 1:], target_xz[0, 1:], color=colors_2[3], linestyle=':', linewidth=1.5, zorder=5.4)  # error proj
    add_arrow(ax_12, target_x[0, 1:], target_1_x[0, 1:], color=colors_2[1], linestyle=':', linewidth=1.5, zorder=5.3)  # error direction 2 proj
    add_arrow(ax_12, target_1_x[0, 1:], target_xz[0, 1:], color=colors_1[1], linestyle=':', linewidth=1.5, zorder=5.2)  # error direction 1 proj

    # add_arrow(ax_01, origin[0, :-1].flip(dims=(-1,)), basis[0, :-1].flip(dims=(-1,)), color=colors_0[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 0
    add_arrow(ax_01, origin[0, :-1].flip(dims=(-1,)), basis[1, :-1].flip(dims=(-1,)), color=colors_1[0], linestyle='-', linewidth=1., zorder=5.1)  # basis 1
    add_arrow(ax_01, origin[0, :-1].flip(dims=(-1,)), o_basis[1, :-1].flip(dims=(-1,)), color=colors_1[0], linestyle=':', linewidth=1.5, zorder=5.1)  # basis 1 proj, prep basis 0, 2
    add_arrow(ax_01, target_1[0, :-1].flip(dims=(-1,)), error_01[0, :-1].flip(dims=(-1,)), color=colors_1[1], linestyle='-', linewidth=1., zorder=5.6)  # error direction 1
    add_arrow(ax_01, error_01[0, :-1].flip(dims=(-1,)), target_z[0, :-1].flip(dims=(-1,)), color=colors_0[-1], linestyle='-', linewidth=1., zorder=5.5)  # error direction 0
    add_arrow(ax_01, target_1_x[0, :-1].flip(dims=(-1,)), target_xz[0, :-1].flip(dims=(-1,)), color=colors_1[1], linestyle=':', linewidth=1.5, zorder=5.2)  # error direction 1 proj

    add_line(ax, span_line_endpoints_proj[0], span_line_endpoints_proj[1], color=colors_0[0], linestyle='-', linewidth=1., zorder=1.5)  # basis 0 span proj line
    add_line(ax, basis_proj[1], o_basis_proj[1], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # basis 1 proj line
    add_line(ax, basis_proj[2], o_basis_2_proj[2], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # basis 2 proj line
    add_line(ax, target_proj[0], target_x_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # target proj line
    add_line(ax, target_1_x_proj[0], target_1_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # target proj proj line
    add_line(ax, error_01_proj[0], target_xz_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # error proj line
    add_line(ax, origin_proj[0], o_basis_proj[1] * 100. - origin_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 1 span
    add_line(ax, origin_proj[0], o_basis_proj[2] * 100. - origin_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 2 span

    add_line(ax_simple, span_line_endpoints_proj[0], span_line_endpoints_proj[1], color=colors_0[0], linestyle='-', linewidth=1., zorder=1.5)  # basis 0 span
    add_line(ax_simple, origin_proj[0], o_basis_proj[1] * 100. - origin_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 1 span
    add_line(ax_simple, origin_proj[0], o_basis_proj[2] * 100. - origin_proj[0], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 2 span

    add_line(ax_01, span_line_endpoints[0, :-1].flip(dims=(-1,)), span_line_endpoints[1, :-1].flip(dims=(-1,)), color=colors_0[0], linestyle='-', linewidth=1., zorder=1.5)  # basis 0 span
    add_line(ax_01, basis[1, :-1].flip(dims=(-1,)), o_basis[1, :-1].flip(dims=(-1,)), color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # basis 1 proj line
    add_line(ax_01, target_1_x[0, :-1].flip(dims=(-1,)), target_1[0, :-1].flip(dims=(-1,)), color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # target proj proj line
    add_line(ax_01, error_01[0, :-1].flip(dims=(-1,)), target_xz[0, :-1].flip(dims=(-1,)), color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # error proj line
    add_line(ax_01, o_basis[1, :-1].flip(dims=(-1,)) * -100. - origin[0, :-1].flip(dims=(-1,)), o_basis[1, :-1].flip(dims=(-1,)) * 100. - origin[0, :-1].flip(dims=(-1,)), color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 1 span

    add_line(ax_12, o_basis[1, 1:] * -100. - origin[0, 1:], o_basis[1, 1:] * 100. - origin[0, 1:], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 1 span
    add_line(ax_12, o_basis[2, 1:] * -100. - origin[0, 1:], o_basis[2, 1:] * 100. - origin[0, 1:], color=colors_0[-1], linestyle=':', linewidth=1., zorder=4.)  # gs basis 2 span

    # add_point(ax, origin_proj, color=colors[-1], s=4., zorder=3.)
    add_point(ax, target_proj, color=colors[-1], marker='P', s=49., zorder=3.)
    add_point(ax, target_z_proj, color=colors[-1], marker='X', s=49., zorder=3.)
    # add_point(ax, target_1_proj, color=colors[-1], s=4., zorder=3.)
    # add_point(ax, error_01_proj, color=colors[-1], s=4., zorder=3.)
    # add_point(ax, target_xz_proj, color=colors[-1], s=4., zorder=3.)
    # add_point(ax, target_x_proj, color=colors[-1], s=4., zorder=3.)
    # add_point(ax, target_1_x_proj, color=colors[-1], s=4., zorder=3.)

    add_point(ax_simple, target_proj, color=colors[-1], marker='P', s=49., zorder=3.)
    add_point(ax_simple, target_z_proj, color=colors[-1], marker='X', s=49., zorder=3.)

    add_point(ax_01, target_z[:, :-1].flip(dims=(-1,)), color=colors[-1], marker='X', s=49., zorder=3.)

    ax.fill(nearest_plane_endpoints_proj[:, 0], nearest_plane_endpoints_proj[:, 1], color=colors_bg[0], alpha=1., linewidth=0., zorder=1.)
    ax.fill(ortho_plane_endpoints_proj[:, 0], ortho_plane_endpoints_proj[:, 1], color=colors_bg[1], alpha=1., linewidth=0., zorder=1.9)

    ax_simple.fill(nearest_plane_endpoints_proj[:, 0], nearest_plane_endpoints_proj[:, 1], color=colors_bg[0], alpha=1., linewidth=0., zorder=1.)

    radius = .3
    ax_12.add_patch(Arc(xy=origin[0, 1:], width=radius, height=radius, angle=90. - angle_theta, theta1=0., theta2=angle_theta, color=colors_1[3], zorder=3.5))
    ax_12.add_patch(Arc(xy=origin[0, 1:], width=radius, height=radius, angle=180. - angle_theta, theta1=0., theta2=angle_theta, color=colors_1[3], zorder=3.5))
    ax_12.add_patch(Arc(xy=target_1_x[0, 1:], width=radius, height=radius, angle=180. - angle_theta, theta1=0., theta2=angle_theta, color=colors_1[3], zorder=3.5))

    ax.set_xlim(-2., 1.5)
    ax.set_ylim(-1.5, 2.)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    ax.set_facecolor('white')

    ax_simple.set_xlim(-2., 1.5)
    ax_simple.set_ylim(-1.5, 2.)
    ax_simple.set_xticks([])
    ax_simple.set_yticks([])
    ax_simple.set_aspect('equal')
    ax_simple.set_facecolor('white')

    ax_01.set_xlim(-.5, 2.5)
    ax_01.set_ylim(2.5, -.5)
    ax_01.set_xticks([])
    ax_01.set_yticks([])
    ax_01.set_aspect('equal')
    ax_01.set_facecolor(colors_bg[0])

    ax_12.set_xlim(-1., 2.)
    ax_12.set_ylim(-.75, 2.25)
    ax_12.set_xticks([])
    ax_12.set_yticks([])
    ax_12.set_aspect('equal')
    ax_12.set_facecolor(colors_bg[1])

    ax.set_xlabel("(b) [3D] Babai & OBQ Equivalence", fontproperties=font)
    ax_simple.set_xlabel("(a) [3D] Babai's Projection", fontproperties=font)
    ax_01.set_xlabel("(c) [2D] Nearest Hyperplane", fontproperties=font)
    ax_12.set_xlabel("(d) [2D] Orthogonal Projection Plane", fontproperties=font)

    handles = {}
    handles[r"Auxiliary Line in Orthogonal Directions"], = ax.plot(100., 100., color=colors_0[-1], linestyle=':', linewidth=1.)  # proj line
    handles[r"Basis Vector $\boldsymbol{b}_{{j}_{1}}$"], = ax.plot(100., 100., color=colors_1[0], linestyle='-', linewidth=1.)  # basis 1
    handles[r"Basis Vector $\boldsymbol{b}_{{j}_{2}}$"], = ax.plot(100., 100., color=colors_2[0], linestyle='-', linewidth=1.)  # basis 2
    handles[r"Target Point $\boldsymbol{y} \coloneq \Sigma_{j} {\zeta}_{j} \boldsymbol{b}_{j}$"] = ax.scatter(100., 100., color=colors[-1], marker='P', s=49.)
    handles[r"Nearest Hyperplane $\mathcal{NHP} \coloneq \left\lfloor {\zeta}_{{j}_{2}} \right\rceil \boldsymbol{b}_{{j}_{2}} + \mathrm{Span} \left\{ \boldsymbol{b}_{j} \; | \; {j} \ne {j}_{2} \right\}$"], = ax.fill(100., 100., color=colors_bg[0], alpha=1., linewidth=0.)
    handles[r"Hyperline $\mathcal{HL} \coloneq \left\lfloor {\zeta}_{{j}_{2}} \right\rceil \boldsymbol{b}_{{j}_{2}} + \mathrm{Span} \left\{ \boldsymbol{b}_{j} \; | \; {j} \ne {j}_{1}, {j}_{2} \right\}$"], = ax.plot(100., 100., color=colors_0[0], linestyle='-', linewidth=1.)  # basis 0
    handles[r"Babai's Projected Point $\mathrm{Proj}_{\mathcal{NHP}} \left( \boldsymbol{y} \right) \coloneq \Sigma_{j} \left( {\zeta}_{j} + \Delta {\zeta}_{j} \right) \boldsymbol{b}_{j}$"] = ax.scatter(100., 100., color=colors[-1], marker='X', s=49.)
    handles[r"Error Vector $\Delta \boldsymbol{y} \coloneq \mathrm{Proj}_{\mathcal{NHP}} \left( \boldsymbol{y} \right) - \boldsymbol{y} = \Sigma_{j} \Delta {\zeta}_{j} \boldsymbol{b}_{j}$"], = ax.plot(100., 100., color=colors_2[3], linestyle='-', linewidth=1.)  # error
    handles[r"Error Component Vector $\Delta {\zeta}_{{j}_{1}} \boldsymbol{b}_{{j}_{1}}$"], = ax.plot(100., 100., color=colors_1[1], linestyle='-', linewidth=1.)  # error direction 1
    handles[r"Error Component Vector $\Delta {\zeta}_{{j}_{2}} \boldsymbol{b}_{{j}_{2}}$"], = ax.plot(100., 100., color=colors_2[1], linestyle='-', linewidth=1.)  # error direction 2
    handles[r"Remaining Error Component Vector $\Sigma_{{j} \ne {j}_{1}, {j}_{2}} \Delta {\zeta}_{j} \boldsymbol{b}_{j}$"], = ax.plot(100., 100., color=colors_0[-1], linestyle='-', linewidth=1.)  # error direction 0
    handles[r"Inverse Basis Vector $\boldsymbol{n}_{{j}_{1}}: \left< \boldsymbol{n}_{{j}_{1}}, \boldsymbol{b}_{{j}_{1}} \right> = 1; \boldsymbol{n}_{{j}_{1}} \perp \boldsymbol{b}_{j}, \forall {j} \ne {j}_{1}$"], = ax.plot(100., 100., color=colors_1[2], linestyle='-', linewidth=1.)  # inv basis 1, prep basis 0, 2
    handles[r"Inverse Basis Vector $\boldsymbol{n}_{{j}_{2}}: \left< \boldsymbol{n}_{{j}_{2}}, \boldsymbol{b}_{{j}_{2}} \right> = 1; \boldsymbol{n}_{{j}_{2}} \perp \boldsymbol{b}_{j}, \forall {j} \ne {j}_{2}$"], = ax.plot(100., 100., color=colors_2[2], linestyle='-', linewidth=1.)  # inv basis 2, prep basis 0, 1
    handles[r"Orthogonal Projection Plane $\mathcal{OPP} \coloneq \mathrm{Span} \left\{ \boldsymbol{n}_{j} \; | \; j = {j}_{1}, {j}_{2} \right\}$"], = ax.fill(100., 100., color=colors_bg[1], alpha=1., linewidth=0.)
    handles[r"Projected Basis Vector $\mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{1}} \right)$"], = ax.plot(100., 100., color=colors_1[0], linestyle=':', linewidth=1.5)  # basis 1 proj, prep basis 0, 2
    handles[r"Projected Basis Vector $\mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{2}} \right)$"], = ax.plot(100., 100., color=colors_2[0], linestyle=':', linewidth=1.5)  # basis 2 proj, prep basis 0, 1
    handles[r"Projected Error Vector $\mathrm{Proj}_{\mathcal{OPP}} \left( \Delta \boldsymbol{y} \right) = \Delta \boldsymbol{y} = \Sigma_{{j} = {j}_{1}, {j}_{2}} \Delta {\zeta}_{j} \mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{j} \right)$"], = ax.plot(100., 100., color=colors_2[3], linestyle=':', linewidth=1.5)  # error proj
    handles[r"Projected Error Component Vector $\Delta {\zeta}_{{j}_{1}} \mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{1}} \right)$"], = ax.plot(100., 100., color=colors_1[1], linestyle=':', linewidth=1.5)  # error direction 1 proj
    handles[r"Projected Error Component Vector $\Delta {\zeta}_{{j}_{2}} \mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{2}} \right)$"], = ax.plot(100., 100., color=colors_2[1], linestyle=':', linewidth=1.5)  # error direction 2 proj
    handles[r"Angle $\theta = \angle \left( \boldsymbol{n}_{{j}_{1}}, \boldsymbol{n}_{{j}_{2}} \right) = \pi - \angle \left( \mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{1}} \right), \mathrm{Proj}_{\mathcal{OPP}} \left( \boldsymbol{b}_{{j}_{2}} \right) \right)$"] = ax.scatter(100., 100., facecolors='none', edgecolors=colors_1[3], s=81., marker='o', linewidths=1.)  # angle theta

    fig.legend(handles.values(), handles.keys(), loc='center left', ncol=1, framealpha=1., bbox_to_anchor=(1., .5), prop=font_legend)#, handler_map={FancyArrow: HandlerFancyArrow()})
    fig.set_facecolor((1., 1., 1., 0.))
    fig.tight_layout()
    fig.savefig(f'2_obq_3d.pdf', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_pdf_meta)
    fig.savefig(f'2_obq_3d.svg', bbox_inches='tight', pad_inches=.01, transparent=False, metadata=strip_svg_meta)
    # fig.show()
    fig.clf()


if __name__ == '__main__':
    plot_obq_argmin()
    plot_obq_update()
