import torch
import x_transformers
import traceback
import numpy as np
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
)
from functools import partial
from math import pi

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.ReLU())
            prev_dim = h
        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        model_dtype = next(self.parameters()).dtype
        x = x.to(model_dtype)
        output = self.model(x)
        return torch.cat((output, x), dim=-1)

class MeshUnwarpAutoencoder(Module):
    def __init__(
        self,
        encoder_depth = 8,
        encoder_heads = 8,
        encoder_dim = 512,
        decoder_fine_dim = 128,
        dim_coor_embed = 32,
        dim_normal_vert_embed = 16,
        dim_curve_embed = 16,
        dim_uv_embed = 64,
        pad_id = -1,
        dropout = 0.,
        graph_layers = 5,
    ):
        super().__init__()
        self.coor_mlp = MLP(input_dim=3, hidden_dims=[dim_coor_embed//4, dim_coor_embed//2], output_dim=dim_coor_embed)
        self.normal_mlp = MLP(input_dim=3, hidden_dims=[dim_normal_vert_embed//4, dim_normal_vert_embed//2], output_dim=dim_normal_vert_embed)
        self.curve_mlp = MLP(input_dim=1, hidden_dims=[dim_curve_embed//4, dim_curve_embed//2], output_dim=dim_curve_embed)
        self.degree_mlp = MLP(input_dim=1, hidden_dims=[dim_degree_embed//4, dim_degree_embed//2], output_dim=dim_degree_embed)
        self.uv_mlp = MLP(input_dim=2, hidden_dims=[dim_uv_embed//4, dim_uv_embed//2], output_dim=dim_uv_embed) 
        init_dim = dim_coor_embed + 3 + dim_normal_vert_embed + 3 + dim_curve_embed + 1 + dim_degree_embed + 1 + dim_uv_embed + 2

        self.patch_size = patch_size
        self.project_in = nn.Linear(init_dim * patch_size, encoder_dim)
        
        # graph init
        self.graph_layers = graph_layers
        if self.graph_layers > 0:
            sageconv_kwargs = dict(
                normalize = True,
                project = True,
            )
            self.init_sage_conv = SAGEConv(encoder_dim, encoder_dim, **sageconv_kwargs)
            self.init_encoder_act_and_norm = nn.Sequential(
                nn.SiLU(),
                nn.LayerNorm(encoder_dim)
            )
            self.graph_encoders = nn.ModuleList([])

        for _ in range(graph_layers - 1):
            sage_conv = SAGEConv(
                encoder_dim,
                encoder_dim,
                **sageconv_kwargs
            )
            self.graph_encoders.append(sage_conv)
        
        self.encoder = Encoder(
            dim = encoder_dim,
            depth = encoder_depth,
            heads = encoder_heads,
            attn_flash = True,
            attn_dropout = dropout,
            ff_dropout = dropout,
        )
        self.pad_id = pad_id
        
        self.init_decoder = nn.Sequential(
            nn.Linear(encoder_dim, encoder_dim),
            nn.SiLU(),
            nn.LayerNorm(encoder_dim),
        )

        self.decoder_coarse = Encoder(
            dim = encoder_dim,
            depth = encoder_depth // 2,
            heads = encoder_heads,
            attn_flash = True,
            attn_dropout = dropout,
            ff_dropout = dropout,
        )
        self.coarse_to_fine = nn.Linear(encoder_dim, decoder_fine_dim)
        self.decoder_fine = Encoder(
            dim = decoder_fine_dim,
            depth = encoder_depth // 2,
            heads = encoder_heads,
            attn_flash = True,
            attn_dropout = dropout,
            ff_dropout = dropout,
        )
        self.output_uv = nn.Sequential(
            nn.Linear(decoder_fine_dim, decoder_fine_dim//4),
            nn.ReLU(),
            nn.Linear(decoder_fine_dim//4, decoder_fine_dim//8),
            nn.ReLU(),
            nn.Linear(decoder_fine_dim//8, 2),
            nn.Tanh()
        )
    @beartype
    def encode(
        self,
        *,
        vertices:         TensorType['b', 'nv', 3, float],
        faces:            TensorType['b', 'nf', 3, int],
        edges:       TensorType['b', 'e', 2, int],
        vertices_mask:        TensorType['b', 'nv', bool],
        normals:          TensorType['b', 'nv', 3, float],
        curve:          TensorType['b', 'nv', 1, float],
        degree:          TensorType['b', 'nv', 1, float],
        blender_uvs:      TensorType['b', 'nv', 2, float],
    ):
        batch, num_vertices, num_coors, device = *vertices.shape, vertices.device
        _, num_faces, _ = faces.shape
    
        vert_coor_embed = self.coor_mlp(vertices)
        normal_embed = self.normal_mlp(normals)
        curve_embed = self.curve_mlp(curve)
        degree_embed = self.degree_mlp(degree)
        uv_embed = self.uv_mlp(blender_uvs)
        vert_embed, _ = pack([vert_coor_embed, normal_embed, curve_embed, degree_embed, uv_embed], 'b nv *')

        vert_embed = rearrange(vert_embed, 'b (num_patch patch_size) d -> b num_patch (patch_size d)', 
                               patch_size = self.patch_size)
        vert_embed = self.project_in(vert_embed)

        # init graph 
        if self.graph_layers > 0:
            orig_vert_embed_shape = vert_embed.shape[:2]
            vert_embed = vert_embed[vertices_mask]
            
            vert_index_offsets = reduce(vertices_mask.long(), 'b nv -> b', 'sum')
            vert_index_offsets = F.pad(vert_index_offsets.cumsum(dim = 0), (1, -1), value = 0)
            vert_index_offsets = rearrange(vert_index_offsets, 'b -> b 1 1')
            
            edges = edges + vert_index_offsets
            edges = rearrange(edges, 'be ij -> ij be')
            
            
            vert_embed = self.init_sage_conv(vert_embed, edges)
            vert_embed = self.init_encoder_act_and_norm(vert_embed)

            for conv in self.graph_encoders:
                vert_embed = conv(vert_embed, edges)
            
            shape = (*orig_vert_embed_shape, vert_embed.shape[-1])
            vert_embed = vert_embed.new_zeros(shape).masked_scatter(rearrange(vertices_mask, '... -> ... 1'), vert_embed)  
        
        vert_embed = self.encoder(vert_embed, mask=vertices_mask)

        return vert_embed

    @beartype
    def decode(
        self,
        encoded: TensorType['b', 'n', 'd', float],
        vert_mask:  TensorType['b', 'n', bool]
    ):
        conv_vert_mask = rearrange(vert_mask, 'b n -> b n 1')
        x = encoded
        x = x.masked_fill(~conv_vert_mask, 0.)

        x = self.init_decoder(x)
        x_coarse = self.decoder_coarse(x, mask = vert_mask)
        x = self.coarse_to_fine(x_coarse)
        x = self.decoder_fine(x, mask = vert_mask)
        
        return x
    
    def compute_elastic_energy(self, vertices_batch, uvs_batch, faces_batch, vertices_mask_batch):
        batch_num = vertices_batch.shape[0]
        total_loss = []
        for i in range(batch_num):
            vertices_mask = vertices_mask_batch[i]
            faces = faces_batch[i]
            uv = uvs_batch[i]
            vertices = vertices_batch[i]
            vertices = vertices[vertices_mask]
            uv = uv[vertices_mask] 
            faces = faces[~torch.all(faces == -1, dim=1)]
            tri_3d = vertices[faces].permute(0,2,1)
            tri_uv = uv[faces].permute(0,2,1)    
            inv_A = torch.linalg.pinv(tri_3d)
            J = torch.einsum('fij,fjk->fik', tri_uv, inv_A)
            U, S, Vh = torch.linalg.svd(J)
            L_disto = (S[:,0] - S[:,1]).abs().mean()
            total_loss.append(L_disto)
            
        return torch.mean(torch.tensor(total_loss, device=vertices_batch.device))

    def get_uvs_normals(self, uvs, faces, normals, vertices_mask, faces_mask):
        normals[:, :, 1] = torch.abs(normals[:, :, 1])
        b, nf, _ = faces.shape
        batch_indices = torch.arange(b).view(-1, 1, 1).expand(-1, nf, 3)
        uvs_faces = uvs[batch_indices, faces]

        p1 = uvs_faces[:, :, 0, :] 
        p2 = uvs_faces[:, :, 1, :]
        p3 = uvs_faces[:, :, 2, :]

        v1 = p2 - p1 
        v2 = p3 - p1
        cross_product = v1[:, :, 0] * v2[:, :, 1] - v1[:, :, 1] * v2[:, :, 0]
        cross_product[~faces_mask] = 1
        normals_mask = (cross_product > 0).int()
        uvs_normals = normals.clone()
        uvs_normals[:, :, 1] = torch.where(normals_mask == 0, -uvs_normals[:, :, 1], uvs_normals[:, :, 1])
            
        return uvs_normals, normals_mask
    
    def cal_overlap_num(self, uvs_batch, faces_batch, normals_batch, vertices_mask, faces_mask, normals_mask_batch, k=3):
        below_threshold = (normals_mask_batch < 1e-4)
        zero_count_per_batch = below_threshold.sum(dim=1).float()
        overlap_loss = zero_count_per_batch.mean()

        return overlap_loss
    
    def get_silhouette(self, uv_pred_batch, uvs_batch, faces_batch, vertices_mask, faces_mask):
        b = uvs_batch.shape[0]
        device = uvs_batch.device
        cameras = FoVPerspectiveCameras(device=device)
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=100, 
        )
        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras, 
                raster_settings=raster_settings
            ),
            shader=SoftSilhouetteShader(blend_params=blend_params)
        )
        distance =  1 
        elevation = 0 
        azimuth = 0.0 
        R, T = look_at_view_transform(distance, elevation, azimuth, device=device)

        silhouette_gt_batch = []
        silhouette_pred_batch = []
        for i in range(b):
            uvs = uvs_batch[i][vertices_mask[i]].type(torch.float32)
            uv_pred = uv_pred_batch[i][vertices_mask[i]].type(torch.float32)
            faces = faces_batch[i][faces_mask[i]]
            zero_column = torch.zeros(uvs.size(0), 1, device = device)
            uvs_3d = torch.cat([uvs[:, :1], uvs[:, 1:], zero_column], dim=1)
            uv_pred_3d = torch.cat([uv_pred[:, :1], uv_pred[:, 1:], zero_column], dim=1)
            verts_rgb = torch.ones_like(uvs_3d)[None]
            textures = TexturesVertex(verts_features=verts_rgb.to(device))
            mesh_gt = Meshes(
                verts=[uvs_3d.to(device)],   
                faces=[faces.to(device)], 
                textures=textures
            )
            mesh_pred = Meshes(
                verts=[uv_pred_3d.to(device)],   
                faces=[faces.to(device)], 
                textures=textures
            )
            silhouette_gt = silhouette_renderer(meshes_world=mesh_gt, R=R, T=T)
            silhouette_pred = silhouette_renderer(meshes_world=mesh_pred, R=R, T=T)
            silhouette_gt_batch.append(silhouette_gt)
            silhouette_pred_batch.append(silhouette_pred)
        silhouette_gt_batch = torch.stack(silhouette_gt_batch).squeeze(1)
        silhouette_pred_batch = torch.stack(silhouette_pred_batch).squeeze(1)
        silhouette_loss = torch.mean((silhouette_gt_batch[..., 3] - silhouette_pred_batch[..., 3]) ** 2)

        return silhouette_loss

    def forward(
        self,
        *,
        vertices:          TensorType['b', 'nv', 3, float],
        faces:             TensorType['b', 'nf', 3, int],
        uvs:               TensorType['b', 'nv', 2, float],
        normals:           TensorType['b', 'nv', 3, float],
        curve:             TensorType['b', 'nv', 1, float],
        degree:            TensorType['b', 'nv', 1, float],
        blender_uvs:        TensorType['b', 'nv', 2, float],
        **kwargs
    ):  
        vertices_mask = reduce(vertices != self.pad_id, 'b nv c -> b nv', 'all')
        vertices_mask = vertices_mask.contiguous()
        faces_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all')
        faces_mask = faces_mask.contiguous()
        edges_mask = reduce(edges != self.pad_id, 'b e ij -> b e', 'all')
        edges_mask = edges_mask.contiguous()
        encoded = self.encode(
            vertices = vertices,
            faces = faces,
            edges = edges,
            edges_mask = edges_mask,
            vertices_mask = vertices_mask,
            normals = normals,
            curve = curve,
            degree = degree,
            blender_uvs = blender_uvs,
        )
        decode = self.decode(
            encoded,
            vert_mask = vertices_mask
        )
        pred_uv_offset = self.output_uv(decode)

        with autocast(enabled = False):
            if self.blender_uv_weight  > 0.:
                uv_pred = blender_uvs + pred_uv_offset
            recon_losses = F.l1_loss(uv_pred, uvs, reduction='none').sum(dim=2)
            if self.boundary_weight >= 0.:
                recon_losses = recon_losses + recon_losses * self.boundary_weight
            recon_loss = recon_losses[vertices_mask].mean()
            silhouette_loss = self.get_silhouette(uv_pred - 0.5, uvs - 0.5, faces, vertices_mask,faces_mask)
            uvs_normals, normals_mask_batch = self.get_uvs_normals(uv_pred, faces, face_normals, vertices_mask, faces_mask)
            overlap_loss = self.cal_overlap_num(uv_pred, faces, uvs_normals, vertices_mask, faces_mask, normals_mask_batch)
            if self.distort_weight != 0.0:
                distort_loss = self.compute_elastic_energy(vertices, uv_pred, faces, vertices_mask)
        
        # calculate total loss
        total_loss = 1.0 * recon_loss + 1.0 * silhouette_loss + 0.01 * overlap_loss + 0.0001 * distort_loss
        
        return total_loss, uv_pred, distort_loss, overlap_loss, silhouette_loss