'''
Author: 
Email: 
Date: 2020-09-26 15:59:00
LastEditTime: 2021-05-30 16:40:01
Description: 
    This renderer includes two kinds of objects, and this can be extended to multiple objects
'''

import torch
import torch.nn as nn
import numpy as np
from skimage.io import imread

from utils import CUDA, CPU, COLOR
import neural_renderer as nr


class Renderer(nn.Module):
    def __init__(self, filename_box, filename_plane, filename_ref):
        super(Renderer, self).__init__()

        # create two kinds of objects
        self.max_box_num = 20
        self.max_plane_num = 5
        vertices_box, faces_box, box_face_num = self.process_objects(filename_box, self.max_box_num)
        self.box = {'vertices': vertices_box, 'faces': faces_box, 'face_num': box_face_num}
        vertices_plane, faces_plane, plane_face_num = self.process_objects(filename_plane, self.max_plane_num)
        if filename_plane.split('/')[-1] == 'circle.obj':
            vertices_plane = vertices_plane*220
            vertices_plane = vertices_plane.matmul(self.Ry(CUDA(torch.tensor(np.pi/2))))
            vertices_plane[:, :, 0] -= 1

        self.plane = {'vertices': vertices_plane, 'faces': faces_plane, 'face_num': plane_face_num}

        # load reference image
        image_ref = CUDA(torch.from_numpy(imread(filename_ref).astype(np.float32)/255.)) # [256, 256, 3]
        self.image_ref = torch.cat([image_ref[None, :, :, 0], image_ref[None, :, :, 1], image_ref[None, :, :, 2]], dim=0)[None] # [1, 3, 256, 256]

        # setup renderer
        self.renderer = nr.Renderer(camera_mode='look_at', light_direction=[1, 0, 0], background_color=[1.0, 1.0, 1.0])
        self.renderer.eye = nr.get_points_from_angles(20, 0, 90)

        self.V = None
        self.F = None
        self.T = None

        print(COLOR.GREEN+'Multiple Object Renderer:')
        print('\tBox filename:', filename_box)
        print('\tPlane filename:', filename_plane)
        print('\tReference filename:', filename_ref)
        print('\tMax box number:', self.max_box_num)
        print('\tMax plane number:', self.max_plane_num)
        print(COLOR.WHITE+'')
        
    @staticmethod
    def process_objects(filename, max_number):
        # NOTE: we should turn off the normalization since we have multiple objects
        vertices, faces = nr.load_obj(filename, normalization=False)

        # create multi-object faces
        vertices_output = vertices[None, :, :]
        face_num = faces.shape[0]
        faces_output = torch.cat([faces+int(vertices.shape[0]*o_i) for o_i in range(max_number)], dim=0)[None]

        return vertices_output, faces_output, face_num

    @staticmethod
    def color_to_texture(colors, faces, face_num):
        # create textures, the value should be in [-1, 1]
        texture_size = 2
        texture = CUDA(torch.ones(1, faces.shape[1], texture_size, texture_size, texture_size, 3, dtype=torch.float32))

        for o_i in range(colors.shape[0]):
            start_ = o_i*face_num
            end_ = (o_i+1)*face_num
            texture[:, start_:end_, :, :, :, 0] = colors[o_i][0]
            texture[:, start_:end_, :, :, :, 1] = colors[o_i][1]
            texture[:, start_:end_, :, :, :, 2] = colors[o_i][2]
        return texture

    def apply_transformation(self, object, poses, colors):
        trans = poses[0, :, 0:2] # only one batch
        rotations = poses[0, :, 2]
        colors = colors[0]

        # the number of object could be different
        V = torch.cat([object['vertices'].matmul(self.Rx(rotations[o_i])) + self.Translation(trans[o_i]) for o_i in range(trans.shape[0])], dim=1)
        F = object['faces'][:, 0:trans.shape[0]*object['face_num'], :]
        if colors is None:
            ValueError('No textures provided in texture enable mode')
        texture = self.color_to_texture(colors, object['faces'], object['face_num'])

        # the generate object number may not be equal to object_num
        T = texture[:, 0:F.shape[1], :, :, :, :]

        return V, F, T

    def forward(self, box_poses, box_colors, plane_poses, plane_colors):
        """
        poses: [1, N, 3]
        colors: [1, N, 3]
        """

        if box_poses is None:
            V_box = None
            F_box = None
            T_box = None
        else:
            # apply operations to all objects separately
            V_box, F_box, T_box = self.apply_transformation(self.box, box_poses, box_colors)

        if plane_poses is None:
            self.V = V_box
            self.F = F_box
            self.T = T_box
        else:
            V_plane, F_plane, T_plane = self.apply_transformation(self.plane, plane_poses, plane_colors)
            # NOTE: face stores the index of vertices, so we need to shift the index of the second object
            if F_box is None:
                # only have plates
                self.V = V_plane
                self.F = F_plane
                self.T = T_box
            else:
                last_face_idx = torch.max(F_box)
                F_plane = F_plane + last_face_idx + 1
                # combine all objects
                self.V = torch.cat([V_box, V_plane], dim=1)
                self.F = torch.cat([F_box, F_plane], dim=1)
                self.T = torch.cat([T_box, T_plane], dim=1)

        image = self.renderer(self.V, self.F, self.T, mode='rgb')
        loss = torch.sum((image - self.image_ref)**2)**0.5
        return loss

    def reference(self):
        #trans = np.array([[-8, -7.5], [-8, -2.5], [-8, 2.5], [-8, 7.5], [8, -7.5], [8, -2.5], [8, 2.5], [8, 7.5]], dtype=np.float32)
        trans = np.array([[-8, -8], [-8, -4], [-4, -8], [-4, -4], [8, 8], [8, 4], [4, 8], [4, 4]], dtype=np.float32)
        trans = CUDA(torch.tensor(trans))
        rotations = CUDA(torch.zeros(trans.shape[0],))

        self.V = torch.cat([self.vertices.matmul(self.Rx(rotations[o_i])) + self.Translation(trans[o_i]) for o_i in range(trans.shape[0])], dim=1)
        self.F = self.faces[:, 0:trans.shape[0]*self.one_object_face_num, :]
        self.T = self.textures[:, 0:self.F.shape[1], :, :, :, :]

        images = self.renderer(self.V, self.F, self.T, mode='rgb')
        images = (CPU(images)[0]*255.0).astype(np.uint8) # [3, 256, 256]
        images = np.concatenate([images[0, :, :, None], images[1, :, :, None], images[2, :, :, None]], axis=-1)
        return images

    def render(self):
        if self.V is None:
            raise ValueError('No object targted, call .forward() before calling .render()')
        images = self.renderer(self.V, self.F, self.T, mode='rgb')
        images = (CPU(images)[0]*255.0).astype(np.uint8) # [3, 256, 256]
        images = np.concatenate([images[0, :, :, None], images[1, :, :, None], images[2, :, :, None]], axis=-1)
        return images
        
    @staticmethod
    def Rx(rad):
        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        rotation = CUDA(torch.eye(3))
        rotation[1, 1] = torch.cos(rad)
        rotation[1, 2] = torch.sin(rad)
        rotation[2, 1] = -torch.sin(rad)
        rotation[2, 2] = torch.cos(rad)
        return rotation

    @staticmethod
    def Ry(rad):
        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        rotation = CUDA(torch.eye(3))
        rotation[0, 0] = torch.cos(rad)
        rotation[0, 2] = torch.sin(rad)
        rotation[2, 0] = -torch.sin(rad)
        rotation[2, 2] = torch.cos(rad)
        return rotation

    @staticmethod
    def Rz(rad):
        # NOTE: we should not directly define new tensors, otherwise the gradient flow will be broke down
        rotation = CUDA(torch.eye(3))
        rotation[0, 0] = torch.cos(rad)
        rotation[0, 1] = -torch.sin(rad)
        rotation[1, 0] = torch.sin(rad)
        rotation[1, 1] = torch.cos(rad)
        return rotation

    @staticmethod
    def Translation(xyz):
        translation = CUDA(torch.zeros(3))
        translation[1] = xyz[0]
        translation[2] = xyz[1]
        #translation[2] = xyz[2]
        return translation
