import torch
from torch import nn
import numpy as np
import kaolin
import tqdm
from pytorch3d.ops.knn import knn_gather, knn_points
from pytorch3d.transforms import so3_exponential_map

from modules.mlp import MLP
from modules.camera import CameraModule
from modules.pos_embedding import get_embedder
from modules.pts_embedding import PointCloudEncoder

from utils.dmtet_utils import marching_tetrahedra

class MeshHeadModule(nn.Module):
    def __init__(self, cfg, init_landmarks_3d_neutral):
        super(MeshHeadModule, self).__init__()
        
        self.geo_mlp = MLP(cfg['geo_mlp'], last_op=nn.Tanh())
        self.exp_color_mlp = MLP(cfg['exp_color_mlp'], last_op=nn.Sigmoid())
        self.pose_color_mlp = MLP(cfg['pose_color_mlp'], last_op=nn.Sigmoid())
        self.exp_deform_mlp = MLP(cfg['exp_deform_mlp'], last_op=nn.Tanh())
        self.pose_deform_mlp = MLP(cfg['pose_deform_mlp'], last_op=nn.Tanh())

        self.landmarks_3d_neutral = nn.Parameter(init_landmarks_3d_neutral)

        self.pos_embedding, _ = get_embedder(cfg['pos_freq'])
        
        self.pts_embedding = PointCloudEncoder(max_length=4000000, num_out_scale=4)

        self.model_bbox = cfg['model_bbox']
        self.dist_threshold_near = cfg['dist_threshold_near']
        self.dist_threshold_far = cfg['dist_threshold_far']
        self.deform_scale = cfg['deform_scale']

        tets_data = np.load('assets/tets_data.npz')
        self.register_buffer('tet_verts', torch.from_numpy(tets_data['tet_verts']))
        self.register_buffer('tets', torch.from_numpy(tets_data['tets']))
        self.grid_res = 128

        if cfg['subdivide']:
            self.subdivide()
            
        self.render = CameraModule()

    def geometry(self, geo_input):
        pred = self.geo_mlp(geo_input)
        return pred

    def exp_color(self, color_input):
        verts_color = self.exp_color_mlp(color_input)
        return verts_color
    
    def pose_color(self, color_input):
        verts_color = self.pose_color_mlp(color_input)
        return verts_color
    
    def exp_deform(self, deform_input):
        deform = self.exp_deform_mlp(deform_input)
        return deform
    
    def pose_deform(self, deform_input):
        deform = self.pose_deform_mlp(deform_input)
        return deform
        
    def get_landmarks(self):
        return self.landmarks_3d_neutral

    def subdivide(self, focus_on_mouth=True):
        new_tet_verts, new_tets = kaolin.ops.mesh.subdivide_tetmesh(self.tet_verts.unsqueeze(0), self.tets)
        self.tet_verts = new_tet_verts[0]
        self.tets = new_tets
        self.grid_res *= 2

        if focus_on_mouth:
            mouth_keypoints = self.landmarks_3d_neutral[48:66] 
            mouth_center = torch.mean(mouth_keypoints, dim=0, keepdim=True)
            mouth_center[:, 2] = mouth_keypoints[:, 2].min() 
            
            max_dist = (mouth_keypoints - mouth_center).abs().max(0)[0] * 1.2
            min_coords = mouth_center - max_dist
            max_coords = mouth_center + max_dist
 
            in_mouth_region = ((self.tet_verts >= min_coords) & (self.tet_verts <= max_coords)).all(dim=1)

            mouth_tets_mask = torch.any(in_mouth_region[self.tets], dim=1)
            mouth_tets = self.tets[mouth_tets_mask]
            
            if mouth_tets.shape[0] > 0:
                mouth_tets_verts_unique = torch.unique(mouth_tets)
                sub_tet_verts = self.tet_verts[mouth_tets_verts_unique]
                
                old_to_new_idx = torch.zeros(self.tet_verts.shape[0], dtype=torch.long, device=self.tet_verts.device)
                old_to_new_idx[mouth_tets_verts_unique] = torch.arange(mouth_tets_verts_unique.shape[0], device=self.tet_verts.device)
                
                sub_tets = old_to_new_idx[mouth_tets]
                
                sub_tet_verts_new, sub_tets_new = kaolin.ops.mesh.subdivide_tetmesh(sub_tet_verts.unsqueeze(0), sub_tets)
                sub_tet_verts_new = sub_tet_verts_new[0]

                non_mouth_tets = self.tets[~mouth_tets_mask]
                
                self.tet_verts = torch.cat([self.tet_verts, sub_tet_verts_new[sub_tet_verts.shape[0]:]])
                
                new_tets_offset = sub_tets_new + (self.tet_verts.shape[0] - sub_tet_verts_new.shape[0])
                mask = new_tets_offset >= self.tet_verts.shape[0]
                new_tets_offset[mask] = new_tets_offset[mask] - (self.tet_verts.shape[0] - sub_tet_verts_new.shape[0])

                self.tets = torch.cat([non_mouth_tets, new_tets_offset])

    def reconstruct(self, data):
        ear = data['ear']
        B = data['exp_coeff'].shape[0]

        query_pts = self.tet_verts.unsqueeze(0).repeat(B, 1, 1)
        geo_input = self.pts_embedding(query_pts).permute(0, 2, 1)
        geo_input_prime = self.pos_embedding(query_pts).permute(0, 2, 1)
        pred = self.geometry(torch.cat([geo_input, geo_input_prime], 1))

        sdf, deform, features = pred[:, :1, :], pred[:, 1:4, :], pred[:, 4:, :]
        
        sdf = sdf.permute(0, 2, 1)
        features = features.permute(0, 2, 1)
        verts_deformed = (query_pts + torch.tanh(deform.permute(0, 2, 1)) / self.grid_res)
        verts_list, features_list, faces_list = marching_tetrahedra(verts_deformed, features, self.tets, sdf)

        data['verts0_list'] = verts_list
        data['faces_list'] = faces_list

        verts_batch = []
        verts_features_batch = []
        num_pts_max = 0
        for b in range(B):
            if verts_list[b].shape[0] > num_pts_max:
                num_pts_max = verts_list[b].shape[0]
            
        for b in range(B):
            verts_batch.append(torch.cat([verts_list[b], torch.zeros([num_pts_max - verts_list[b].shape[0], verts_list[b].shape[1]], device=verts_list[b].device)], 0))
            verts_features_batch.append(torch.cat([features_list[b], torch.zeros([num_pts_max - features_list[b].shape[0], features_list[b].shape[1]], device=features_list[b].device)], 0))
        verts_batch = torch.stack(verts_batch, 0)
        verts_features_batch = torch.stack(verts_features_batch, 0)

        dists, idx, _ = knn_points(verts_batch, data['landmarks_3d_neutral'])
        exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0)
        pose_weights = 1 - exp_weights
        
        ear = ear.repeat(1, 64)
        exp_coeff = torch.cat([data['exp_coeff'], ear], 1)
        
        exp_color_input = torch.cat([verts_features_batch.permute(0, 2, 1), exp_coeff.unsqueeze(-1).repeat(1, 1, num_pts_max)], 1)
        verts_color_batch = self.exp_color(exp_color_input).permute(0, 2, 1) * exp_weights

        pose_color_input = torch.cat([verts_features_batch.permute(0, 2, 1), self.pos_embedding(data['pose']).unsqueeze(-1).repeat(1, 1, num_pts_max)], 1)
        verts_color_batch = verts_color_batch + self.pose_color(pose_color_input).permute(0, 2, 1) * pose_weights

        deform_imput = torch.cat([self.pts_embedding(verts_batch).permute(0, 2, 1), self.pos_embedding(verts_batch).permute(0, 2, 1)], 1)
        
        exp_deform_input = torch.cat([deform_imput, exp_coeff.unsqueeze(-1).repeat(1, 1, num_pts_max)], 1)
        exp_deform = self.exp_deform(exp_deform_input).permute(0, 2, 1)
        verts_batch = verts_batch + exp_deform * exp_weights * self.deform_scale

        pose_deform_input = torch.cat([deform_imput, self.pos_embedding(data['pose']).unsqueeze(-1).repeat(1, 1, num_pts_max)], 1)
        pose_deform = self.pose_deform(pose_deform_input).permute(0, 2, 1)
        verts_batch = verts_batch + pose_deform * pose_weights * self.deform_scale

        if 'pose' in data:
            R = so3_exponential_map(data['pose'][:, :3])
            T = data['pose'][:, None, 3:]
            S = data['scale'][:, :, None]
            verts_batch = torch.bmm(verts_batch * S, R.permute(0, 2, 1)) + T

        data['exp_deform'] = exp_deform
        data['pose_deform'] = pose_deform
        data['verts_list'] = [verts_batch[b, :verts_list[b].shape[0], :] for b in range(B)]
        data['verts_color_list'] = [verts_color_batch[b, :verts_list[b].shape[0], :] for b in range(B)]
        return data
    
    def reconstruct_neutral(self):
        query_pts = self.tet_verts.unsqueeze(0)
        geo_input = self.pts_embedding(query_pts).permute(0, 2, 1)
        geo_input_prime = self.pos_embedding(query_pts).permute(0, 2, 1)
        
        pred = self.geometry(torch.cat([geo_input, geo_input_prime], 1))

        sdf, deform, features = pred[:, :1, :], pred[:, 1:4, :], pred[:, 4:, :]
        sdf = sdf.permute(0, 2, 1)
        features = features.permute(0, 2, 1)
        verts_deformed = (query_pts + torch.tanh(deform.permute(0, 2, 1)) / self.grid_res)
        verts_list, features_list, faces_list = marching_tetrahedra(verts_deformed, features, self.tets, sdf)

        data = {}
        data['verts'] = verts_list[0]
        data['faces'] = faces_list[0]
        data['verts_feature'] = features_list[0]
        return data
    
    def query_sdf(self, data):
        query_pts = data['query_pts']
        geo_input = self.pts_embedding(query_pts).permute(0, 2, 1)
        geo_input_prime = self.pos_embedding(query_pts).permute(0, 2, 1)
        pred = self.geometry(torch.cat([geo_input, geo_input_prime], 1))
        sdf = pred[:, :1, :]
        sdf = sdf.permute(0, 2, 1)
        data['sdf'] = sdf
        return data
    
    def deform(self, data):
        exp_coeff = data['exp_coeff']
        ear = data['ear']
        query_pts = data['query_pts']
        geo_input = self.pts_embedding(query_pts).permute(0, 2, 1)
        geo_input_prime = self.pos_embedding(query_pts).permute(0, 2, 1)
        pred = self.geometry(torch.cat([geo_input, geo_input_prime], 1))
        sdf, deform = pred[:, :1, :], pred[:, 1:4, :]
        query_pts = (query_pts + torch.tanh(deform).permute(0, 2, 1) / self.grid_res)
        
        defrom_imput = torch.cat([self.pts_embedding(query_pts).permute(0, 2, 1), self.pos_embedding(query_pts).permute(0, 2, 1)], 1)

        ear = ear.repeat(1, 64)
        
        exp_coeff = torch.cat([exp_coeff, ear], 1)
        
        exp_deform_input = torch.cat([defrom_imput, exp_coeff.unsqueeze(-1).repeat(1, 1, query_pts.shape[1])], 1)
        exp_deform = self.exp_deform(exp_deform_input).permute(0, 2, 1)
        deformed_pts = query_pts + exp_deform * self.deform_scale

        data['deformed_pts'] = deformed_pts
        return data
    
    def pre_train_sphere(self, iter, device):
        loss_fn = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-3)

        for i in tqdm.tqdm(range(iter)):
            query_pts = torch.rand((8, 1024, 3), device=device) * 3 - 1.5
            ref_value  = torch.sqrt((query_pts**2).sum(-1)) - 1.0
            data = {
                'query_pts': query_pts
                }
            data = self.query_sdf(data)
            sdf = data['sdf']
            loss = loss_fn(sdf[:, :, 0], ref_value)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
    def forward(self, data):
        landmarks_3d_neutral = self.get_landmarks()[None].repeat(data['landmarks_3d'].shape[0], 1, 1)
        
        data['landmarks_3d_neutral'] = landmarks_3d_neutral
        
        deform_data = {
            'exp_coeff': data['exp_coeff'],
            'ear': data['ear'],
            'query_pts': data['landmarks_3d_neutral']
        }
        
        deform_data = self.deform(deform_data)
        pred_landmarks_3d_can = deform_data['deformed_pts']

        deform_data = self.query_sdf(deform_data)
        sdf_landmarks_3d = deform_data['sdf']
        
        data = self.reconstruct(data)
        
        data = self.render.render_mesh(data,data['images'].permute(0, 1, 3, 4, 2).shape[2])
        
        render_images = data['render_images']
        render_soft_masks = data['render_soft_masks']
        exp_deform = data['exp_deform']
        pose_deform = data['pose_deform']
        verts_list = data['verts_list']
        faces_list = data['faces_list']
        
        return pred_landmarks_3d_can, sdf_landmarks_3d, render_images, render_soft_masks, exp_deform, pose_deform, verts_list, faces_list
