import torch
import torch.nn as nn
import numpy as np

import os
# from TorchBatchify import Batchifier

try:
    from pytorch3d.structures import Meshes

    use_textures = True
except:
    from pytorch3d.structures import Meshes

    use_textures = False

try:
    from VoGE.Meshes import GaussianMeshesNaive as GaussianMesh
    from VoGE.Converter.Converters import naive_vertices_converter

    enable_voge = True
except:
    enable_voge = False

from pytorch3d.renderer import MeshRasterizer
from pytorch3d.transforms import Transform3d
from sklearn.neighbors import KDTree

from corr.utils import (
    forward_interpolate,
    forward_interpolate_voge,
    pre_process_mesh_pascal,
    vertex_memory_to_face_memory,
    campos_to_R_T,
)


def func_reselect(meshes, indexs, **kwargs):
    verts_ = [meshes._verts_list[i] for i in indexs]
    faces_ = [meshes._faces_list[i] for i in indexs]
    meshes_out = Meshes(verts=verts_, faces=faces_).to(meshes.device)
    return meshes_out, meshes_out.verts_padded()


def MeshDeformModule(*args, **kwargs):
    rasterizer = kwargs.get('rasterizer')
    if isinstance(rasterizer, MeshRasterizer):
        return MeshDeformModuleMesh(*args, **kwargs)
    else:
        assert enable_voge
        return MeshInterpolateModuleVoGE(*args, **kwargs)


def rotation_matrix(azimuth, elevation, theta):
    # Create the azimuth rotation matrix
    Rz = torch.stack([
        torch.cos(azimuth), -torch.sin(azimuth), torch.tensor(0.),
        torch.sin(azimuth), torch.cos(azimuth), torch.tensor(0.),
        torch.tensor(0.), torch.tensor(0.), torch.tensor(1.)
    ]).reshape(3, 3)

    # Create the elevation rotation matrix
    Rx = torch.stack([
        torch.tensor(1.), torch.tensor(0.), torch.tensor(0.),
        torch.tensor(0.), torch.cos(elevation), -torch.sin(elevation),
        torch.tensor(0.), torch.sin(elevation), torch.cos(elevation)
    ]).reshape(3, 3)

    # Create the theta rotation matrix
    Ry = torch.stack([
        torch.cos(theta), torch.tensor(0.), torch.sin(theta),
        torch.tensor(0.), torch.tensor(1.), torch.tensor(0.),
        -torch.sin(theta), torch.tensor(0.), torch.cos(theta)
    ]).reshape(3, 3)

    # Combine the rotations
    R = torch.matmul(Rz, torch.matmul(Rx, Ry))

    return R


class MeshDeformModuleVoGE(nn.Module):
    def __init__(self, vertices, faces, memory_bank, rasterizer, post_process=None, off_set_mesh=False):
        super(MeshInterpolateModuleVoGE, self).__init__()

        # Convert memory features of vertices to faces
        self.memory = None
        self.update_memory(memory_bank=memory_bank,)

        # Preprocess convert meshes in PASCAL3d+ standard to Pytorch3D
        verts = pre_process_mesh_pascal(vertices)

        self.meshes = GaussianMesh(*naive_vertices_converter(verts, faces, percentage=0.5))

        # Device is used during theta to R
        self.rasterizer = rasterizer
        self.post_process = post_process
        self.off_set_mesh = off_set_mesh

    def update_memory(self, memory_bank, ):
        self.memory = memory_bank

    def to(self, *args, **kwargs):
        if 'device' in kwargs.keys():
            device = kwargs['device']
        else:
            device = args[0]
        super(MeshInterpolateModuleVoGE, self).to(device)
        self.rasterizer.cameras = self.rasterizer.cameras.to(device)
        self.memory = self.memory.to(device)
        self.meshes = self.meshes.to(device)
        return self

    def cuda(self, device=None):
        return self.to(torch.device("cuda"))

    def forward(self, campos, theta, deform_verts=None, **kwargs):
        R, T = campos_to_R_T(campos, theta, device=campos.device, )

        if self.off_set_mesh:
            meshes = self.meshes.offset_verts(deform_verts)
        else:
            meshes = self.meshes
        get = forward_interpolate_voge(R, T, meshes, self.memory.repeat(R.shape[0], 1), rasterizer=self.rasterizer, )

        if self.post_process is not None:
            get = self.post_process(get)
        return get


class MeshDeformModuleMesh(nn.Module):
    def __init__(
        self,
        vertices,
        faces,
        features,
        rasterizer,
        post_process=None,
        off_set_mesh=False,
    ):
        super().__init__()

        memory_bank = features

        # Convert memory features of vertices to faces
        self.faces = faces
        self.face_memory = None
        self.update_memory(memory_bank=memory_bank, faces=faces)

        # Support multiple meshes at same time
        verts = vertices

        # Create Pytorch3D meshes
        self.meshes = Meshes(verts=verts, faces=faces, textures=None)

        # Device is used during theta to R
        self.rasterizer = rasterizer
        self.post_process = post_process
        self.off_set_mesh = off_set_mesh

    def update_memory(self, memory_bank, faces=None):
        if faces is None:
            faces = self.faces
        # Convert memory features of vertices to faces
        self.face_memory = [
                vertex_memory_to_face_memory(m, f).to(m.device)
                for m, f in zip(memory_bank, faces)
            ]

        # print('face_memory: ', self.face_memory[0].shape, 'face: ', faces[0].shape)

    def to(self, *args, **kwargs):
        if "device" in kwargs.keys():
            device = kwargs["device"]
        else:
            device = args[0]
        super().to(device)
        self.rasterizer.cameras = self.rasterizer.cameras.to(device)
        self.face_memory = [memory.to(device) for memory in self.face_memory]
        self.meshes = self.meshes.to(device)
        return self

    def update_rasterizer(self, rasterizer):
        device = self.rasterizer.cameras.device
        self.rasterizer = rasterizer
        self.rasterizer.cameras = self.rasterizer.cameras.to(device)

    def cuda(self, device=None):
        return self.to(torch.device("cuda"))

    def net_deform(self, net, encoder, chosen_vert0, latent_mix0):
        # vert0 -> [V, 3]
        N = latent_mix0.shape[0]
        # print("latent_mix0: ", latent_mix0.shape)
        vert0 = chosen_vert0.to(self.meshes.device)
        V = vert0.shape[0]
        vert0 = vert0.unsqueeze(0).expand(N, -1, -1).contiguous().view(-1, 3)

        latent_get = latent_mix0.unsqueeze(1).expand(-1, V, -1).contiguous().view(-1, latent_mix0.shape[-1])
        get = net(encoder(vert0), latent_get)
        return get

    def forward(self, campos, theta, blur_radius=0, deform_verts=None, mode="bilinear", indexs=None, part_poses=None,
                **kwargs):
        if indexs is not None:
            meshes, _ = func_reselect(self.meshes, indexs)
            face_memory = torch.cat([self.face_memory[idx] for idx in indexs], dim=0)
        else:
            meshes = self.meshes
            face_memory = torch.cat(self.face_memory, dim=0)
        if self.off_set_mesh:
            meshes = self.meshes.offset_verts(deform_verts)

        device = meshes.device
        R, T = campos_to_R_T(campos, theta, device=campos.device, **kwargs)
        R = R.to(device)
        T = T.to(device)

        if part_poses is not None:
            vert_list = []
            face_list = []
            for idx in range(len(campos)):
                offset = part_poses['offset'][idx][None]
                xscale = part_poses['xscale'][idx]
                yscale = part_poses['yscale'][idx]
                zscale = part_poses['zscale'][idx]
                azimuth = part_poses['azimuth'][idx]
                elevation = part_poses['elevation'][idx]
                theta = part_poses['theta'][idx]
                rotate = rotation_matrix(azimuth, elevation, theta)

                # transform = Transform3d(matrix=rotate.to(device), device=device)
                transform = Transform3d(device=device)
                transform = transform.scale(x=xscale.to(device), y=yscale.to(device), z=zscale.to(device))
                transform = transform.rotate(rotate.to(device))
                transform = transform.translate(offset.to(device))

                # print('transform done')
                verts = meshes._verts_list[0]
                faces = meshes._faces_list[0]
                verts = transform.transform_points(verts)
                vert_list.append(verts)
                face_list.append(faces)

            meshes = Meshes(verts=vert_list, faces=face_list).to(meshes.device)
            # exit(0)

        n_cam = campos.shape[0]
        # print('n_cam: ', n_cam)
        if n_cam > 1:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory.repeat(n_cam, 1, 1),
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )
        else:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory,
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )

        # print('get: ', get.shape)
        if self.post_process is not None:
            get = self.post_process(get)
        return get

    def forward_part_with_deform(self, campos, theta, blur_radius=0, mode="bilinear",
                                 part_poses=None, deform_net=None, deform_encoder=None, deform_latent=None, **kwargs):
        meshes = self.meshes
        face_memory = torch.cat(self.face_memory, dim=0)

        device = meshes.device
        R, T = campos_to_R_T(campos, theta, device=campos.device, **kwargs)
        R = R.to(device)
        T = T.to(device)

        # for the deformation forward
        N = deform_latent.shape[0]

        # softmax
        deform_latent = torch.nn.functional.softmax(deform_latent, dim=1)
        # print("after softmax: ", deform_latent)
        # careful
        batch_size = 1000000

        get_list = []
        original_vert = meshes._verts_list[0]
        ori_offset = part_poses['offset'][0].to(device)
        # print('ori_offset: ', ori_offset.shape, 'original_vert: ', original_vert.shape)
        original_vert += ori_offset

        count = len(original_vert) // batch_size + 1
        for j in range(count):
            vert = original_vert[j * batch_size: min((j + 1) * batch_size, len(original_vert))].to(device)
            vert = vert.unsqueeze(0).expand(N, -1, -1).contiguous().view(-1, 3)

            get = self.net_deform(deform_net, deform_encoder, vert, deform_latent)

            get_list.append(get)

        deformation = torch.cat(get_list, dim=0)

        verts = original_vert + deformation
        faces = meshes._faces_list[0]

        meshes = Meshes(verts=[verts], faces=[faces]).to(meshes.device)

        vert_list = []
        face_list = []
        for idx in range(len(campos)):
            offset = part_poses['offset'][idx][None]
            xscale = part_poses['xscale'][idx]
            yscale = part_poses['yscale'][idx]
            zscale = part_poses['zscale'][idx]
            azimuth = part_poses['azimuth'][idx]
            elevation = part_poses['elevation'][idx]
            theta = part_poses['theta'][idx]
            rotate = rotation_matrix(azimuth, elevation, theta)

            # transform = Transform3d(matrix=rotate.to(device), device=device)
            transform = Transform3d(device=device)
            transform = transform.scale(x=xscale.to(device), y=yscale.to(device), z=zscale.to(device))
            transform = transform.rotate(rotate.to(device))
            transform = transform.translate(offset.to(device))

            # print('transform done')
            verts = meshes._verts_list[0]
            faces = meshes._faces_list[0]
            verts = transform.transform_points(verts)
            vert_list.append(verts)
            face_list.append(faces)

        meshes = Meshes(verts=vert_list, faces=face_list).to(meshes.device)

        n_cam = campos.shape[0]
        # print('n_cam: ', n_cam)
        if n_cam > 1:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory.repeat(n_cam, 1, 1),
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )
        else:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory,
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )

        # print('get: ', get.shape)
        if self.post_process is not None:
            get = self.post_process(get)

        return get, deformation

    def forward_with_deform(self, campos, theta, blur_radius=0, mode="bilinear", deform_net=None,
                            deform_encoder=None, deform_latent=None, **kwargs):
        meshes = self.meshes
        face_memory = torch.cat(self.face_memory, dim=0)

        device = meshes.device
        R, T = campos_to_R_T(campos, theta, device=campos.device, **kwargs)
        R = R.to(device)
        T = T.to(device)

        # for the deformation forward
        N = deform_latent.shape[0]

        # softmax
        deform_latent = torch.nn.functional.softmax(deform_latent, dim=1)
        # print("after softmax: ", deform_latent)
        # careful
        batch_size = 1000000

        get_list = []
        original_vert = meshes._verts_list[0]
        count = len(original_vert) // batch_size + 1
        for j in range(count):
            vert = original_vert[j * batch_size: min((j + 1) * batch_size, len(original_vert))].to(device)
            vert = vert.unsqueeze(0).expand(N, -1, -1).contiguous().view(-1, 3)

            get = self.net_deform(deform_net, deform_encoder, vert, deform_latent)

            get_list.append(get)

        deformation = torch.cat(get_list, dim=0)

        verts = original_vert + deformation
        faces = meshes._faces_list[0]

        meshes = Meshes(verts=[verts], faces=[faces]).to(meshes.device)

        n_cam = campos.shape[0]
        # print('n_cam: ', n_cam)
        if n_cam > 1:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory.repeat(n_cam, 1, 1),
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )
        else:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory,
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )

        # print('get: ', get.shape)
        if self.post_process is not None:
            get = self.post_process(get)
        return get, deformation

    def forward_whole(self, campos, theta, blur_radius=0, deform_verts=None, mode="bilinear", part_poses=None, **kwargs):
        meshes = self.meshes
        face_memory = torch.cat(self.face_memory, dim=0)
        if self.off_set_mesh:
            meshes = self.meshes.offset_verts(deform_verts)

        device = meshes.device
        R, T = campos_to_R_T(campos, theta, device=campos.device)
        R = R.to(device)
        T = T.to(device)

        loss = torch.tensor(0.).to(device)
        if part_poses is not None:
            vert_list = []
            face_list = []
            for idx in range(len(campos)):
                offsets = part_poses['offset'][idx]
                xscales = part_poses['xscale'][idx]
                yscales = part_poses['yscale'][idx]
                zscales = part_poses['zscale'][idx]
                azimuths = part_poses['azimuth'][idx]
                elevations = part_poses['elevation'][idx]
                thetas = part_poses['theta'][idx]
                whole_vert = []
                whole_face = []
                for part_id in range(len(offsets)):
                    offset = offsets[part_id][None]
                    xscale = xscales[part_id]
                    yscale = yscales[part_id]
                    zscale = zscales[part_id]
                    azimuth = azimuths[part_id]
                    elevation = elevations[part_id]
                    theta = thetas[part_id]
                    rotate = rotation_matrix(azimuth, elevation, theta)

                    transform = Transform3d(device=device)
                    transform = transform.scale(x=xscale.to(device), y=yscale.to(device), z=zscale.to(device))
                    transform = transform.rotate(rotate.to(device))
                    transform = transform.translate(offset.to(device))

                    # print('transform done')
                    verts = meshes[part_id]._verts_list[0]
                    faces = meshes[part_id]._faces_list[0]
                    verts = transform.transform_points(verts)
                    faces = faces + len(whole_vert)
                    whole_vert.extend(verts)
                    whole_face.extend(faces)

                whole_vert = torch.stack(whole_vert, dim=0)
                whole_face = torch.stack(whole_face, dim=0)
                vert_list.append(whole_vert)
                face_list.append(whole_face)

                if 'near_pairs' in kwargs.keys():
                    threshold = 0.01
                    for pair in kwargs['near_pairs']:
                        id_1 = pair[0]
                        id_2 = pair[1]
                        vert_1 = whole_vert[id_1]
                        vert_2 = whole_vert[id_2]
                        dist = torch.norm(vert_1 - vert_2)
                        # print(dist)
                        if dist > threshold:
                            # print('dist too large: ', dist)
                            loss += dist - threshold

            meshes = Meshes(verts=vert_list, faces=face_list).to(meshes.device)

        n_cam = campos.shape[0]
        if n_cam > 1:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory.repeat(n_cam, 1, 1),
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )
        else:
            get = forward_interpolate(
                R,
                T,
                meshes,
                face_memory,
                rasterizer=self.rasterizer,
                blur_radius=blur_radius,
                mode=mode,
            )

        # print('get: ', get.shape)
        if self.post_process is not None:
            get = self.post_process(get)
        return get, loss
