"""
Shapes that can be rendered with OpenGL
"""

from collections import namedtuple
import copy

import numpy as np
from OpenGL.GL import *

from . import linalg

"""A simple coordinate"""
coord = namedtuple("coord", ["x", "y", "z"], defaults=[0.0, 0.0, 0.0])

ORIGIN = coord(x=0.0, y=0.0, z=0.0)


class Shape:
    def __init__(self, opacity=1.0):
        self.vertices = None
        self.edges = None
        self.quad_surfaces = None
        self.trifan_surfaces = None
        self.opacity = opacity

    def draw_lines(self):
        if self.vertices is None or self.edges is None:
            raise NotImplementedError(f"Cannot draw lines on a {type(self).__name__} shape")
        glBegin(GL_LINES)
        glColor4fv((1.0, 1.0, 1.0, self.opacity)) # white color
        for edge in self.edges:
            glVertex3fv(self.vertices[edge[0]])
            glVertex3fv(self.vertices[edge[1]])
        glEnd()

    def draw_surfaces(self):
        colors = [
            (1.0, 0.0, 0.0, self.opacity),
            (0.0, 1.0, 0.0, self.opacity),
            (0.0, 0.0, 1.0, self.opacity),
        ]
        if self.vertices is not None and self.quad_surfaces is not None:
            glBegin(GL_QUADS)
            for i, surface in enumerate(self.quad_surfaces):
                r,g,b,o = colors[i % 3]
                glColor4fv((min(r + 0.5, 1.0), g, b, o))
                glVertex3fv(self.vertices[surface[0]])
                glColor4fv((r, min(g + 0.5, 1.0), b, o))
                glVertex3fv(self.vertices[surface[1]])
                glColor4fv((r, g, min(b + 0.5, 1.0), o))
                glVertex3fv(self.vertices[surface[2]])
                glColor4fv((r, g, b, o))
                glVertex3fv(self.vertices[surface[3]])
            glEnd()
        if self.vertices is not None and self.trifan_surfaces is not None:
            for i, surf in enumerate(self.trifan_surfaces):
                glBegin(GL_TRIANGLE_FAN)
                glColor4fv(colors[i % 3])
                for vertex_idx in surf:
                    glVertex3fv(self.vertices[vertex_idx])
                glEnd()

    def rotate(self, x_angle=0.0, y_angle=0.0, z_angle=0.0):
        if self.vertices is None:
            raise NotImplementedError(f"Cannot rotate a {type(self).__name__} shape")
        #"""
        m = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
        ], dtype=np.float32)
        if x_angle != 0.0:
            m = np.matmul(m, np.array([
                [1, 0, 0],
                [0, np.cos(x_angle), -np.sin(x_angle)],
                [0, np.sin(x_angle), np.cos(x_angle)],
            ], dtype=np.float32))
        if y_angle != 0.0:
            m = np.matmul(m, np.array([
                [np.cos(y_angle), 0, np.sin(y_angle)],
                [0, 1, 0],
                [-np.sin(y_angle), 0, np.cos(y_angle)],
            ], dtype=np.float32))
        if z_angle != 0.0:
            m = np.matmul(m, np.array([
                [np.cos(z_angle), -np.sin(z_angle), 0],
                [np.sin(z_angle), np.cos(z_angle), 0],
                [0, 0, 1],
            ], dtype=np.float32))
        #"""
        #m = linalg.M_rot(x_angle, y_angle, z_angle)
        self.vertices = np.matmul(
            self.vertices,
            m.T,
        )

    def translate(self, x=0.0, y=0.0, z=0.0):
        self.vertices += np.array([x,y,z], dtype=np.float32)


class Cube(Shape):
    def __init__(self, width=1.0, height=1.0, depth=1.0, **kwargs):
        super().__init__(**kwargs)
        assert width > 0
        assert height > 0
        assert depth > 0
        UNIT_VERTICES = [
            (0.0, 0.0, 0.0),
            (0.0, 0.0, 1.0),
            (0.0, 1.0, 0.0),
            (0.0, 1.0, 1.0),
            (1.0, 0.0, 0.0),
            (1.0, 0.0, 1.0),
            (1.0, 1.0, 0.0),
            (1.0, 1.0, 1.0),
        ]
        IDX_EDGES = [
            (0, 1), (0, 2), (0, 4),
            (1, 3), (1, 5),
            (2, 3), (2, 6),
            (3, 7),
            (4, 5), (4, 6),
            (5, 7),
            (6, 7),
        ]
        IDX_SURFACES = [
            (0, 1, 3, 2),
            (0, 1, 5, 4),
            (0, 2, 6, 4),
            (4, 5, 7, 6),
            (2, 3, 7, 6),
            (1, 3, 7, 5),
        ]
        #self.edges = []
        #for i in range(len(UNIT_VERTICES)):
        #    for j in range(i+1,len(UNIT_VERTICES)):
        #        dist = sum([abs(x - y) for x,y in zip(UNIT_VERTICES[i], UNIT_VERTICES[j])])
        #        if dist == 1:
        #            self.edges.append((i,j))
        self.edges = copy.copy(IDX_EDGES)
        self.quad_surfaces = copy.copy(IDX_SURFACES)
        self.vertices = np.array([
            [(x - 0.5)*width, (y - 0.5)*height, (z - 0.5)*depth]
            for x,y,z in UNIT_VERTICES
        ])


class Cylinder(Shape):
    def __init__(self, height=1.0, diameter=1.0, granularity=8, **kwargs):
        super().__init__(**kwargs)
        assert height > 0
        assert diameter > 0
        assert isinstance(granularity, int) and granularity > 2
        vertices = []
        self.edges = []
        self.quad_surfaces = []
        radius = diameter / 2.0
        c_top = coord(0.0, height/2.0, 0.0)
        c_bottom = coord(0.0, -height/2.0, 0.0)
        for i in range(granularity):
            angle = 2.0 * np.pi * (i / granularity)
            vertices.append(coord(radius*np.sin(angle), c_bottom.y, radius*np.cos(angle)))
            vertices.append(coord(radius*np.sin(angle), c_top.y,    radius*np.cos(angle)))
            # edge along the length of the cylinder
            self.edges.append((2*i, 2*i + 1))
            # edges between the endpoints
            self.edges.append((2*i, (2*i + 2) % (2*granularity)))
            self.edges.append((2*i + 1, (2*i + 3) % (2*granularity)))
            # surfaces along the cylinder
            self.quad_surfaces.append((
                2*i,
                2*i + 1,
                (2*i + 3) % (2*granularity),
                (2*i + 2) % (2*granularity),
            ))

        # surfaces at the endpoints
        vertices.append(c_bottom)
        vertices.append(c_top)
        self.trifan_surfaces = [
            tuple([len(vertices) - 2] + [2*i for i in range(granularity)] + [0]),
            tuple([len(vertices) - 1] + [2*i + 1 for i in range(granularity)] + [1]),
        ]

        self.vertices = np.array([
            [x, y, z]
            for x,y,z in vertices
        ], dtype=np.float32)


class Rod(Cylinder):
    def __init__(self, c_from, c_to, thickness=1.0, **kwargs):
        if not isinstance(c_from, coord):
            c_from = coord(x=c_from[0], y=c_from[1], z=c_from[2])
        if not isinstance(c_to, coord):
            c_to = coord(x=c_to[0], y=c_to[1], z=c_to[2])

        v_dir = coord(
            x=c_to.x - c_from.x,
            y=c_to.y - c_from.y,
            z=c_to.z - c_from.z,
        )

        length = np.sqrt(v_dir.x**2 + v_dir.y**2 + v_dir.z**2)
        flat_length = np.sqrt(v_dir.x**2 + 0.0 + v_dir.z**2)

        x_angle = np.arccos(v_dir.y / length)
        y_angle = 0.0
        if flat_length > 0:
            y_angle = np.arccos(v_dir.z / flat_length)
            if v_dir.x < 0:
                y_angle *= -1
        super().__init__(height=length, diameter=thickness, **kwargs)
        # Put ourselves with the bottom at coord (0,0,0)
        c_bottom = self.vertices[-2]
        self.translate(x=-c_bottom[0], y=-c_bottom[1], z=-c_bottom[2])
        #self.rotate(x_angle, y_angle, z_angle)
        self.rotate(x_angle, 0.0,     0.0)
        self.rotate(0.0,     y_angle, 0.0)
        self.translate(x=c_from.x, y=c_from.y, z=c_from.z)
