import numpy as np
import torch
import trimesh
import torch.nn as nn
import torch.nn.functional as F

class RASF(nn.Module):
    def __init__(self, resolution=(32,32,32), field_dim=32, num_local_points=32):
        super().__init__()
        self.field = nn.Parameter(torch.rand(field_dim, *resolution), requires_grad=True)
        self.k = len(resolution)
        self.num_local_points = num_local_points

    def sample(self, index):
        """
        index shape: (num_occupancy, 2 or 3), with values between [-1,1]
        return shape: (C, num_occupancy)
        input for grid_sample: (1, C, H, W, (D,)), (1, 1, num_occupancy, 2 or 3)
        output for grid_sample: (1, C, 1, num_occupancy)
        """
        out = F.grid_sample(self.field.unsqueeze(0), unsqueeze_k_times(index, self.k))
        return out.squeeze()                   

    def batch_samples(self, batch_points):
        """
        Sample RASF feature for a batch of points
        Args:
            batch_points: Tensor of shape (B, num_p, 3), the batch of points to sample features,
        Returns:
            RASF_feature: Tensor of shape (B, C, num_p), the RASF feautre of the input batch points, C is the feature
                          dim of the RASF_featre
        """
        # batch_points: B, num_p, 3
        B, num_p, _ = batch_points.shape
        values, indices = (batch_points.unsqueeze(1)-batch_points.unsqueeze(2)).norm(p=1,dim=-1).topk(self.num_local_points,dim=-1,largest=False)
        local_points = batch_points.repeat(1,num_p,1).reshape(B,num_p,num_p,3).gather(2, indices.unsqueeze(-1).repeat(1,1,1,3)) # B, num_p, num_local_points, 3
        relative_local_points = local_points - batch_points.unsqueeze(2)
        
        # normalize to [-1,1] in one object
        batch_max = relative_local_points.view(B,-1).max(-1,)[0]
        batch_min = relative_local_points.view(B,-1).max(-1,)[0]
        zoom_factor = torch.max(batch_max.abs(), batch_min.abs())
        relative_local_points = relative_local_points / zoom_factor.view(-1,1,1,1) # B, num_p, num_local_points, 3
        
        out = F.grid_sample(self.field.unsqueeze(0).expand(B,-1,-1,-1,-1), relative_local_points.unsqueeze(1)).squeeze(2) # B, C, num_p, num_local_points
        # return out.view(B, -1, num_p).contiguous() # B, C*num_p, num_local_points, concat
        RASF_feature=out.max(-1)[0]
        return RASF_feature # B, C, num_p
    def voxel_samples(self, batch_points):
        """
        Sample RASF feature for a small voxel grid
        Args:
            batch_points: Tensor of shape (B, num_p, 3), the voxel grid coordinates flatened to num_p,
        Returns:
            RASF_feature: Tensor of shape (B, C, num_p), the RASF feautre of the input voxel grid coordinates, C is the feature
                          dim of the RASF_featre
        """
        B, num_p, _ = batch_points.shape
        # no need for normalize for voxel grid
        batch_points=torch.unsqueeze(batch_points,0)
        out = F.grid_sample(self.field.unsqueeze(0).expand(B,-1,-1,-1,-1), batch_points.unsqueeze(1)).squeeze(2) # B, C, num_p, num_local_points
        # return out.view(B, -1, num_p).contiguous() # B, C*num_p, num_local_points, concat
        return out # B, C, num_p

    def point_clouds_inference(self, point_clouds):
        """
        Infer the shape embedding of a point cloud.
        Args:
            point_clouds: Tensor or Numpy array of shape (num_p, 3), the point cloud to infer the embedding,
        Returns:
            point_clouds_shape_embedding: Tensor of shape (num_p, 3+C), the shape embdedding of the input point, 
                                          C is the feature dim of the RASF feature
        """
        
        if type(point_clouds) is np.ndarray:
            
            point_clouds=torch.Tensor(point_clouds)
        if self.field.is_cuda:
            point_clouds=point_clouds.cuda() 
        point_clouds=point_clouds.unsqueeze(0)
        rasf_feature=self.batch_samples(point_clouds)
        rasf_feature=rasf_feature.transpose(1,2)
        point_clouds_shape_embedding=torch.cat((rasf_feature,point_clouds),2).squeeze()
        return point_clouds_shape_embedding

    
    def point_clouds_batch_inference(self, point_clouds):
        """
        Infer the shape embedding of a batch point clouds.
        Args:
            point_clouds: Tensor or Numpy array of shape (B, num_p, 3), the batch of point clouds to 
                          infer the embedding,
        Returns:
            point_clouds_shape_embedding: Tensor of shape (B, num_p, 3+C), the RASF feautre of the input 
                                          batch points, C is the feature dim of the RASF feature.
                                          
        """
        
        # data conversion
        if type(point_clouds) is np.ndarray:
            point_clouds=torch.Tensor(point_clouds)
        if self.field.is_cuda:
            point_clouds=point_clouds.cuda() 
       
        rasf_feature=self.batch_samples(point_clouds)
        rasf_feature=rasf_feature.transpose(1,2)
        point_clouds_shape_embedding=torch.cat((rasf_feature,point_clouds),2)
        
        
        return point_clouds_shape_embedding



    def mesh_inference(self, mesh):
        """
        Infer the shape embedding of a mesh.
        Args:
            mesh: Tensor or Numpy array of shape [V:N x 3, E:N_e x 2, F:N_f x 3], the mesh to infer the embedding,
                  which contains vertices, edges and faces.
        Returns:
            RASF_mesh: Tensor of shape [V:N x 3+C, E:N_e x 2, F:N_f x 3], the RASF feautre of the input 
                                          mesh, C is the feature dim of the RASF feature.
                                          
        """
        
        # mesh: 
        
        if torch.is_tensor(mesh['V']):
        
            recon_mesh=trimesh.Trimesh(vertices=mesh['V'].clone().detach().cpu()
                           ,faces=mesh['F'].clone().detach().cpu())
        else:
            recon_mesh=trimesh.Trimesh(vertices=mesh['V']
                           ,faces=mesh['F'])
        points,_=    trimesh.sample.sample_surface(recon_mesh, 1000)
        RASF_mesh={}
        points=torch.Tensor(points)
        
        # data conversion
        if not torch.is_tensor(mesh['V']):
            RASF_mesh['V']=torch.Tensor(mesh['V']).clone()
            RASF_mesh['E']=torch.Tensor(mesh['E']).clone()
            RASF_mesh['F']=torch.Tensor(mesh['F']).clone()
            
        else:
            RASF_mesh['V']=mesh['V'].clone()
            RASF_mesh['E']=mesh['E'].clone()
            RASF_mesh['F']=mesh['F'].clone()

        if self.field.is_cuda:
            RASF_mesh['V']= RASF_mesh['V'].cuda()
            RASF_mesh['E']= RASF_mesh['E'].cuda()
            RASF_mesh['F']= RASF_mesh['F'].cuda()
            points=points.cuda()
        
        # use both the sampled pointcloud and mesh vertices as input
        inputpoints=torch.cat((points,RASF_mesh['V']),0).unsqueeze(0)
        feature=torch.transpose(self.batch_samples(inputpoints).squeeze(),0,1)
        
        # take the feature for only the vertices
        rasf_feature=feature[1000:,:]

        RASF_mesh['V']=torch.cat((RASF_mesh['V'],rasf_feature),1)
    
        return RASF_mesh
    def mesh_batch_inference(self, mesh_batch):
        """
        Infer the shape embedding of a batch mesh.
        Args:
            mesh_batch: A list of Tensor or Numpy array of shape [V:N x 3, E:N_e x 2, F:N_f x 3], the mesh to infer the embedding,
                  which contains vertices, edges and faces.
        Returns:
            RASF_mesh_batch: A list of Tensor of shape [V:N x 3+C, E:N_e x 2, F:N_f x 3], the RASF feautre of the input 
                                          batch of mesh, C is the feature dim of the RASF feature.
                                          
        """
        RASF_mesh_batch=[]
        i=0
        for mesh in mesh_batch:
            RASF_mesh=self.mesh_inference(mesh)
            RASF_mesh_batch.append(RASF_mesh)
        
        return RASF_mesh_batch
    def voxels_inference(self, voxels):
        """
        Infer the shape embedding of a voxel.
        Args:
            voxels: A Tensor or Numpy array of shape (N, N, N), the voxel to infer the embedding,
                 
        Returns:
            feature_vox: A Tensor of shape (C+1, N, N, N), the RASF feautre of the input 
                                          voxel, C is the feature dim of the RASF feature.
                                          
        """
        # data conversion for voxels
        if not torch.is_tensor(voxels):
            voxels=torch.Tensor(voxels)
        if self.field.is_cuda:
            voxels=voxels.cuda()
            
        shapevox=voxels.shape   
        
        # initiate the voxel grid kernel, assign normalized coordinates
        
        ones=torch.ones((7,7,7))
        vac = torch.where(ones == 1)
        c = torch.stack((vac[0], vac[1], vac[2]), dim=1)
        cor = (c - 3) / 3
        cor = torch.unsqueeze(cor, 0)
        
        # data conversion
        if self.field.is_cuda:
            cor=cor.cuda()
           
        #sample the voxel kernel feature
        local_feature=self.voxel_samples(cor)
        local_feature=local_feature.squeeze().transpose(0,1)
        
        #assign feature to voxel using the kernel feature
        emptyvox=torch.zeros((*shapevox,7,7,7))
        for l in range(0,7):
            for j in range(0,7):
                for k in range(0,7):
                    emptyvox[:,:,:,l,j,k]=voxels.squeeze()

        reshapevox=emptyvox.view(*shapevox,7*7*7).unsqueeze(-1)
        if self.field.is_cuda:
            reshapevox=reshapevox.cuda()
        out=reshapevox*local_feature.unsqueeze(0).unsqueeze(0).unsqueeze(0)
        feature_vox=torch.max(out,3)[0]
        feature_vox=feature_vox.view(shapevox[0]*shapevox[1]*shapevox[2],32).transpose(0,1).unsqueeze(0)
        feature_vox=feature_vox.view(32,*shapevox)
        
        # concatenate the feature with the original voxel
        feature_vox=torch.cat((feature_vox, voxels.unsqueeze(0)),0)
        return feature_vox
    def voxels_batch_inference(self, voxels):
        """
        Infer the shape embedding of a batch voxel.
        Args:
            voxels: A Tensor or Numpy array of shape (B, N, N, N), the batch of voxel to infer the embedding,
                 
        Returns:
            feature_vox: A Tensor of shape (B, C+1, N, N, N), the RASF feautre of the input 
                                          voxel, C is the feature dim of the RASF feature.
                                          
        """
        
        # data conversion
        if not torch.is_tensor(voxels):
            voxels=torch.Tensor(voxels)
        if self.field.is_cuda:
            voxels=voxels.cuda()
            
        
        shapevox=voxels.shape[1:]
        
        # initiate the voxel grid kernel, assign normalized coordinates 
        ones=torch.ones((7,7,7))
        vac = torch.where(ones == 1)
        c = torch.stack((vac[0], vac[1], vac[2]), dim=1)
        cor = (c - 3) / 3
        cor = torch.unsqueeze(cor, 0)
        
        # data conversion
        if self.field.is_cuda:
            cor=cor.cuda()
            
        local_feature=self.voxel_samples(cor)
        local_feature=local_feature.squeeze().transpose(0,1)
        batch_list=[]

        #assigning feature to each voxel in the batch using the kernel feature, concatenate them to a batch
        for vox in voxels:
            emptyvox=torch.zeros((*shapevox,7,7,7))
            for l in range(0,7):
                for j in range(0,7):
                    for k in range(0,7):
                        emptyvox[:,:,:,l,j,k]=vox.squeeze()

            reshapevox=emptyvox.view(*shapevox,7*7*7).unsqueeze(-1)
            if self.field.is_cuda:
                reshapevox=reshapevox.cuda()
            out=reshapevox*local_feature.unsqueeze(0).unsqueeze(0).unsqueeze(0)


            feature_vox=torch.max(out,3)[0]

            feature_vox=feature_vox.view(shapevox[0]*shapevox[1]*shapevox[2],32).transpose(0,1).unsqueeze(0)
            batch_list.append(feature_vox)
        batchdata=torch.cat(batch_list,0)
        feature_vox=batchdata.view(-1,32,*shapevox)
        
        # concatenate thefeature  batch with the original voxel batch
        feature_vox=torch.cat((feature_vox, voxels.unsqueeze(1)),1)
        
        
        return feature_vox
