import itertools
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.spatial.transform.rotation import Rotation as R


def plot3d(figsize=(3,3),
           dpi=40,
           tight_layout=True,
           lim=0.7,
           view=(-47, -90),
           axis='off',
           dist=None,
          ):
    f = plt.figure(figsize=figsize, dpi=dpi)
    f.subplots_adjust(0,0,1,1)
    ax = plt.axes(projection='3d')

    if lim is not None:
        ax.set_xlim(-lim, lim)
        ax.set_ylim(-lim, lim)
        ax.set_zlim(-lim, lim)

    if tight_layout:
        plt.tight_layout()

    if view is not None:
        ax.view_init(*view)

    if dist is not None:
        ax.dist = dist

    if axis == 'off':
        plt.axis('off')

    return f, ax


def add_colored_polygons(ax,
                         faces: np.ndarray,
                         colors: np.ndarray,
                         **kwargs):
    for face, color in zip(faces, colors):
        add_polygon(ax, np.asarray(face), color=color, **kwargs)


def label_vertices(ax,
                   vertices,
                   size=12,
                   color='k',
                   fontweight='bold',
                   occluding=True,
                   start_from=0,
                   **kwargs,
                  ):
    theta, phi = np.radians((ax.azim, 90 - ax.elev))
    vec = np.array((np.cos(theta)*np.sin(phi),
                    np.sin(theta)*np.sin(phi),
                    np.cos(phi)))

    origin = np.mean(vertices, axis=0)

    for i, vert in enumerate(vertices):
        if not occluding or vec.dot(vert) > vec.dot(origin):
            ax.text(*vert, f'{start_from+i}',
                    size=size, color=color, fontweight=fontweight,
                    **kwargs)

def add_edges(ax,
              edges,
              color=(0,0,0,1),
              linewidth=1,
             ):
    for edge in edges:
        ax.plot(*edge.T, '-', color=color, linewidth=linewidth)

def add_polygon(ax,
                vertices: np.ndarray,
                color=(1,1,1,0.8),
                edgecolor=(0,0,0,1),
                linewidth=1,
                zsort='average'):
    poly = Poly3DCollection([vertices])
    poly.set(color=color,
            edgecolor=edgecolor,
            linewidth=linewidth,
            zsort=zsort
           )
    ax.add_collection3d(poly)

def add_polygons(ax,
                vertices: np.ndarray,
                color=(1,1,1,0.8),
                edgecolor=(0,0,0,1),
                linewidth=1,
                zsort='average'):
    poly = Poly3DCollection(vertices)
    poly.set(color=color,
            edgecolor=edgecolor,
            linewidth=linewidth,
            zsort=zsort
           )
    ax.add_collection3d(poly)

def add_axis_visual(ax,
                    origin=(-1,-1,-1),
                    length=0.4,
                    linewidth=5,
                    colors='rgb',
                   ):
    ax.quiver(*origin, length, 0, 0, color=colors[0], linewidth=linewidth)
    ax.quiver(*origin, 0, length, 0, color=colors[1], linewidth=linewidth)
    ax.quiver(*origin, 0, 0, length, color=colors[2], linewidth=linewidth)


def render(f, render_mode='rgb_array'):
    if render_mode == 'human':
        plt.show()
    elif render_mode == 'background':
        f.canvas.draw()
        plt.pause(0.01)
    elif render_mode == 'rgb_array':
        f.canvas.draw()
        rgb_array = np.frombuffer(f.canvas.tostring_rgb(), dtype='uint8')
        rgb_array = rgb_array.reshape(f.canvas.get_width_height()[::-1]+(3,))
        plt.close(f)
        return rgb_array


def gridify_images(images,
                   # border_size=0,
                   # border_val=0,
                  ):
    output = []
    for row in images:
        output.append(np.concatenate(row, axis=1))

    return np.concatenate(output, axis=0)


def generate_video(imgs, name: str=None, show: bool=False):
    fig_width = imgs[0].shape[1]/imgs[0].shape[0]
    fig_height = 1
    f = plt.figure(figsize=(fig_width, fig_height), dpi=200)
    f.subplots_adjust(0,0,1,1)
    plt.axis('off')

    frames = []
    for img in imgs:
        frames.append([plt.imshow(img, animated=True)])

    ani = animation.ArtistAnimation(f, frames, interval=20, repeat_delay=0, blit=False)
    if name is not None:
        ani.save(name)
    if show:
        plt.show()
    plt.close(f)
