## license: MIT

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from .color import *

## 2D net
AS = 0.3 # arrow size
def arrow_2d(x, y, orientation, color):
    color_rgba = [float(c) / 255.0 for c in COLOR_RGB[color]] + [1.0]
    if orientation == 'up':
        return patches.FancyArrow(x+0.5, y+0.5-AS, 0, AS, color=color_rgba, width=AS, head_width=AS*2, head_length=AS)
    elif orientation == 'down': 
        return patches.FancyArrow(x+0.5, y+0.5+AS, 0, -AS, color=color_rgba, width=AS, head_width=AS*2, head_length=AS)
    elif orientation == 'right':
        return patches.FancyArrow(x+0.5-AS, y+0.5, AS, 0, color=color_rgba, width=AS, head_width=AS*2, head_length=AS)
    elif orientation == 'left':
        return patches.FancyArrow(x+0.5+AS, y+0.5, -AS, 0, color=color_rgba, width=AS, head_width=AS*2, head_length=AS)

def plot_2d(net_view, save_path, show=False, axis=False):
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111)

    for coor in net_view.keys():
        ori, color = net_view[coor]
        x, y = int(coor[0]), int(coor[2])
        x -= 0.5
        y -= 0.5
        square = plt.Rectangle((x, y), 1, 1, fill=None, edgecolor='black')
        ax.add_patch(square)

        # icon = triangle(x, y, ori, color)
        icon = arrow_2d(x, y, ori, color)

        ax.add_patch(icon)

    ax.set_xlim(-1, 4)
    ax.set_ylim(-1, 5)
    ax.set_aspect('equal')
    # ax.axis('scaled')

    if axis:
        ax.axis('on')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
    else:
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    if not show:
        plt.close()

## 3D net
def arrow_3d(face, orientation):
    arrow = [[0.5-AS/2, 0, 0.5], [0.5-AS/2, 0, 0.5-AS], [0.5+AS/2, 0, 0.5-AS], [0.5+AS/2, 0, 0.5],\
                [0.5+AS, 0, 0.5], [0.5, 0, 0.5+AS], [0.5-AS, 0, 0.5]]
    if face == 'front':
        if orientation == 'top':
            pass
        elif orientation == 'right':
            for i in range(len(arrow)):
                arrow[i][0], arrow[i][2] = arrow[i][2], arrow[i][0]
        elif orientation == 'bottom':
            for i in range(len(arrow)):
                arrow[i][2] = 1 - arrow[i][2]
        elif orientation == 'left':
            for i in range(len(arrow)):
                arrow[i][0], arrow[i][2] = 1 - arrow[i][2], arrow[i][0]
        else:
            raise ValueError('Invalid orientation')
        
    elif face == 'top':
        for i in range(len(arrow)):
            arrow[i][1], arrow[i][2] = arrow[i][2], 1-arrow[i][1]

        if orientation == 'back':
            pass
        elif orientation == 'right':
            for i in range(len(arrow)):
                arrow[i][0], arrow[i][1] = arrow[i][1], 1 - arrow[i][0]
        elif orientation == 'front':
            for i in range(len(arrow)):
                arrow[i][1] = 1 - arrow[i][1]
        elif orientation == 'left':
            for i in range(len(arrow)):
                arrow[i][0], arrow[i][1] = 1 - arrow[i][1], arrow[i][0]
        else:
            raise ValueError('Invalid orientation')
    elif face == 'right':
        for i in range(len(arrow)):
            arrow[i][0], arrow[i][1] = 1-arrow[i][1], arrow[i][0]

        if orientation == 'top':
            pass
        elif orientation == 'back':
            for i in range(len(arrow)):
                arrow[i][1], arrow[i][2] = arrow[i][2], 1-arrow[i][1]
        elif orientation == 'bottom':
            for i in range(len(arrow)):
                arrow[i][2] = 1 - arrow[i][2]
        elif orientation == 'front':
            for i in range(len(arrow)):
                arrow[i][1], arrow[i][2] = 1-arrow[i][2], arrow[i][1]
        else:
            raise ValueError('Invalid orientation')

    return arrow


def plot_3d(cube_view, save_path, show=False):
    fig = plt.figure(figsize=(3,3))
    ax = fig.add_subplot(111, projection='3d')

    cube = [
        [[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]],  
        [[1, 0, 0], [1, 1, 0], [1, 1, 1], [1, 0, 1]],  
        [[0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]],  
    ]
    ax.add_collection3d(Poly3DCollection(cube, facecolors='white', linewidths=1, edgecolors='black', alpha=0.0))
    
    for face in ['front', 'top', 'right']:
        orientation, color = cube_view[face] 
        color_rgba = [float(c) / 255.0 for c in COLOR_RGB[color]] + [1.0]
        arrow = arrow_3d(face, orientation)
        ax.add_collection3d(Poly3DCollection([arrow], facecolors=color_rgba, linewidths=1, edgecolors='white'))

    ax.set_box_aspect([1, 1, 0.9])
    ax.view_init(elev=35, azim=-55)

    ax.axis('off')
    plt.savefig(save_path, dpi=100)
    if not show:
        plt.close()

def plot_comp(save_path, path_list, subplots):
    row, col = subplots
    assert len(path_list) <= row * col

    images = [mpimg.imread(img_file) for img_file in path_list]

    fig, axes = plt.subplots(row, col)  

    axes = axes.ravel()  # Flatten for easy iteration
    for ax, img in zip(axes, images):
        ax.imshow(img)
        ax.axis('off')  # Hide axis labels and ticks

    plt.tight_layout()  # Prevent overlapping
    plt.savefig(save_path)
    plt.close()