import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import mpl_toolkits.mplot3d.art3d as art3d

from app.utils.voxel_utils import voxel_project

TEST_VOXEL = np.array(
    [
        [
            [1, 1, 1],
            [1, 1, 1],
            [1, 0, 0],
        ],
        [
            [1, 1, 1],
            [1, 1, 0],
            [0, 0, 0],
        ],
        [
            [1, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ],
    ]
)


class VoxelRenderer:
    def __init__(self):
        pass

    def _render_2d(self, voxel_array, ax, projection=None, color_map="plasma"):
        voxel_array = voxel_project(voxel_array, projection)

        ax.set_aspect("equal")
        ax.axis("off")

        depth, height, width = voxel_array.shape

        if color_map is None:
            DEPTH_MAP = ["white"]
        else:
            DEPTH_MAP = plt.colormaps[color_map](np.linspace(0, 1, depth))

        for z in range(depth):
            layer = voxel_array[z]

            filled = np.argwhere(layer > 0)
            for y, x in filled:
                rect = patches.Rectangle(
                    (x, y),
                    1,
                    1,
                    linewidth=0,
                    facecolor=DEPTH_MAP[z % len(DEPTH_MAP)],
                    zorder=z,
                )
                ax.add_patch(rect)

            padded = np.pad(layer, ((1, 1), (0, 0)))
            diff_y = np.diff(padded, axis=0) != 0
            for y, x in np.argwhere(diff_y):
                y_plot = y
                ax.plot(
                    [x, x + 1],
                    [y_plot, y_plot],
                    color="k",
                    linewidth=1.5,
                    zorder=z + 1,
                )

            padded = np.pad(layer, ((0, 0), (1, 1)))
            diff_x = np.diff(padded, axis=1) != 0
            for y, x in np.argwhere(diff_x):
                y_plot = y
                ax.plot(
                    [x, x],
                    [y_plot, y_plot + 1],
                    color="k",
                    linewidth=1.5,
                    zorder=z + 1,
                )

        ax.set_xlim(-0.5, width + 0.5)
        ax.set_ylim(-0.5, height + 0.5)

    def _render_2d_projection(self, projection, ax):
        ax.set_aspect("equal")
        ax.axis("off")

        height, width = projection.shape

        filled = np.argwhere(projection > 0)
        for y, x in filled:
            rect = patches.Rectangle(
                (x, y),
                1,
                1,
                linewidth=0,
                facecolor="white",
            )
            ax.add_patch(rect)

        padded = np.pad(projection, ((1, 1), (0, 0)))
        diff_y = np.diff(padded, axis=0) != 0
        for y, x in np.argwhere(diff_y):
            y_plot = y
            ax.plot(
                [x, x + 1],
                [y_plot, y_plot],
                color="k",
                linewidth=1.5,
            )

        padded = np.pad(projection, ((0, 0), (1, 1)))
        diff_x = np.diff(padded, axis=1) != 0
        for y, x in np.argwhere(diff_x):
            y_plot = y
            ax.plot(
                [x, x],
                [y_plot, y_plot + 1],
                color="k",
                linewidth=1.5,
            )

        ax.set_xlim(-0.5, width + 0.5)
        ax.set_ylim(-0.5, height + 0.5)

    def _check_face(self, x, y, z, dir, voxel_array):
        axis = dir[1]
        sign = dir[0]

        nx, ny, nz = voxel_array.shape

        if x < 0 or y < 0 or z < 0:
            return False
        if x >= nx or y >= ny or z >= nz:
            return False

        if voxel_array[x, y, z] == 0:
            return False

        dx = x + (1 if sign == "+" else -1) * (1 if axis == "x" else 0)
        dy = y + (1 if sign == "+" else -1) * (1 if axis == "y" else 0)
        dz = z + (1 if sign == "+" else -1) * (1 if axis == "z" else 0)

        if dx >= 0 and dy >= 0 and dz >= 0:
            if dx < nx and dy < ny and dz < nz:
                if voxel_array[dx, dy, dz] != 0:
                    return False

        return True

    def _draw_face(self, ax, x, y, z, dir, voxel_array, face_color):
        axis = dir[1]
        sign = dir[0]

        if axis == "z":
            rect_coords = (x, y)
            z_val = z + (1 if sign == "+" else 0)
            zdir = "z"
        elif axis == "y":
            rect_coords = (x, z)
            z_val = y + (1 if sign == "+" else 0)
            zdir = "y"
        elif axis == "x":
            rect_coords = (y, z)
            z_val = x + (1 if sign == "+" else 0)
            zdir = "x"

        if not self._check_face(x, y, z, dir, voxel_array):
            return

        rect = patches.Rectangle(rect_coords, 1, 1, facecolor=face_color, edgecolor="k")
        ax.add_patch(rect)
        art3d.pathpatch_2d_to_3d(rect, z=z_val, zdir=zdir)

    def _render_3d(self, voxel_array, ax, color_map="plasma"):
        nx, ny, nz = voxel_array.shape

        if color_map is None:
            DEPTH_MAP = ["white"]
        else:
            DEPTH_MAP = plt.colormaps[color_map](np.linspace(0, 1, max(nx, ny, nz)))

        for x in range(nx):
            for y in range(ny):
                for z in range(nz):
                    if voxel_array[x, y, z] == 0:
                        continue

                    face_colors = {
                        "-x": DEPTH_MAP[x % len(DEPTH_MAP)],
                        "+x": DEPTH_MAP[x % len(DEPTH_MAP)],
                        "-y": DEPTH_MAP[y % len(DEPTH_MAP)],
                        "+y": DEPTH_MAP[y % len(DEPTH_MAP)],
                        "-z": DEPTH_MAP[z % len(DEPTH_MAP)],
                        "+z": DEPTH_MAP[z % len(DEPTH_MAP)],
                    }

                    self._draw_face(ax, x, y, z, "-z", voxel_array, face_colors["-z"])
                    self._draw_face(ax, x, y, z, "+z", voxel_array, face_colors["+z"])
                    self._draw_face(ax, x, y, z, "-y", voxel_array, face_colors["-y"])
                    self._draw_face(ax, x, y, z, "+y", voxel_array, face_colors["+y"])
                    self._draw_face(ax, x, y, z, "-x", voxel_array, face_colors["-x"])
                    self._draw_face(ax, x, y, z, "+x", voxel_array, face_colors["+x"])

        ax.set_aspect("equal")

        ax.set_xlabel("Right", fontsize=20)
        ax.set_ylabel("Front", fontsize=20)

        ax.xaxis.pane.set_visible(False)
        ax.yaxis.pane.set_visible(False)
        ax.zaxis.pane.set_visible(False)

        ax.xaxis.pane.set_facecolor((0.0, 0.0, 0.0, 0.0))
        ax.yaxis.pane.set_facecolor((0.0, 0.0, 0.0, 0.0))
        ax.zaxis.pane.set_facecolor((0.0, 0.0, 0.0, 0.0))

        ax.xaxis.line.set_color((0.0, 0.0, 0.0, 0.0))
        ax.yaxis.line.set_color((0.0, 0.0, 0.0, 0.0))
        ax.zaxis.line.set_color((0.0, 0.0, 0.0, 0.0))

        ax.set_xlim(0, nx)
        ax.set_ylim(0, ny)
        ax.set_zlim(0, nz)

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])

        ax.view_init(elev=30, azim=45)


if __name__ == "__main__":
    color_map = "cool"
    projections = ["top", "front", "right"]
    labels = ["A", "B", "C"]

    renderer = VoxelRenderer()
    fig = plt.figure(figsize=(15, 10))

    cols = 3
    projection_rows = len(projections) // cols
    gs = fig.add_gridspec(2 + projection_rows, cols)

    ax_3d = fig.add_subplot(gs[0:2, 0:cols], projection="3d")
    renderer._render_3d(TEST_VOXEL, ax_3d, color_map)

    axs = []
    for row in range(projection_rows):
        for col in range(cols):
            axs.append(fig.add_subplot(gs[2 + row, col]))

    for ax, proj, label in zip(axs, projections, labels):
        renderer._render_2d(TEST_VOXEL, ax, proj, color_map)
        ax.set_title(label, fontsize=20)

    plt.tight_layout()

    plt.show()
