import torch
from torch import nn
import numpy as np
from einops import rearrange
from simple_knn._C import distCUDA2
from pytorch3d.ops.knn import knn_points
from pytorch3d.transforms import so3_exponential_map
from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix, matrix_to_quaternion

from modules.mlp import MLP
from modules.pos_embedding import get_embedder
from modules.pts_embedding import PointCloudEncoder
from utils.general_utils import inverse_sigmoid
from modules.camera import CameraModule

class GaussianHeadModule(nn.Module):
    def __init__(self, cfg, xyz, feature, landmarks_3d_neutral, add_mouth_points=False):
        super(GaussianHeadModule, self).__init__()

        self.original_pts_count = xyz.shape[0]
        
        if add_mouth_points and cfg['num_add_mouth_points'] > 0:
            mouth_keypoints = landmarks_3d_neutral[48:66]  
            mouth_center = torch.mean(mouth_keypoints, dim=0, keepdim=True)
            mouth_center[:, 2] = mouth_keypoints[:, 2].min() 
            max_dist = (mouth_keypoints - mouth_center).abs().max(0)[0]
            
            teeth_upper_center = mouth_center.clone()
            teeth_upper_center[:, 1] += max_dist[1] * 0.1 
            teeth_upper_points = torch.zeros([cfg['num_add_mouth_points']//3, 3], device=xyz.device)
            
            for i in range(teeth_upper_points.shape[0]):
                angle = torch.tensor(np.pi * (i / (teeth_upper_points.shape[0] - 1) - 0.5))
                teeth_upper_points[i, 0] = teeth_upper_center[0, 0] + torch.cos(angle) * max_dist[0] * 0.8
                teeth_upper_points[i, 1] = teeth_upper_center[0, 1] 
                teeth_upper_points[i, 2] = teeth_upper_center[0, 2] + torch.sin(angle) * max_dist[2] * 0.2
            
            teeth_lower_center = mouth_center.clone()
            teeth_lower_center[:, 1] -= max_dist[1] * 0.1 
            teeth_lower_points = torch.zeros([cfg['num_add_mouth_points']//3, 3], device=xyz.device)
            
            for i in range(teeth_lower_points.shape[0]):
                angle = torch.tensor(np.pi * (i / (teeth_lower_points.shape[0] - 1) - 0.5))
                teeth_lower_points[i, 0] = teeth_lower_center[0, 0] + torch.cos(angle) * max_dist[0] * 0.8
                teeth_lower_points[i, 1] = teeth_lower_center[0, 1]
                teeth_lower_points[i, 2] = teeth_lower_center[0, 2] + torch.sin(angle) * max_dist[2] * 0.2
            
            inner_points = (torch.rand([cfg['num_add_mouth_points'] - teeth_upper_points.shape[0] - teeth_lower_points.shape[0], 3], 
                            device=xyz.device) - 0.5) * max_dist * 1.2 + mouth_center
            
            points_add = torch.cat([teeth_upper_points, teeth_lower_points, inner_points], dim=0)
            
            teeth_features = torch.ones([teeth_upper_points.shape[0] + teeth_lower_points.shape[0], feature.shape[1]], 
                                      device=feature.device) * 0.8 
            inner_features = torch.ones([inner_points.shape[0], feature.shape[1]], 
                                       device=feature.device) * 0.3  
            mouth_features = torch.cat([teeth_features, inner_features], dim=0)
            
            self.mouth_points_indices = torch.arange(xyz.shape[0], xyz.shape[0] + points_add.shape[0], device=xyz.device)
            self.teeth_indices = torch.arange(xyz.shape[0], xyz.shape[0] + teeth_upper_points.shape[0] + teeth_lower_points.shape[0], device=xyz.device)
            self.inner_mouth_indices = torch.arange(xyz.shape[0] + teeth_upper_points.shape[0] + teeth_lower_points.shape[0], 
                                                 xyz.shape[0] + points_add.shape[0], device=xyz.device)
            
            xyz = torch.cat([xyz, points_add])
            feature = torch.cat([feature, mouth_features])


        self.xyz = nn.Parameter(xyz)
        self.feature = nn.Parameter(feature)
        self.register_buffer('landmarks_3d_neutral', landmarks_3d_neutral)

        dist2 = torch.clamp_min(distCUDA2(self.xyz.cuda()), 0.0000001).cpu()
        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
        self.scales = nn.Parameter(scales)

        rots = torch.zeros((xyz.shape[0], 4), device=xyz.device)
        rots[:, 0] = 1
        self.rotation = nn.Parameter(rots)

        self.opacity = nn.Parameter(inverse_sigmoid(0.3 * torch.ones((xyz.shape[0], 1))))

        self.exp_color_mlp = MLP(cfg['exp_color_mlp'], last_op=None)
        self.pose_color_mlp = MLP(cfg['pose_color_mlp'], last_op=None)
        self.exp_attributes_mlp = MLP(cfg['exp_attributes_mlp'], last_op=None)
        self.pose_attributes_mlp = MLP(cfg['pose_attributes_mlp'], last_op=None)
        self.exp_deform_mlp = MLP(cfg['exp_deform_mlp'], last_op=nn.Tanh())
        self.pose_deform_mlp = MLP(cfg['pose_deform_mlp'], last_op=nn.Tanh())

        self.pos_embedding, _ = get_embedder(cfg['pos_freq'])
        self.pts_embedding = PointCloudEncoder(max_length=4000000, num_out_scale=4)
        
        self.dist_threshold_near = cfg['dist_threshold_near']
        self.dist_threshold_far = cfg['dist_threshold_far']
        self.deform_scale = cfg['deform_scale']
        self.attributes_scale = cfg['attributes_scale']
        
        self.render = CameraModule()
    
    def generate(self, data):
        B = data['exp_coeff'].shape[0]
        
        xyz = self.xyz.unsqueeze(0).repeat(B, 1, 1)
        feature = torch.tanh(self.feature).unsqueeze(0).repeat(B, 1, 1)

        dists, _, _ = knn_points(xyz, self.landmarks_3d_neutral.unsqueeze(0).repeat(B, 1, 1))
        exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0)
        pose_weights = 1 - exp_weights
        exp_controlled = (dists < self.dist_threshold_far).squeeze(-1)
        pose_controlled = (dists > self.dist_threshold_near).squeeze(-1)

        color = torch.zeros([B, xyz.shape[1], self.exp_color_mlp.dims[-1]], device=xyz.device)
        delta_xyz = torch.zeros_like(xyz, device=xyz.device)
        delta_attributes = torch.zeros([B, xyz.shape[1], self.scales.shape[1] + self.rotation.shape[1] + self.opacity.shape[1]], device=xyz.device)
        for b in range(B):
            feature_exp_controlled = feature[b, exp_controlled[b], :]

            ear_params = data['ear'][b]
            ear_params = ear_params.repeat(64)
            
            exp_color_input = torch.cat([feature_exp_controlled.t(), 
                                         torch.cat([data['exp_coeff'][b], ear_params]).unsqueeze(-1).repeat(1, feature_exp_controlled.shape[0])], 0)[None]
            exp_color = self.exp_color_mlp(exp_color_input)[0].t()
            color[b, exp_controlled[b], :] += exp_color * exp_weights[b, exp_controlled[b], :]

            feature_pose_controlled = feature[b, pose_controlled[b], :]
            pose_color_input = torch.cat([feature_pose_controlled.t(), 
                                               self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, feature_pose_controlled.shape[0])], 0)[None]
            pose_color = self.pose_color_mlp(pose_color_input)[0].t()
            color[b, pose_controlled[b], :] += pose_color * pose_weights[b, pose_controlled[b], :]

            exp_attributes_input = exp_color_input
            exp_delta_attributes = self.exp_attributes_mlp(exp_attributes_input)[0].t()
            delta_attributes[b, exp_controlled[b], :] += exp_delta_attributes * exp_weights[b, exp_controlled[b], :]

            pose_attributes_input = pose_color_input
            pose_attributes = self.pose_attributes_mlp(pose_attributes_input)[0].t()
            delta_attributes[b, pose_controlled[b], :] += pose_attributes * pose_weights[b, pose_controlled[b], :]
            
            xyz_exp_controlled = xyz[b, exp_controlled[b], :].unsqueeze(0)
            exp_deform_value = torch.cat([self.pos_embedding(xyz_exp_controlled).permute(0, 2, 1), self.pts_embedding(xyz_exp_controlled).permute(0, 2, 1)], 1)
            
            
            exp_deform_input = torch.cat([exp_deform_value.squeeze(0), 
                                          torch.cat([data['exp_coeff'][b], ear_params]).unsqueeze(-1).repeat(1, xyz_exp_controlled.shape[1])], 0)[None]
            exp_deform = self.exp_deform_mlp(exp_deform_input)[0].t()
            
            delta_xyz[b, exp_controlled[b], :] += exp_deform * exp_weights[b, exp_controlled[b], :]


            xyz_pose_controlled = xyz[b, pose_controlled[b], :].unsqueeze(0)
            pose_deform_value = torch.cat([self.pos_embedding(xyz_pose_controlled).permute(0, 2, 1), self.pts_embedding(xyz_pose_controlled).permute(0, 2, 1)], 1)
            pose_deform_input = torch.cat([pose_deform_value.squeeze(0), 
                                           self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, xyz_pose_controlled.shape[1])], 0)[None]
            pose_deform = self.pose_deform_mlp(pose_deform_input)[0].t()
            delta_xyz[b, pose_controlled[b], :] += pose_deform * pose_weights[b, pose_controlled[b], :]

        xyz = xyz + delta_xyz * self.deform_scale

        delta_scales = delta_attributes[:, :, 0:3]
        scales = self.scales.unsqueeze(0).repeat(B, 1, 1) + delta_scales * self.attributes_scale
        scales = torch.exp(scales)

        delta_rotation = delta_attributes[:, :, 3:7]
        rotation = self.rotation.unsqueeze(0).repeat(B, 1, 1) + delta_rotation * self.attributes_scale
        rotation = torch.nn.functional.normalize(rotation, dim=2)

        delta_opacity = delta_attributes[:, :, 7:8]
        opacity = self.opacity.unsqueeze(0).repeat(B, 1, 1) + delta_opacity * self.attributes_scale
        opacity = torch.sigmoid(opacity)

        if 'pose' in data:
            R = so3_exponential_map(data['pose'][:, :3])
            T = data['pose'][:, None, 3:]
            S = data['scale'][:, :, None]
            xyz = torch.bmm(xyz * S, R.permute(0, 2, 1)) + T

            rotation_matrix = quaternion_to_matrix(rotation)
            rotation_matrix = rearrange(rotation_matrix, 'b n x y -> (b n) x y')
            R = rearrange(R.unsqueeze(1).repeat(1, rotation.shape[1], 1, 1), 'b n x y -> (b n) x y')
            rotation_matrix = rearrange(torch.bmm(R, rotation_matrix), '(b n) x y -> b n x y', b=B)
            rotation = matrix_to_quaternion(rotation_matrix)

            scales = scales * S

        data['exp_deform'] = exp_deform
        data['xyz'] = xyz
        data['color'] = color
        data['scales'] = scales
        data['rotation'] = rotation
        data['opacity'] = opacity

        return data
    
    def forward(self, data, resolution):
        data = self.generate(data)
        data = self.render.render_gaussian(data, resolution)
        return data