import torch
from pytorch3d.io import load_obj
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes
from torch import nn
from torch.nn import functional as F
import torch.nn.init as init
from torch_scatter import scatter_min

import itertools
import pickle

import numpy as np
import torchvision

from .texture import UNet, DUNet, HexPlaneField, DUNet1D, Texture_Decoder, DUNet1D_LSTM, DUNet1D_1DConv, MLP1D
from .strand_prior import Decoder

from torchvision.transforms import functional as TF
import sys

import accelerate
from copy import deepcopy
import os
import trimesh
import cv2
import pathlib
from tqdm import tqdm
import math
import time 


sys.path.append(os.path.join(sys.path[0], 'k-diffusion'))
from k_diffusion import config 

from src.utils.util import param_to_buffer, positional_encoding
from src.utils.geometry import barycentric_coordinates_of_projection, face_vertices
from src.utils.sample_points_from_meshes import sample_points_from_meshes, sample_points_from_allmeshes
from src.diffusion_prior.diffusion import make_denoiser_wrapper



def downsample_texture(rect_size, downsample_size):
    b = torch.linspace(0, rect_size**2 - 1, rect_size**2, device="cuda").reshape(rect_size, rect_size)
    
    patch_size = rect_size // downsample_size
    unf = torch.nn.Unfold(
        kernel_size=patch_size,
        stride=patch_size).cuda()
    unfo = unf(b[None, None]).reshape(-1, downsample_size**2)
    idx = torch.randint(low=0, high=patch_size**2, size=(1,), device="cuda")
    idx_ = idx.repeat(downsample_size**2,)
    choosen_val = unfo[idx_, torch.arange(downsample_size**2, device="cuda")]
    x = choosen_val // rect_size
    y = choosen_val % rect_size 
    return x.long(), y.long()


class OptimizableTexturedStrands(nn.Module):
    def __init__(self, 
                 path_to_mesh, 
                 num_strands,
                 max_num_strands,
                 texture_size,
                 geometry_descriptor_size,
                 appearance_descriptor_size,
                 decoder_checkpoint_path,
                 path_to_scale=None,
                 cut_scalp=None, 
                 diffusion_cfg=None,
                 texture_hidden_config=None,
                 data_dir=None,
                 flame_mesh_dir=None,
                 num_guiding_strands=None,
                 start_time_step = -1,
                 num_time_steps = -1,
                 ):
        super().__init__()
        file_path = pathlib.Path(__file__).parent.resolve()
        scalp_vert_idx = torch.load(f'{file_path}/../../data/new_scalp_vertex_idx.pth').long().cuda() # indices of scalp vertices
        scalp_faces = torch.load(f'{file_path}/../../data/new_scalp_faces.pth')[None].cuda() # faces that form a scalp
        scalp_uvs_init = torch.load(f'{file_path}/../../data/improved_neural_haircut_uvmap.pth').cuda()[None] # generated in Blender uv map for the scalp

        obj_path = f'{flame_mesh_dir}/raw_data/eval_30/mesh'
        frame_base_name = []
        time_index = []
        for time_name in os.listdir(obj_path):
            if time_name.endswith('.obj'):
                frame_base_name.append(time_name.split('.')[0])
                time_index.append(int(time_name.split('.')[0].split('_')[-1]))
        # Sort the time index
        time_index = np.array(time_index)
        time_index = np.argsort(time_index)
        frame_base_name = [frame_base_name[i] for i in time_index]
        time_index = np.arange(len(frame_base_name))
        cut_scalp = None
        # frame_base_name = frame_base_name[:2]
        self.start_time_step = 0 if start_time_step == -1 else start_time_step
        self.num_time_steps = len(frame_base_name) if num_time_steps == -1 else num_time_steps
        frame_base_name = frame_base_name[self.start_time_step:self.start_time_step + self.num_time_steps]

        verts_all = []
        faces_all = []
        head_mesh_all = []
        self.scale_decoder = []
        self.scalp_mesh = []
        self.scalp_uvs = []
        local2world = []

        # Load FLAME head mesh
        print('Loading FLAME head mesh')
        for time_name in tqdm(frame_base_name,total=len(frame_base_name)):
            # import ipdb;ipdb.set_trace()
            if flame_mesh_dir is not None:
                verts, faces, _ = load_obj(f'{flame_mesh_dir}/raw_data/eval_30/mesh/{time_name}.obj', device='cuda')
            else:
                verts, faces, _ = load_obj(path_to_mesh, device='cuda')    
            # Transform head mesh if it's not in unit sphere (same scale used for world-->unit_sphere transform)
            self.transform = None
            if path_to_scale:
                with open(path_to_scale, 'rb') as f:
                    self.transform = pickle.load(f)
                verts = (verts - torch.tensor(self.transform['translation'], device=verts.device)) / self.transform['scale']
            head_mesh =  Meshes(verts=[(verts)], faces=[faces.verts_idx]).cuda()
            os.makedirs(f'{flame_mesh_dir}/raw_data/eval_30/point_cloud', exist_ok=True)
            ply_path = f'{flame_mesh_dir}/raw_data/eval_30/point_cloud/{time_name}.ply' 
            if not os.path.exists(ply_path):
                num_samples = 200000  # 
                pointcloud = sample_points_from_meshes(head_mesh, num_samples)[0][0].cpu().numpy()
                pc = trimesh.PointCloud(pointcloud)
                pc.export(ply_path)

            # Scaling factor, as decoder pretrained on synthetic data with fixed head scale
            usc_scale = torch.tensor([[0.2579, 0.4082, 0.2580]]).cuda()
            head_scale = head_mesh.verts_packed().max(0)[0] - head_mesh.verts_packed().min(0)[0]
            scale_decoder = (usc_scale / head_scale).mean()
            # Extract scalp mesh from head
            scalp_verts = head_mesh.verts_packed()[None, scalp_vert_idx]
            scalp_face_verts = face_vertices(scalp_verts, scalp_faces)[0]
            scalp_uvs = scalp_uvs_init
            scalp_mesh = Meshes(verts=scalp_verts, faces=scalp_faces, textures=TexturesVertex(scalp_uvs)).cuda()

            # If we want to use different scalp vertices for scene
            if cut_scalp:
                # print('Loading cut scalp')
                tqdm.write(f'Loading cut scalp for frame {time_name}')
                if data_dir is not None:
                    time_name = int(time_name.split('_')[-1])+1
                    time_name = f'frame_{time_name:05d}'
                    with open(f'{flame_mesh_dir}/preprocessed_data/scalp_data/cut_scalp_verts/{time_name}.pickle', 'rb') as f:
                        full_scalp_list = sorted(pickle.load(f))
                else:
                    with open(cut_scalp, 'rb') as f:
                        full_scalp_list = sorted(pickle.load(f))
                    
                a = np.array(full_scalp_list)
                b = np.arange(a.shape[0])
                d = dict(zip(a, b))
                
                faces_masked = []
                for face in scalp_mesh.faces_packed():
                    if face[0] in full_scalp_list and face[1] in full_scalp_list and  face[2] in full_scalp_list:
                        faces_masked.append(torch.tensor([d[int(face[0])], d[int(face[1])], d[int(face[2])]]))

                scalp_uvs = scalp_uvs[:, full_scalp_list]
                scalp_mesh = Meshes(verts=scalp_mesh.verts_packed()[None, full_scalp_list].float(), faces=torch.stack(faces_masked)[None].cuda(), textures=TexturesVertex(scalp_uvs)).cuda()
            
            scalp_mesh.textures = TexturesVertex(scalp_uvs)
            self.scale_decoder.append(scale_decoder)
            self.scalp_mesh.append(scalp_mesh)
            self.scalp_uvs.append(scalp_uvs)
            local2world.append(self.init_scalp_basis(scalp_uvs))
            verts_all.append(verts)
            faces_all.append(faces.verts_idx)
            head_mesh_all.append(head_mesh)  

        # For 3D interpolation
        self.use_guiding_strands = num_guiding_strands is not None and num_guiding_strands > 0
        self.num_guiding_strands = num_guiding_strands if self.use_guiding_strands else 0

        # self.num_strands = num_strands - self.num_guiding_strands
        self.num_strands = num_strands
        self.max_num_strands = max_num_strands
        self.geometry_descriptor_size = geometry_descriptor_size
        self.appearance_descriptor_size = appearance_descriptor_size
        
        # uvs_index_select_gradient
        self.uvs_index_select = torch.arange(self.max_num_strands).cuda()
        self.uvs_index_active = torch.empty(0).cuda()
        self.uvs_index_select_sequence = [self.uvs_index_select for i in range(self.num_time_steps)]
        self.uvs_index_select_sequence_used = [False for i in range(self.num_time_steps)]
        self.uvs_index_select_sequence_used[0] = True
        self.uvs_index_select_time_index_last = 0
        self.grad_uvs_p_storage = None
        self.grad_point_p_storage = None
        
        mgrid = torch.stack(torch.meshgrid([torch.linspace(-1, 1, texture_size)]*2, indexing='xy'))[None].cuda()
        # self.register_buffer('encoder_input', positional_encoding(mgrid, 6))
        # print("mgrid:  ",mgrid.shape)
        # print("positional_encoding(mgrid, 6):  ",positional_encoding(mgrid, 6).shape)
        self.register_buffer('mgrid', mgrid)
        time_init = torch.tensor([0.0]).unsqueeze(0).unsqueeze(0).repeat(1,texture_size, texture_size)[None].cuda()
        mgrid_time = torch.cat([mgrid, time_init], dim=1)
        self.register_buffer('encoder_input', positional_encoding(mgrid_time,6))
        
        res_init = 16
        dynamic_mgrid = torch.stack(torch.meshgrid([torch.linspace(-1, 1, res_init + 1)]*2, indexing='xy')).cuda()
        dynamic_mgrid_mask = torch.ones(res_init + 1, res_init + 1).int().cuda()
        self.register_buffer('res_init', torch.tensor(res_init).int())
        self.register_buffer('dynamic_mgrid', dynamic_mgrid)
        self.register_buffer('dynamic_mgrid_mask', dynamic_mgrid_mask)
        self.register_buffer('grid_active_set', dynamic_mgrid_mask)
        # self.register_buffer('idx_int', torch.randperm(self.max_num_strands, device='cuda')[:self.num_strands])
        # Initialize the texture decoder network
        # self.texture_decoder = UNet(self.encoder_input.shape[1], geometry_descriptor_size + appearance_descriptor_size, bilinear=True)
        # self.texture_decoder = DUNet(self.encoder_input.shape[1], geometry_descriptor_size + appearance_descriptor_size, bilinear=True)
        # self.texture1D_decoder = DUNet1D(self.encoder_input.shape[1], geometry_descriptor_size + appearance_descriptor_size, bilinear=True)
        # self.texture1D_lstm_decoder = DUNet1D_LSTM(positional_encoding(mgrid, 6).shape[1], geometry_descriptor_size + appearance_descriptor_size, bilinear=True)
        # self.texture1D_conv_decoder = DUNet1D_1DConv(positional_encoding(mgrid, 6).shape[1], geometry_descriptor_size + appearance_descriptor_size, bilinear=True)
        # self.texture_mlp_decoder = Texture_Decoder(dim_in = self.encoder_input.shape[1], dim_hidden = 512 ,num_layers = 8, dim_out = geometry_descriptor_size + appearance_descriptor_size)
        # print("self.encoder_input.shape[1]:  ",self.encoder_input.shape[1])
        # self.texture_decoder.load_state_dict(torch.load(f'{file_path}/../../pretrained_models/texture_decoder.pth'))
        self.feature_net_width = geometry_descriptor_size + appearance_descriptor_size
        self.feature_net_depth = 4
        self.grid = HexPlaneField(texture_hidden_config.texture_bounds, texture_hidden_config.texture_kplanes_config, texture_hidden_config.texture_multires)
        
        # self.feature_out = [nn.Linear(self.grid.feat_dim , self.feature_net_width)]
        # for i in range(self.feature_net_depth-1):
        #     self.feature_out.append(nn.ReLU())
        #     self.feature_out.append(nn.Linear(self.feature_net_width, self.feature_net_width))
        # self.feature_out = nn.Sequential(*self.feature_out)
        # for m in self.feature_out:
        #     if isinstance(m, nn.Linear):
        #         # init.xavier_uniform_(m.weight)
        #         init.normal_(m.weight, mean=0, std=1e-4)
        #         if m.bias is not None:
        #             init.constant_(m.bias, 0)
        self.feature_out = DUNet1D(self.grid.feat_dim, self.feature_net_width, bilinear=True)
        # self.feature_out = MLP1D(self.grid.feat_dim, self.feature_net_width, bilinear=True)
        xyz_max = np.array([1,1])
        xyz_min = np.array([-1,-1])
        self.grid.set_aabb(xyz_max, xyz_min)
        
        self.register_buffer('local2world', torch.stack(local2world, dim=0))

        # Sample fixed origin points
        origins, uvs, face_idx = sample_points_from_allmeshes(self.scalp_mesh, num_samples=max_num_strands, return_textures=True)
        self.register_buffer('origins', origins)
        self.register_buffer('uvs', uvs)
        self.idx = None
        self.uvs_select = None
        self.p_local = [None] * self.num_time_steps
        self.features_dc = [None] * self.num_time_steps
        self.features_rest = [None] * self.num_time_steps
        self.orient_conf = [None] * self.num_time_steps
        self.idx_active_mask = None
        self.idx_mask = None

        # Get transforms for the samples
        batch_idx = torch.arange(face_idx.shape[0]).view(-1, 1).expand_as(face_idx)
        self.local2world.data = self.local2world[batch_idx, face_idx]
        # import pdb; pdb.set_trace()

        # For uniform faces selection
        self.N_faces =  self.scalp_mesh[0].faces_packed()[None].shape[1]   
        # self.m, self.q = self.num_strands // self.N_faces, self.num_strands % self.N_faces
        
        # if self.use_guiding_strands:
        #     self.m_gdn, self.q_gdn = self.num_guiding_strands // self.N_faces, self.num_guiding_strands % self.N_faces
        
        
        self.faces_dict = {}
        for idx, f in enumerate(face_idx[0].cpu().numpy()):
            try:
                self.faces_dict[f].append(idx)
            except KeyError:
                self.faces_dict[f] = [idx]
        idxes, counts = face_idx[0].unique(return_counts=True)
        self.faces_count_dict = dict(zip(idxes.cpu().numpy(), counts.cpu().numpy()))
        
        # Decoder predicts the strands from the embeddings
        decoder_checkpoint_path = f'{file_path}/../../pretrained_models/strand_prior/strand_ckpt.pth'
        self.strand_length = 100
        # self.strand_length = 100
        self.strand_decoder = Decoder(None, latent_dim=geometry_descriptor_size, length=99).eval()
        self.strand_decoder.load_state_dict(torch.load(decoder_checkpoint_path)['decoder'])
        param_to_buffer(self.strand_decoder)
        self.max_sh_degree = 3
        self.color_decoder = Decoder(None, dim_hidden=128, num_layers=2, length=self.strand_length - 1, dim_out=3*(self.max_sh_degree+1)**2 + 1).cuda()

        # Diffusion prior model
        self.use_diffusion = diffusion_cfg['use_diffusion']  
        if self.use_diffusion:
            ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=diffusion_cfg['model']['skip_stages'] > 0)
            self.accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=1)

            # Initialize diffusion model
            inner_model = config.make_model(diffusion_cfg)
            model = make_denoiser_wrapper(diffusion_cfg)(inner_model)
            self.model_ema = deepcopy(model).cuda()
            self.model_ema.eval()
            
            # Upload pretrained on synthetic data checkpoint
            diffusion_checkpoint_path = f'{file_path}/../../pretrained_models/diffusion_prior/wo_bug_blender_uv_00130000.pth'
            # ckpt = torch.load(diffusion_checkpoint_path, map_location='cuda')
            ckpt = torch.load(diffusion_checkpoint_path, map_location='cpu')
            self.accelerator.unwrap_model(self.model_ema.inner_model).load_state_dict(ckpt['model_ema'])
            param_to_buffer(self.model_ema)

            self.diffusion_input = diffusion_cfg['model']['input_size'][0]
            self.sample_density = config.make_sample_density(diffusion_cfg['model'])
            self.start_denoise = diffusion_cfg['start_denoise']
            self.diffuse_bs = diffusion_cfg['diffuse_bs']
            self.diffuse_mask = []

            for time_name in frame_base_name:
                # Load scalp mask for hairstyle
                if data_dir is not None:
                    time_name = int(time_name.split('_')[-1])+1
                    time_name = f'frame_{time_name:05d}'
                    diffuse_mask = f'{flame_mesh_dir}/preprocessed_data/scalp_data/dif_mask/{time_name}.png'
                else:
                    diffuse_mask = diffusion_cfg.get('diffuse_mask', None) 
                
                if os.path.exists(diffuse_mask) and diffuse_mask:
                    # print(f'Loading diffuse mask {self.diffuse_mask}')
                    diffuse_mask = torch.tensor(cv2.imread(diffuse_mask) / 255)[:, :, :1].squeeze(-1).cuda()
                else:
                    diffuse_mask = torch.ones(256, 256).cuda()
                self.diffuse_mask.append(diffuse_mask)

            self.rect_size = texture_size
            self.downsample_size = self.diffusion_input

            b = torch.linspace(0, self.rect_size**2 - 1, self.rect_size**2, device="cuda").reshape(self.rect_size, self.rect_size)
            self.patch_size = self.rect_size // self.downsample_size
            unf = torch.nn.Unfold(
                kernel_size=self.patch_size,
                stride=self.patch_size).cuda()
            self.unfo = unf(b[None, None]).reshape(-1, self.downsample_size**2) # all unfolds
        K = 4
        self.register_buffer('triu_indices', torch.triu_indices(K, K, offset=1))
    
    def load_state_dict(self, state_dict, strict=True):
        if 'res_init' in state_dict:
            saved_shape = state_dict['res_init'].shape
            if self.res_init is None or self.res_init.shape != saved_shape:
                print(f"Resizing buffer 'res_init' to shape {saved_shape}")
                self.register_buffer('res_init', torch.empty(saved_shape, device=self.res_init.device))
        if 'dynamic_mgrid' in state_dict:
            saved_shape = state_dict['dynamic_mgrid'].shape
            if self.dynamic_mgrid is None or self.dynamic_mgrid.shape != saved_shape:
                print(f"Resizing buffer 'dynamic_mgrid' to shape {saved_shape}")
                self.register_buffer('dynamic_mgrid', torch.empty(saved_shape, device=self.dynamic_mgrid.device))
        if 'dynamic_mgrid_mask' in state_dict:
            saved_shape = state_dict['dynamic_mgrid_mask'].shape
            if self.dynamic_mgrid_mask is None or self.dynamic_mgrid_mask.shape != saved_shape:
                print(f"Resizing buffer 'dynamic_mgrid_mask' to shape {saved_shape}")
                self.register_buffer('dynamic_mgrid_mask', torch.empty(saved_shape, device=self.dynamic_mgrid_mask.device))
        if 'grid_active_set' in state_dict:
            saved_shape = state_dict['grid_active_set'].shape
            if self.grid_active_set is None or self.grid_active_set.shape != saved_shape:
                print(f"Resizing buffer 'grid_active_set' to shape {saved_shape}")
                self.register_buffer('grid_active_set', torch.empty(saved_shape, device=self.grid_active_set.device))
        super().load_state_dict(state_dict, strict=strict)
    def save_uvs_grad(self, grad):
        self.grad_uvs_p_storage = grad  
    def save_point_grad(self, grad):
        self.grad_point_p_storage = grad  
        # self.grad_uvs_p_storage = grad  
    def int_para_precomp(self,time_index):
        with torch.no_grad():
        # idx_int = torch.randperm(self.max_num_strands, device='cuda')[:self.num_strands]
            # idx_int = torch.range(0, self.max_num_strands, device='cuda').long()
            idx_int = torch.arange(0, self.max_num_strands, device="cuda", dtype=torch.long)
            uvs_gdn = self.dynamic_mgrid[:,self.dynamic_mgrid_mask.bool()].reshape(2,-1).transpose(0,1)
            uvs_int = self.uvs[0][idx_int]
            uvs_mgrid = self.dynamic_mgrid.permute(1,2,0) # 2 x (res + 1) x (res + 1)
            uvs_mask = self.dynamic_mgrid_mask # (res + 1) x (res + 1)
            res_final = uvs_mgrid.shape[1] - 1
            res_init = self.res_init.item()
            M = math.ceil(math.log(res_final / res_init, 2)) + 1
            N = self.max_num_strands
            multi_res = [res_init * (2 ** i) for i in range(M)]
            # uvs_select_ad = ((uvs_int - torch.tensor([-1,-1]).cuda()) / torch.tensor([2,2]).cuda())[None,:,:]
            uvs_select_ad = ((uvs_int + 1) * 0.5).unsqueeze(0)
            multi_res_ad = torch.tensor(multi_res).cuda()[:,None,None]
            uvs_select_ad =  uvs_select_ad * multi_res_ad
            multi_res_ad_grid_four = multi_res_ad[:,:,None,:].repeat(1,1,4,1)
            uvs_select_ad_grid = torch.floor(uvs_select_ad).int().cuda() # M x N x 2
            offsets = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32).view(1, 1, 4, 2).cuda()
            uvs_select_ad_grid_four = uvs_select_ad_grid.unsqueeze(2) + offsets
            uvs_select_ad_grid_four = torch.minimum(uvs_select_ad_grid_four, multi_res_ad_grid_four)
            scale_factors = torch.tensor([res_final // res for res in multi_res],
                                device=uvs_select_ad_grid_four.device,
                                dtype=torch.int32)[:, None, None, None]
            uvs_select_ad_grid_four_high_res = (uvs_select_ad_grid_four * scale_factors).int()
            mask_vals = uvs_mask[uvs_select_ad_grid_four_high_res[..., 0],
                        uvs_select_ad_grid_four_high_res[..., 1]]  # shape: (M, N, 4)
            uvs_select_ad_grid_four_high_res_mask = torch.all(mask_vals, dim=-1).transpose(0, 1).int()
            # uvs_select_ad_grid_four_high_res_mask : N x M; uvs_res_select: N 
            # uvs_res_select = torch.argmax(uvs_select_ad_grid_four_high_res_mask, dim=1)
            uvs_res_select = M-1-torch.argmax(uvs_select_ad_grid_four_high_res_mask.flip(dims=[1]), dim=1)
            uvs_select_ad_grid_four_high_res_select = uvs_select_ad_grid_four_high_res[uvs_res_select, torch.arange(N).cuda()]
            uvs_select_four_select = (uvs_select_ad_grid_four_high_res_select[:, :, 0] +
                            uvs_select_ad_grid_four_high_res_select[:, :, 1] * (res_final + 1)).int().cuda()
            uvs_mask = uvs_mask.view(-1) 
            selected_idx = torch.nonzero(uvs_mask).cuda().reshape(-1)
            mapping = torch.full(uvs_mask.shape, -1).cuda()
            mapping[selected_idx] = torch.arange(len(selected_idx)).cuda()
            knn_idx = mapping[uvs_select_four_select] # N x 4
            uvs_int_norm = (uvs_int + 1) * 0.5
            uvs_gdn_norm = (uvs_gdn + 1) * 0.5
            u = uvs_int_norm[:, 0]
            v = uvs_int_norm[:, 1]
            uv_max_min = uvs_gdn_norm[knn_idx]
            u_min = uv_max_min[:,0,0]
            u_max = uv_max_min[:,3,0]
            v_min = uv_max_min[:,0,1]
            v_max = uv_max_min[:,3,1]
            denom = (u_max - u_min) * (v_max - v_min) + 1e-8
            w11 = torch.clamp((u_max - u) * (v_max - v) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
            w21 = torch.clamp((u - u_min) * (v_max - v) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
            w12 = torch.clamp((u_max - u) * (v - v_min) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
            w22 = torch.clamp((u - u_min) * (v - v_min) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
            self.weights = torch.cat([w11, w21, w12, w22], dim=1).unsqueeze(-1)
            self.knn_idx = knn_idx
        
        
    def init_scalp_basis(self, scalp_uvs):         

        scalp_verts, scalp_faces = self.scalp_mesh[-1].verts_packed()[None], self.scalp_mesh[-1].faces_packed()[None]
        scalp_face_verts = face_vertices(scalp_verts, scalp_faces)[0] 
        
        # Define normal axis
        origin_v = scalp_face_verts.mean(1)
        origin_n = self.scalp_mesh[-1].faces_normals_packed()
        origin_n /= origin_n.norm(dim=-1, keepdim=True)
        
        # Define tangent axis
        full_uvs = scalp_uvs[0][scalp_faces[0]]
        bs = full_uvs.shape[0]
        concat_full_uvs = torch.cat((full_uvs, torch.zeros(bs, full_uvs.shape[1], 1, device=full_uvs.device)), -1)
        new_point = concat_full_uvs.mean(1).clone()
        new_point[:, 0] += 0.001
        bary_coords = barycentric_coordinates_of_projection(new_point, concat_full_uvs).unsqueeze(1)
        full_verts = scalp_verts[0][scalp_faces[0]]
        origin_t = (bary_coords @ full_verts).squeeze(1) - full_verts.mean(1)
        origin_t /= origin_t.norm(dim=-1, keepdim=True)
        
        assert torch.where((bary_coords.reshape(-1, 3) > 0).sum(-1) != 3)[0].shape[0] == 0
        
        # Define bitangent axis
        origin_b = torch.cross(origin_n, origin_t, dim=-1)
        origin_b /= origin_b.norm(dim=-1, keepdim=True)

        # Construct transform from global to local (for each point)
        R = torch.stack([origin_t, origin_b, origin_n], dim=1) 
        
        # local to global 
        R_inv = torch.linalg.inv(R) 
        
        return R_inv
        
    def forward(self, it=None, time_index=0, training=True): 
        diffusion_dict = {}
        num_strands = self.num_guiding_strands if self.use_guiding_strands else self.num_strands
        time1 = time.time()
        uvs_grid = self.dynamic_mgrid[:,self.dynamic_mgrid_mask.bool()].reshape(2,-1).transpose(0,1) # 256*256 x 2
        uvs_grid.requires_grad_(True)
        uvs_grid.register_hook(self.save_uvs_grad)
        time_val = time_index / self.num_time_steps
        training = True
        if training:
            grid_time = uvs_grid.new_full((uvs_grid.shape[0], 1), time_val)
            grid_feature = self.grid(uvs_grid, grid_time)
            z = self.feature_out(grid_feature)
            # z.requires_grad_(True)
            # z.register_hook(self.save_uvs_grad)
            # idx = self.uvs_index_select[idx]
            # self.idx = idx
        else:
            grid_time = self.uvs[time_index].new_full((self.uvs[time_index].shape[0], 1), time_val)
            grid_feature = self.grid(self.uvs[time_index], grid_time)
            idx = torch.randperm(self.max_num_strands, device="cuda")[:num_strands]
            z_full = self.feature_out(grid_feature)
            z = z_full[idx]
        # z = F.grid_sample(texture, uvs[None, None])[0, :, 0].transpose(0, 1)
        # z = self.texture_mlp_decoder(uvs_p)
        z_geom = z[:, :self.geometry_descriptor_size]
        if self.appearance_descriptor_size:
            z_app = z[:, self.geometry_descriptor_size:]
        else:
            z_app = None
        time2 = time.time()
        # Decode strabds
        v = self.strand_decoder(z_geom) / self.scale_decoder[time_index]  # [num_strands, strand_length - 1, 3]
        p_local = torch.cat([
                torch.zeros_like(v[:, -1:, :]), 
                torch.cumsum(v, dim=1)
            ], 
            dim=1
        )
        p_local.requires_grad_(True)
        p_local.register_hook(self.save_point_grad)
        time3 = time.time()
        features_dc, features_rest, orient_conf = self.color_decoder(z_app[:,1:]).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)
        time4 = time.time()
        self.p_local[time_index] = p_local.detach()
        self.features_dc[time_index] = features_dc.detach()
        self.features_rest[time_index] = features_rest.detach()
        self.orient_conf[time_index] = orient_conf.detach()
        if self.use_guiding_strands:
            idx_int = torch.randperm(self.max_num_strands, device='cuda')[:self.num_strands]
            knn_idx = self.knn_idx[idx_int]
            weights = self.weights[idx_int]
            origins_int = self.origins[time_index][idx_int]
            uvs_int = self.uvs[time_index][idx_int]
            local2world_int = self.local2world[time_index][idx_int]
            
            p_local_knn = p_local[knn_idx]
            features_dc_knn = features_dc[knn_idx]
            features_rest_knn = features_rest[knn_idx]
            orient_conf_knn = orient_conf[knn_idx]
            
            
            # import ipdb; ipdb.set_trace()
            p_local_int = (weights * p_local_knn).sum(dim=1)
            features_dc_int = (weights * features_dc_knn).sum(dim=1)
            features_rest_int = (weights * features_rest_knn).sum(dim=1)
            orient_conf_int = (weights * orient_conf_knn).sum(dim=1)
            # p_local_int = w11 *  p_local[knn_idx[:, 0]]  + w21 *  p_local[knn_idx[:, 1]]  + w12 *  p_local[knn_idx[:, 2]]  + w22 *  p_local[knn_idx[:, 3]]
            # features_dc_int = w11 *  features_dc[knn_idx[:, 0]]  + w21 *  features_dc[knn_idx[:, 1]]  + w12 *  features_dc[knn_idx[:, 2]]  + w22 *  features_dc[knn_idx[:, 3]]
            # features_rest_int = w11 *  features_rest[knn_idx[:, 0]]  + w21 *  features_rest[knn_idx[:, 1]]  + w12 *  features_rest[knn_idx[:, 2]]  + w22 *  features_rest[knn_idx[:, 3]]
            # orient_conf_int = w11 *  orient_conf[knn_idx[:, 0]]  + w21 *  orient_conf[knn_idx[:, 1]]  + w12 *  orient_conf[knn_idx[:, 2]]  + w22 *  orient_conf[knn_idx[:, 3]]
            idx_active = torch.arange(uvs_grid.shape[0]).cuda()
            self.idx_mask = torch.isin(idx_active, knn_idx)
            # if it % 10 == 0:
            #     import ipdb; ipdb.set_trace()
        p = (local2world_int[:, None] @ p_local_int[..., None])[:, :, :3, 0] + origins_int[:, None] # [num_strands, strang_length, 3]
            
        time5 = time.time()
        time_mlp = time2-time1
        time_strand_decoder = time3-time2
        time_color_decoder = time4-time3
        time_int = time5-time4
        time_total = time5-time1
        if it % 10 == 0 and False:
            print("coarse")
            print(f"time_total: {time_total}")
            print(f"time_mlp: {time_mlp}, time_strand_decoder: {time_strand_decoder}, time_color_decoder: {time_color_decoder}, time_int: {time_int}")
            print(f"time_mlp_scale: {time_mlp/time_total}, time_strand_decoder_scale: {time_strand_decoder/time_total}, time_color_decoder_scale: {time_color_decoder/time_total}, time_int_scale: {time_int/time_total}")
        
        return p, uvs_int, local2world_int, p_local_int, features_dc_int, features_rest_int, orient_conf_int, diffusion_dict
    def forward_sparse(self, it=None, time_index=0, training=True): 
        diffusion_dict = {}
        time1 = time.time()
        uvs_mask = self.dynamic_mgrid_mask
        mask_dynamic = self.dynamic_mgrid_mask.bool()
        mask_active = self.grid_active_set.bool()
        if mask_active.sum() <= 1:
            mask_active = mask_dynamic
        selected_idx = torch.nonzero(uvs_mask).cuda()
        mapping = torch.full(uvs_mask.shape, -1).cuda()
        mapping[selected_idx[:,0],selected_idx[:,1]] = torch.arange(selected_idx.shape[0]).cuda()
        idx_active = mapping[mask_dynamic & mask_active] # N x 4
        # _idx_active = mapping[self.dynamic_mgrid_mask.bool() & self.grid_active_set.bool()] # N x 4
        idx_wo_active = mapping[mask_dynamic & ~mask_active] # N x 4
        time_val = time_index / self.num_time_steps
        # import ipdb; ipdb.set_trace()
        # uvs_grid = self.dynamic_mgrid[:,(mask_dynamic & mask_active)].reshape(2,-1).transpose(0,1) # 256*256 x 2
        # uvs_grid = self.dynamic_mgrid[:,mask_dynamic].reshape(2,-1).transpose(0,1) # 256*256 x 2
        uvs_grid = self.dynamic_mgrid[:,mask_active].reshape(2,-1).transpose(0,1) # 256*256 x 2
        time_val = time_index / self.num_time_steps
        grid_time = uvs_grid.new_full((uvs_grid.shape[0], 1), time_val)
        # grid_feature = self.grid(uvs_grid, grid_time)[idx_active]
        grid_feature = self.grid(uvs_grid, grid_time)
        z = self.feature_out(grid_feature)
        z_geom = z[:, :self.geometry_descriptor_size]
        if self.appearance_descriptor_size:
            z_app = z[:, self.geometry_descriptor_size:]
        else:
            z_app = None
        time2 = time.time()
        # Decode strabds
        v = self.strand_decoder(z_geom) / self.scale_decoder[time_index]  # [num_strands, strand_length - 1, 3]
        p_local_active = torch.cat([
                torch.zeros_like(v[:, -1:, :]), 
                torch.cumsum(v, dim=1)
            ], 
            dim=1
        )
        time3 = time.time()
        features_dc_active, features_rest_active, orient_conf_active = self.color_decoder(z_app[:,1:]).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)
        time4 = time.time()
        num_dynamic_grid = int(self.dynamic_mgrid_mask.sum().item())
        device = p_local_active.device
        p_local      = torch.empty((num_dynamic_grid, *p_local_active.shape[1:]), device=device)
        features_dc  = torch.empty((num_dynamic_grid, *features_dc_active.shape[1:]), device=device)
        features_rest = torch.empty((num_dynamic_grid, *features_rest_active.shape[1:]), device=device)
        orient_conf  = torch.empty((num_dynamic_grid, *orient_conf_active.shape[1:]), device=device)
        p_local.index_copy_(0, idx_active, p_local_active)
        features_dc.index_copy_(0, idx_active, features_dc_active)
        features_rest.index_copy_(0, idx_active, features_rest_active)
        orient_conf.index_copy_(0, idx_active, orient_conf_active)
        with torch.no_grad():
            p_local.index_copy_(0, idx_wo_active, self.p_local[time_index][idx_wo_active])
            features_dc.index_copy_(0, idx_wo_active, self.features_dc[time_index][idx_wo_active])
            features_rest.index_copy_(0, idx_wo_active, self.features_rest[time_index][idx_wo_active])
            orient_conf.index_copy_(0, idx_wo_active, self.orient_conf[time_index][idx_wo_active])
            self.p_local[time_index] = p_local
            self.features_dc[time_index] = features_dc
            self.features_rest[time_index] = features_rest
            self.orient_conf[time_index] = orient_conf
        time41 = time.time()
        # p_local = p_local[:,::2,:]
        # p_local = p_local[:,::4,:]
        # p_local = p_local[:,::5,:]
        # p_local = p_local[:,::10,:]
        # p_local = p_local[:,::20,:]
        # import ipdb;ipdb.set_trace(
        if self.use_guiding_strands:
            idx_int = torch.randperm(self.max_num_strands, device='cuda')[:self.num_strands]
            # idx_int = self.idx_int
            knn_idx = self.knn_idx[idx_int]
            weights = self.weights[idx_int]
            origins_int = self.origins[time_index][idx_int]
            uvs_int = self.uvs[time_index][idx_int]
            local2world_int = self.local2world[time_index][idx_int]
            p_local_knn = p_local[knn_idx]
            features_dc_knn = features_dc[knn_idx]
            features_rest_knn = features_rest[knn_idx]
            orient_conf_knn = orient_conf[knn_idx]
            p_local_int = (weights * p_local_knn).sum(dim=1)
            features_dc_int = (weights * features_dc_knn).sum(dim=1)
            features_rest_int = (weights * features_rest_knn).sum(dim=1)
            orient_conf_int = (weights * orient_conf_knn).sum(dim=1)
            self.idx_active_mask = torch.any(torch.isin(knn_idx, idx_active),dim=-1)
        time5 = time.time()
            
        p = (local2world_int[:, None] @ p_local_int[..., None])[:, :, :3, 0] + origins_int[:, None] # [num_strands, strang_length, 3]
        time_mlp = time2-time1
        time_strand_decoder = time3-time2
        time_color_decoder = time4-time3
        local_sparse = time41 - time4
        time_int = time5-time41
        time_total = time5-time1
        if it % 11 == 0 and False:
            print("sparse")
            print(f"time_total: {time_total}")
            print(f"time_mlp: {time_mlp}, time_strand_decoder: {time_strand_decoder}, time_color_decoder: {time_color_decoder}, time_int: {time_int}, local_sparse: {local_sparse}")
            print(f"time_mlp_scale: {time_mlp/time_total}, time_strand_decoder_scale: {time_strand_decoder/time_total}, time_color_decoder_scale: {time_color_decoder/time_total}, time_int_scale: {time_int/time_total}, local_sparse_scale: {local_sparse/time_total}")
            # print(f"time_mlp_scale: {time_mlp/time_total}, time_strand_decoder_scale: {time_strand_decoder/time_total}, time_color_decoder_scale: {time_color_decoder/time_total}, time_int_scale: {time_int/time_total}")
        return p, uvs_int, local2world_int, p_local_int, features_dc_int, features_rest_int, orient_conf_int, diffusion_dict

    def forward_inference(self, num_strands, time_index=0): 
        # To sample more strands at inference stage
        # texture_size = self.mgrid.shape[-1]
        # time_input = torch.tensor([time_index/self.num_time_steps]).unsqueeze(0).unsqueeze(0).repeat(1,texture_size, texture_size)[None].cuda()
        # mgrid_time = torch.cat([self.mgrid, time_input], dim=1)
        # texture = self.texture_decoder(positional_encoding(mgrid_time, 6))
        
        # texture = self.texture_decoder(self.encoder_input)
        # self.num_strands = num_strands
        # print("num_strands", num_strands)
        
        # Sample from the fixed origins
        torch.manual_seed(0)
        if self.idx == None:
            idx = torch.randperm(self.max_num_strands, device = 'cuda')[:num_strands]
            self.idx = idx
        else:
            idx = self.idx
        uvs_grid = self.dynamic_mgrid[:,self.dynamic_mgrid_mask.bool()].reshape(2,-1).transpose(0,1) # 256*256 x 2
        time_val = time_index / self.num_time_steps
        # time_val = time_index / self.num_time_steps
        # uv_device = self.uvs[time_index].device
        grid_time = uvs_grid.new_full((uvs_grid.shape[0], 1), time_val)
        grid_feature = self.grid(uvs_grid, grid_time)
        z = self.feature_out(grid_feature)
        # print("z:  ",z.shape)
        # z.requires_grad_(True)
        # z.register_hook(self.save_grad)
        
        # uvs_p = positional_encoding(self.uvs[time_index][idx], 6)
        # uvs_p_t = uvs_p.unsqueeze(1).expand(uvs_p.shape[0], self.num_time_steps, uvs_p.shape[1])
        # z = self.texture1D_lstm_decoder(uvs_p_t)[:,time_index,:]
        # z = self.texture1D_conv_decoder(uvs_p_t)[idx,time_index,:]

        # Get latents for the samples
        # z = F.grid_sample(texture, uvs[None, None])[0, :, 0].transpose(0, 1) # num_strands, C
        
        z_geom = z[:, :self.geometry_descriptor_size]
        
        v = self.strand_decoder(z_geom) / self.scale_decoder[time_index]  # [num_strands, strand_length - 1, 3]
        p_local = torch.cat([
                torch.zeros_like(v[:, -1:, :]), 
                torch.cumsum(v, dim=1)
            ], 
            dim=1
        )

        if self.appearance_descriptor_size:
            z_app = z[:, self.geometry_descriptor_size:]
        else:
            z_app = None
        features_dc, features_rest, orient_conf = self.color_decoder(z_app[:,1:]).split([3, 3 * ((self.max_sh_degree + 1) ** 2 - 1), 1], dim=-1)
        # idx_int = torch.randperm(self.max_num_strands, device='cuda')[:self.num_strands]
        # if training:
        #     self.idx = torch.cat((idx, idx_int), dim = 0)
        # origins = self.origins[time_index][idx]
        # uvs = self.uvs[time_index][idx]
        # local2world = self.local2world[time_index][idx]
        uvs_gdn = uvs_grid
        # p_local_gdn = p_local
        origins_int = self.origins[time_index][idx]
        uvs_int = self.uvs[time_index][idx]
        local2world_int = self.local2world[time_index][idx]
        self.uvs_select = uvs_int
        uvs_mgrid = self.dynamic_mgrid.permute(1,2,0) # 2 x (res + 1) x (res + 1)
        uvs_mask = self.dynamic_mgrid_mask # (res + 1) x (res + 1)
        res_final = uvs_mgrid.shape[1] - 1
        res_init = self.res_init.item()
        # multi_res = [(res + 1) * res_init for res in range(res_final/res_init)]
        M = math.ceil(math.log(res_final / res_init, 2)) + 1
        N = num_strands
        multi_res = [res_init * (2 ** i) for i in range(M)]
        uvs_select_ad = ((uvs_int + 1) * 0.5).unsqueeze(0)
        multi_res_ad = torch.tensor(multi_res).cuda()[:,None,None]
        uvs_select_ad =  uvs_select_ad * multi_res_ad
        multi_res_ad_grid_four = multi_res_ad[:,:,None,:].repeat(1,1,4,1)
        uvs_select_ad_grid = torch.floor(uvs_select_ad).int().cuda() # M x N x 2
        uvs_select_ad_grid_four = uvs_select_ad_grid.unsqueeze(2).repeat(1,1,4,1) # M x N x 4 x 2
        offsets = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32).view(1, 1, 4, 2).cuda()
        uvs_select_ad_grid_four = uvs_select_ad_grid.unsqueeze(2) + offsets
        uvs_select_ad_grid_four = torch.minimum(uvs_select_ad_grid_four, multi_res_ad_grid_four)
        scale_factors = torch.tensor([res_final // res for res in multi_res],
                             device=uvs_select_ad_grid_four.device,
                             dtype=torch.int32)[:, None, None, None]
        uvs_select_ad_grid_four_high_res = (uvs_select_ad_grid_four * scale_factors).int()
        mask_vals = uvs_mask[uvs_select_ad_grid_four_high_res[..., 0],
                     uvs_select_ad_grid_four_high_res[..., 1]]  # shape: (M, N, 4)
        uvs_select_ad_grid_four_high_res_mask = torch.all(mask_vals, dim=-1).transpose(0, 1).int()
        # uvs_select_ad_grid_four_high_res_mask : N x M; uvs_res_select: N 
        uvs_res_select = M-1-torch.argmax(uvs_select_ad_grid_four_high_res_mask.flip(dims=[1]), dim=1)
        uvs_select_ad_grid_four_high_res_select = uvs_select_ad_grid_four_high_res[uvs_res_select, torch.arange(N).cuda()]
        uvs_select_four_select = (uvs_select_ad_grid_four_high_res_select[:, :, 0] +
                            uvs_select_ad_grid_four_high_res_select[:, :, 1] * (res_final + 1)).int().cuda()
        uvs_mask = uvs_mask.view(-1) 
        selected_idx = torch.nonzero(uvs_mask).cuda().reshape(-1)
        mapping = torch.full(uvs_mask.shape, -1).cuda()
        mapping[selected_idx] = torch.arange(len(selected_idx)).cuda()
        knn_idx = mapping[uvs_select_four_select] # N x 4
        uvs_int_norm = (uvs_int + 1) * 0.5
        uvs_gdn_norm = (uvs_gdn + 1) * 0.5
        u = uvs_int_norm[:, 0]
        v = uvs_int_norm[:, 1]
        uv_max_min = uvs_gdn_norm[knn_idx]
        u_min = uv_max_min[:,0,0]
        u_max = uv_max_min[:,3,0]
        v_min = uv_max_min[:,0,1]
        v_max = uv_max_min[:,3,1]
        # import pdb; pdb.set_trace()
        denom = (u_max - u_min) * (v_max - v_min) + 1e-8
        w11 = ((u_max - u) * (v_max - v)) / denom  
        w21 = ((u - u_min) * (v_max - v)) / denom  
        w12 = ((u_max - u) * (v - v_min)) / denom 
        w22 = ((u - u_min) * (v - v_min)) / denom  
        w11 = torch.clamp((u_max - u) * (v_max - v) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
        w21 = torch.clamp((u - u_min) * (v_max - v) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
        w12 = torch.clamp((u_max - u) * (v - v_min) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
        w22 = torch.clamp((u - u_min) * (v - v_min) / denom, 0, 1).unsqueeze(-1).unsqueeze(-1)
        p_local_knn = p_local[knn_idx]
        features_dc_knn = features_dc[knn_idx]
        features_rest_knn = features_rest[knn_idx]
        orient_conf_knn = orient_conf[knn_idx]
        
        weights = torch.cat([w11, w21, w12, w22], dim=1).unsqueeze(-1)
        p_local_int = (weights * p_local_knn).sum(dim=1)
        features_dc_int = (weights * features_dc_knn).sum(dim=1)
        features_rest_int = (weights * features_rest_knn).sum(dim=1)
        orient_conf_int = (weights * orient_conf_knn).sum(dim=1)
        p = (local2world_int[:, None] @ p_local_int[..., None])[:, :, :3, 0] + origins_int[:, None] # [num_strands, strang_length, 3]
        return p, uvs_int, local2world_int, p_local_int, features_dc_int, features_rest_int, orient_conf_int, z_app
        # return torch.cat(strands_list, dim=0), uvs, local2world, torch.cat(p_local_list, dim=0), torch.cat(features_dc_list, dim=0), torch.cat(features_rest_list, dim=0), torch.cat(orient_conf_list, dim=0), z_app