from time import sleep
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as TF
import numpy as np
import torchvision
from tqdm import tqdm

from simple_knn._C import group2DistCUDA2, group2IndexCUDA, indexCUDA, distCUDA2
from scene.gaussian_model import GaussianModel
from scene.cameras import CameraProvider
from utils.general_utils import torch_percentile,get_minimum_axis, flip_align_view
from utils.graphics_utils import transform_xyz
from utils.loss_utils import l1_loss, l1_loss_sum, l2_loss, l2_loss_sum, l2_loss_weighted
from utils.sh_rotation_utils import rotate_sh_by_quaternion
from utils.sh_utils import eval_sh

class torch_yes_grad(object):
    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

class FreqEncoder_torch(nn.Module):
    def __init__(self, input_dim, max_freq_log2, N_freqs,
                 log_sampling=True, include_input=True,
                 periodic_fns=(torch.sin, torch.cos)):
    
        super().__init__()

        self.input_dim = input_dim
        self.include_input = include_input
        self.periodic_fns = periodic_fns

        self.output_dim = 0
        if self.include_input:
            self.output_dim += self.input_dim

        self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)

        if log_sampling:
            self.freq_bands = 2. ** torch.linspace(0., max_freq_log2, N_freqs)
        else:
            self.freq_bands = torch.linspace(2. ** 0., 2. ** max_freq_log2, N_freqs)

        self.freq_bands = self.freq_bands.numpy().tolist()

    def forward(self, input, **kwargs):

        out = []
        if self.include_input:
            out.append(input)

        for i in range(len(self.freq_bands)):
            freq = self.freq_bands[i]
            for p_fn in self.periodic_fns:
                out.append(p_fn(input * freq))

        out = torch.cat(out, dim=-1)
        return out

def global_feature_detection(gs,cam_sample_fn,render_fn):
    from kmeans_pytorch import kmeans
    from simple_cuda_utils._C import vote_to_bin_CUDA
    def log_color_table(postfix:str):
        cluster_centers_img = torch.randn((3,400,50),device=cluster_centers.device)
        GAP = int(400/cluster_centers.shape[0])
        for i,c in enumerate(cluster_centers):
            cluster_centers_img[:,i*GAP:(i+1)*GAP,:] = c[:,None,None].repeat(1,GAP,50)
        torchvision.utils.save_image(cluster_centers_img,f"./.cache/cluster_centers_img_{postfix}.png")
    B = gs.get_xyz.shape[0]
    N_VIEW = 200
    SCREEN_NOSPATIAL=True
    N_BIN = 3
    cluster_centers = None # to grow
    count_vote_number = [] # to grow
    for vid in tqdm(range(N_VIEW),desc="collection global feature"):
        mini_cam = cam_sample_fn()
        if SCREEN_NOSPATIAL:
            out = render_fn(mini_cam,gs)
            out_rgb = out['render'].permute(1,2,0)
            out_mask = (out['alpha']>0.9).permute(1,2,0)
            rgb = out_rgb[out_mask.repeat(1,1,3)].view(-1,3) #(400,400,3)->(?,3)
        else:
            m = torch.randint(0,B,(5000),device=gs.get_xyz.device)
            fea = gs.get_features[m]
            xyz = gs.get_xyz[m]
            viewdir = xyz - mini_cam.camera_center[None,:]
            sh2rgb = eval_sh(gs.active_sh_degree, fea.transpose(1, 2), viewdir/viewdir.norm(dim=1, keepdim=True))
            rgb = torch.clamp_min(sh2rgb + 0.5, 0.0) # (5000,3)
        
        if vid == 0:
            _, cluster_centers = kmeans(
                X=rgb, num_clusters=N_BIN, distance='euclidean', device=torch.device('cuda:0'),
                tqdm_flag=False
            )#(N_BIN,3)
            log_color_table(vid)
        else:
            offset = ((rgb[:,None,:]-cluster_centers[None,:,:])**2).sum(dim=-1) # (?,N_BIN)
            min_offset,_ = offset.min(dim=-1) # (?,)
            new_rgb = rgb[min_offset>3*(0.2*0.2)]
            if new_rgb.shape[0]>3:
                _, new_center = kmeans(
                    X=new_rgb, num_clusters=N_BIN, distance='euclidean', device=torch.device('cuda:0'),
                    tqdm_flag=False
                )#(N_BIN,3)
                cluster_centers = torch.cat([cluster_centers,new_center],dim=0)
                offset = ((rgb[:,None,:]-cluster_centers[None,:,:])**2).sum(dim=-1) # (?,N_BIN)
            near_cluster_id = torch.argmin(offset,dim=-1).type(torch.int32) # (?,)
            cluster_centers_sum, count = vote_to_bin_CUDA(rgb,cluster_centers,near_cluster_id)
            new_cluster_centers = cluster_centers_sum/count.repeat(1,cluster_centers_sum.shape[-1])
            cluster_centers = 0.5*cluster_centers + 0.5*new_cluster_centers
            for i,a in enumerate(count):
                if i>=len(count_vote_number):
                    count_vote_number.append(a)
                else:
                    count_vote_number[i]+=a
    log_color_table(vid)
    print("before prune:")
    count_vote_number = torch.cat(count_vote_number)
    print(f"[debug]:cluster_centers:{cluster_centers.shape}")
    print("pruning:")
    min_count = torch_percentile(count_vote_number,[10])
    vm = count_vote_number>min_count
    count_vote_number = count_vote_number[vm]
    cluster_centers = cluster_centers[vm]
    print("after prune:")
    print(f"[debug]:cluster_centers:{cluster_centers.shape}")
    log_color_table("final")

    k_bin_percent = count_vote_number.cpu().numpy()
    k_bin_percent /= k_bin_percent.sum()
    print(f"[debug]:k_bin_percent:{k_bin_percent}")
    # k_bin_range = []
    # _left=0
    # for i in range(len(k_bin_percent.tolist())):
    #     k_bin_range.append((_left,_left+k_bin_percent[i]))
    #     _left+=k_bin_percent[i]
    # print(f"[debug]:k_bin_range:{k_bin_range}")
    k_bin_loss_weight=[k_bin_percent[i] for i in range(len(k_bin_percent.tolist()))] #np.exp(-20*(k_bin_percent[i]))
    print(f"[debug]:k_bin_loss_weight:{k_bin_loss_weight}")

    return cluster_centers, k_bin_percent, k_bin_loss_weight

class SeamlessModel:
    r"""
    The stitching can actually be done in canonical space,
    the only transformation to noice is:
     1) searching intetsection boarder;
     2) SH rotation for view dependent effects;
    """
    def __init__(self,
                 gs1:GaussianModel,gs2:GaussianModel,info1,info2,
                 render_fn,local_cam_list,
                 global_camera_handler:CameraProvider,
                 local_camera_source_handler:CameraProvider) -> None:
        assert len(local_cam_list)>0

        self.source_gs = gs1
        self.target_gs = gs2
        self.source_info = info1
        self.target_info = info2
        self.xyz1 = transform_xyz(gs1.get_xyz,info1['sc'],info1['rot'],info1['trans'])
        self.xyz2 = transform_xyz(gs2.get_xyz,info2['sc'],info2['rot'],info2['trans'])
        self.source_boarder_mask, self.source_boarder_graph = self._search_boarder(self.xyz1,self.xyz2,gs1.get_opacity.squeeze(dim=-1))
        self.target_boarder_mask, self.target_boarder_graph = self._search_boarder(self.xyz2,self.xyz1,gs2.get_opacity.squeeze(dim=-1))
        print(f"[debug]:total:{self.source_boarder_mask.shape};self.source_boarder_graph:{self.source_boarder_graph.shape}")
        print(f"[debug]:total:{self.target_boarder_mask.shape};self.target_boarder_graph:{self.target_boarder_graph.shape}")
        self.source_gs.set_opt_decay_mask(self.source_boarder_mask)
        self.target_gs.set_opt_decay_mask(self.target_boarder_mask)

        self.render_fn=render_fn
        self.local_cam_list=local_cam_list
        self.global_camera_handler = global_camera_handler
        self.local_camera_source_handler = local_camera_source_handler

        # self._propagate_knnGradDecentOpt() # grad between knn
        self._prepare()


    # run a whole iteration without GUI
    def start(self):
        raise NotImplementedError("API deprecated")
        # self._grad_propagate()

    # run one iter with GUI
    def step(self):
        self._train_step_fast(sample_sh_insteadOf_coeff=True,shuffle_sh_ref=False)
        # pass
    
    @torch.no_grad()
    def _prepare(self):
        r"""
        self.xxx member created in prepare will be only used in train func
        """
        #############################################################
        # Rotate SH
        # ! WARNING !!
        # ! Inplaced operation !
        # First need to rotate source's sh coeff into current transform.
        # And be aware that this operation code is inplaced.
        # Note: if later some day, the sh rotate has been outside, remove this code. 
        #       the reason it curent here is because that (which isn't real-time so 
        #       it hasn't been done interactively outside).
        band_dc,band_rest = self.source_gs.get_features_inplace
        sh_bands = torch.cat([band_dc,band_rest],dim=1)
        sh_bands = rotate_sh_by_quaternion(sh_bands,self.source_info['rot'])
        band_dc[:], band_rest[:] = sh_bands[:,0:1,:], sh_bands[:,1:,:]
        #############################################################

        #############################################################
        # Local feature propagate (prepare resources)
        # 1) collecte boarder color optimize target
        refer_source_fea = self.source_gs.get_features[self.target_boarder_graph]
        self.boarder_sh = refer_source_fea.mean(dim=1)
        # 2) knn-graph for random sample
        boarder_xyz = self.target_gs.get_xyz[self.target_boarder_mask]
        sample_D = torch.sqrt(group2DistCUDA2(self.target_gs.get_xyz,boarder_xyz))
        sample_G = group2IndexCUDA(self.target_gs.get_xyz,boarder_xyz)
        near_xyz = boarder_xyz[sample_G].mean(dim=1)
        def local_warp(xyz):
            r"""
            this warp helps to add random effect for local propagate under complex boarder condition
            xyz -> xyz(warped)
            """
            d_xyz=xyz-near_xyz
            noise = torch.sin(10*d_xyz) # 10 is suitable for most cases
            # noise = torch.sin(-np.e*d_xyz)+torch.sin(10*d_xyz) # for low freq geo
            # noise = torch.sin(40*d_xyz) # over 40 for some high freq boundary
            return xyz+noise
        self.sample_graph = group2IndexCUDA(local_warp(self.target_gs.get_xyz),boarder_xyz)
        #############################################################

        #############################################################
        # global feature detection in canonical space (prepare resources)
        # 1) k-means on source feature --> sample sh to rgb k-bin/percentage
        cluster_centers, k_bin_percent, k_bin_loss_weight = global_feature_detection(
            self.source_gs,
            lambda : self.local_camera_source_handler()['minicam'][0],
            self.render_fn)
        k_bin_range = []
        _left=0
        for i in range(len(k_bin_percent.tolist())):
            k_bin_range.append((_left,_left+k_bin_percent[i]))
            _left+=k_bin_percent[i]
        print(f"[debug]:k_bin_range:{k_bin_range}")
        k_bin_percent = torch.from_numpy(k_bin_percent).cuda()
        
        # 2) voxelize and prune target gs --> {voxel_i: [idx,idx,idx]}
        # 3)*asign voxel to sh k-bin --> {voxel_i: bin_j}
        # 2+3) maybe we could finished 2+3 once and for all by proper continous mapping function
        # from utils.perlin_noise_utils import generate_perlin_noise_3d
        # noise = generate_perlin_noise_3d(
        #     (256, 256, 256), (4, 4, 4), tileable=(False, False, False)
        # )
        # noise = (noise-noise.min())/(noise.max()-noise.min())
        # noise = torch.from_numpy(noise).cuda()
        def global_warp(xyz):
            r"""
            this warp helps to add global feature of that the source with too many subtle small random features
            xyz(B,3) -> scaler[0-1](B,)
            """
            res = 100*(xyz*xyz)
            res = torch.abs(torch.sin(res)).sum(dim=-1)/3
            return res
        @torch.no_grad()
        def collect_kbin_color_and_weight(init_color):
            r"""
            init_color:(B,3)
            """
            # axis_pos = global_warp(xyz)
            # for i in range(len(k_bin_range)):
            #     m = (axis_pos>k_bin_range[i][0])&(axis_pos<k_bin_range[i][1])
            #     ref_color[m] = cluster_centers[i]
            #     ref_weight[m] = k_bin_loss_weight[i]
            ref_color = init_color.detach().clone()
            offsets = ((init_color[:,None,:]-cluster_centers[None,:,:])**2).sum(dim=-1) # (B,N_BIN)
            offsets = torch.sqrt(offsets) # Euclidean (B,N_BIN)
            offsets = offsets-k_bin_percent[None,:]# (B,N_BIN) radius range classifier
            min_offset,bin_id = torch.min(offsets,dim=-1) # (B,)
            ref_color[:] = cluster_centers[bin_id]
            ref_weight = torch.tensor(k_bin_loss_weight,device=bin_id.device)[bin_id]

            return ref_color, ref_weight[:,None]
        # self.debug_dist = torch.tensor([0],dtype=torch.float32,device=cluster_centers.device)
        self.global_detect_fn = collect_kbin_color_and_weight
        ##############################################################        

        # Training set up
        fea_dc,fea_rest = self.target_gs.get_features_inplace
        l=[
            {'params': [fea_dc], 'lr': 0.01, "name": "f_dc"},
            {'params': [fea_rest], 'lr': 0.01/20.0, "name": "f_rest"}
        ]
        # 1st grad kernel
        # group==3 conv can prevent feature propagate, thus we use group=1 to apply on conv(rgb->a single intensity)
        self.k_sobel_x = torch.tensor([[-1,0,1],
                                       [-2,0,2],
                                       [-1,0,1]],dtype=torch.float32,device=fea_dc.device,requires_grad=False).repeat(1,3,1,1)
        self.k_sobel_y = torch.tensor([[-1,-2,-1],
                                       [ 0, 0, 0],
                                       [ 1, 2, 1]],dtype=torch.float32,device=fea_dc.device,requires_grad=False).repeat(1,3,1,1)
        self.k_laplace = torch.tensor([[-1,-1,-1],
                                       [-1, 8,-1],
                                       [-1,-1,-1]],dtype=torch.float32,device=fea_dc.device,requires_grad=False).repeat(1,3,1,1)
        self.optimizer = torch.optim.AdamW(l, lr=0.0, eps=1e-15)
        self.iter=0
        self.max_iter = 6000
        self.global_from_iter = (4.5/6)*self.max_iter
        self.tqdm_bar = tqdm(range(self.max_iter))

        # Logging: log to check camlist
        rgb_BCHW=torch.stack([p_cam.origin_rgb for p_cam in self.local_cam_list],dim=0)
        torchvision.utils.save_image(rgb_BCHW,"./.cache/scene_example_rgb.jpg")

    def _train_step_fast(self,random_strategy=False):
        r"""
        accelerate strategy:
            1) sample a part of certain points(by dist range), for color supervise, in each iteration;
            2) sample cam-xyz direction to tune sh, for view-dependent alignment
        """
        if self.iter > self.max_iter:
            return
        # if self.debug_dist.shape[0]>1e7:
        #     import seaborn as sns
        #     import matplotlib.pyplot as plt
        #     sns.kdeplot(self.debug_dist.cpu().numpy(),
        #                 shade=True,
        #                 color="#dc2624",
        #                 alpha=.7)
        #     sns.set(style="whitegrid", font_scale=1.1)
        #     plt.title('xxx', fontsize=18)
        #     plt.legend()
        #     plt.show()
        #     exit(1)

        # loss grad: pickout one local cam for grad keeping
        cam = self.local_cam_list[np.random.randint(0,len(self.local_cam_list))]
        rgb_gt = cam.origin_rgb[None,...]
        grad_x_gt = TF.conv2d(rgb_gt,self.k_sobel_x,padding=1,groups=1)
        grad_y_gt = TF.conv2d(rgb_gt,self.k_sobel_y,padding=1,groups=1)
        out = self.render_fn(cam,self.target_gs)
        rgb = out['render'][None,...]
        rgb_grad_x = TF.conv2d(rgb,self.k_sobel_x,padding=1,groups=1)
        rgb_grad_y = TF.conv2d(rgb,self.k_sobel_y,padding=1,groups=1)
        loss_grad = l2_loss(rgb_grad_x,grad_x_gt)+l2_loss(rgb_grad_y,grad_y_gt)

        # loss color: local propagate
        random_pickmask = torch.randint(0,self.sample_graph.shape[0],(5000,),device=rgb.device)
        random_graph = self.sample_graph[random_pickmask] # (B,3)->(P,3) P=random_pickmask.sum()
        # pick out SH and aligned in noraml direction
        random_sh_ref = self.boarder_sh[random_graph].mean(dim=1) # (P,16,3)
        random_sh = self.target_gs.get_features[random_pickmask]
        if random_strategy:
            random_sh_ref = random_sh_ref[torch.randperm(random_sh_ref.shape[0],device=random_sh_ref.device)]
        if True:
            ################################################################
            # sample these SH instead of diretly loss their noraml-aligned coeff
            # ----------------------------------------------------------------
            # pickout one global camera for stitching
            mini_cam = self.global_camera_handler()['minicam'][0]
            random_sh_ref = random_sh_ref.transpose(1, 2)
            random_sh = random_sh.transpose(1, 2)
            xyz_ref = self.xyz1[self.target_boarder_graph].mean(dim=1)[random_graph].mean(dim=1).detach()
            viewdir_ref = xyz_ref - mini_cam.camera_center[None,:]
            viewdir = self.xyz2[random_pickmask].detach() - mini_cam.camera_center[None,:]
            sh2rgb_ref = eval_sh(self.source_gs.active_sh_degree, random_sh_ref, viewdir_ref/viewdir_ref.norm(dim=1, keepdim=True))
            random_sh_ref = torch.clamp_min(sh2rgb_ref + 0.5, 0.0).detach()
            sh2rgb = eval_sh(self.target_gs.active_sh_degree, random_sh, viewdir/viewdir.norm(dim=1, keepdim=True))
            random_sh = torch.clamp_min(sh2rgb + 0.5, 0.0)
            # --------------------------------------------------------------
            ################################################################
        # # Need?: P batch may varying for each iter (cause dependent on t_range), fixed with random pick bid
        # _P = random_sh.shape[0]
        # bid = torch.randint(0,_P,(5000,),device=random_dc.device)
        # random_sh_ref = random_sh_ref[bid]
        # random_sh = random_sh[bid]
        loss_color_random = l2_loss(random_sh_ref,random_sh)
        loss_color_boarder = l2_loss(self.target_gs.get_features[self.target_boarder_mask], self.boarder_sh)
        loss_color = loss_color_random+loss_color_boarder
        if self.iter>self.global_from_iter:
            # loss color: global color tune based on local
            warm_fn = lambda ratio: 0.5*np.cos(np.pi*(ratio-1))+0.5
            warm_up = warm_fn(
                (self.iter-self.global_from_iter)/(self.max_iter-self.global_from_iter)
            )
            # # sample each point's sh to tune
            # random_pickmask = torch.randint(0,self.sample_graph.shape[0],(5000,),device=rgb.device)
            # random_sh = self.target_gs.get_features[random_pickmask].transpose(1, 2)
            # mini_cam = self.global_camera_handler()['minicam'][0]
            # viewdir = self.xyz2[random_pickmask].detach() - mini_cam.camera_center[None,:]
            # sh2rgb = eval_sh(self.target_gs.active_sh_degree, random_sh, viewdir/viewdir.norm(dim=1, keepdim=True))
            # sh2rgb = torch.clamp_min(sh2rgb + 0.5, 0.0)

            # or tune in image space (better)
            cam = self.local_cam_list[np.random.randint(0,len(self.local_cam_list))]
            out = self.render_fn(cam,self.target_gs)
            out_rgb = out['render'].permute(1,2,0)
            out_mask = (out['alpha']>0.9).permute(1,2,0)
            sh2rgb = out_rgb[out_mask.repeat(1,1,3)].view(-1,3) #(400,400,3)->(?,3)

            rgb_ref,l_w = self.global_detect_fn(sh2rgb)
            loss_color_global = l2_loss_weighted(sh2rgb,rgb_ref,l_w)
            loss_color += 2*loss_color_global

        # self.debug_dist = torch.cat([self.debug_dist,debug_dist]).detach()

        loss = 2*loss_grad+loss_color
        # logging
        # print({"loss_color":loss_color.item(),"loss_grad":loss_grad.item(),"iter":self.iter})
        if self.iter<self.global_from_iter:
            self.tqdm_bar.set_description("stage1: local propagate")
        else:
            self.tqdm_bar.set_description("stage2: local+global tune")
        self.tqdm_bar.update()

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.iter+=1
    
    @torch.no_grad()
    def _search_boarder(self,p,ref_p,p_opacity):
        """
        calculate boarder point and graph linked to these pointss
        Params:
            p(B1,3)
            ref_p(B2,3)
        Return:
            boarder_mask: (B1,) boolean, note N=sum(boarder_idx)
            boarder_graph (N,K) long, note K is k-nearest
        """

        # boarder condition
        K=3 # hardcoded in cuda, maybe rewrite later
        d0_2 = distCUDA2(p) # knn dist in self pcd
        t = torch_percentile(d0_2,[99])
        m_non_outlier = d0_2<t
        d1_2 = group2DistCUDA2(p,ref_p) # knn dist from p to ref_p
        m_intersect = d1_2 < torch_percentile(distCUDA2(ref_p),[99])
        m_solid_point = p_opacity>0.5
        boarder_mask = (m_non_outlier&m_intersect&m_solid_point)

        N_p = p[boarder_mask]
        boarder_graph = group2IndexCUDA(N_p,ref_p) # 2 stage calculate to save memory
        return boarder_mask, boarder_graph
    
    def get_normal(self, gs, info):
        dir_pp = (gs.get_xyz - viewpoint_camera.camera_center.repeat(gs.get_features.shape[0], 1))
        dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)

        normal_axis = get_minimum_axis(gs.get_scaling, gs.get_rotation)
        normal_axis = normal_axis
        normal_axis, positive = flip_align_view(normal_axis, dir_pp_normalized)
      
        normal = normal_axis/normal_axis.norm(dim=1, keepdim=True) # (N, 3)
        return normal
    

    # ----- history useless moethod -----------------------

    def _propagate_knnGradDecentOpt(self):
        r"""
        deprecated, useless
        """
        with torch.no_grad():
            # step: record the grad of the whole target gs
            K=3 # hardcoded in cuda, max 20
            self.target_graph = indexCUDA(self.target_gs.get_xyz,K) # 3 maybe too small
            target_fea_dc,target_fea_rest = self.target_gs.get_features_inplace
            target_fea_dc_grad,target_fea_rest_grad = target_fea_dc[self.target_graph],target_fea_rest[self.target_graph]
            self.target_fea_dc_grad = target_fea_dc[:,None,:,:]-target_fea_dc_grad #(B,3,1,3)
            self.target_fea_rest_grad = target_fea_rest[:,None,:,:]-target_fea_rest_grad #(B,3,15,3)

            # step: record color in target gs boarder to the color of source gs by 2groupknn
            source_fea_dc, source_fea_rest=self.source_gs.get_features_inplace
            refer_source_fea_dc, refer_source_fea_rest = source_fea_dc[self.target_boarder_graph],source_fea_rest[self.target_boarder_graph]
            self.new_dc, self.new_rest = refer_source_fea_dc.mean(dim=1),refer_source_fea_rest.mean(dim=1)

            #############################################################
            # global feature detection in canonical space (prepare resources)
            # 1) k-means on source feature --> sample sh to rgb k-bin/percentage
            cluster_centers, k_bin_percent, k_bin_loss_weight = global_feature_detection(
                self.source_gs,
                lambda : self.local_camera_source_handler()['minicam'][0],
                self.render_fn)
            k_bin_range = []
            _left=0
            for i in range(len(k_bin_percent.tolist())):
                k_bin_range.append((_left,_left+k_bin_percent[i]))
                _left+=k_bin_percent[i]
            print(f"[debug]:k_bin_range:{k_bin_range}")
            k_bin_percent = torch.from_numpy(k_bin_percent).cuda()
            
            # 2) voxelize and prune target gs --> {voxel_i: [idx,idx,idx]}
            # 3)*asign voxel to sh k-bin --> {voxel_i: bin_j}
            # 2+3) maybe we could finished 2+3 once and for all by proper continous mapping function
            # from utils.perlin_noise_utils import generate_perlin_noise_3d
            # noise = generate_perlin_noise_3d(
            #     (256, 256, 256), (4, 4, 4), tileable=(False, False, False)
            # )
            # noise = (noise-noise.min())/(noise.max()-noise.min())
            # noise = torch.from_numpy(noise).cuda()
            def global_warp(xyz):
                r"""
                this warp helps to add global feature of that the source with too many subtle small random features
                xyz(B,3) -> scaler[0-1](B,)
                """
                res = 100*(xyz*xyz)
                res = torch.abs(torch.sin(res)).sum(dim=-1)/3
                return res
            @torch.no_grad()
            def collect_kbin_color_and_weight(init_color):
                r"""
                init_color:(B,3)
                """
                # axis_pos = global_warp(xyz)
                # for i in range(len(k_bin_range)):
                #     m = (axis_pos>k_bin_range[i][0])&(axis_pos<k_bin_range[i][1])
                #     ref_color[m] = cluster_centers[i]
                #     ref_weight[m] = k_bin_loss_weight[i]
                ref_color = init_color.detach().clone()
                offsets = ((init_color[:,None,:]-cluster_centers[None,:,:])**2).sum(dim=-1) # (B,N_BIN)
                offsets = torch.sqrt(offsets) # Euclidean (B,N_BIN)
                offsets = offsets-k_bin_percent[None,:]# (B,N_BIN) radius range classifier
                min_offset,bin_id = torch.min(offsets,dim=-1) # (B,)
                ref_color[:] = cluster_centers[bin_id]
                ref_weight = torch.tensor(k_bin_loss_weight,device=bin_id.device)[bin_id]

                return ref_color, ref_weight[:,None]
            # self.debug_dist = torch.tensor([0],dtype=torch.float32,device=cluster_centers.device)
            self.global_detect_fn = collect_kbin_color_and_weight
            ##############################################################

            fea_dc,fea_rest = self.target_gs.get_features_inplace
            # # whether to do a noisy to distrub?
            # fea_dc[:] += 0.2*torch.randn_like(fea_dc)
            # fea_rest[:] += 0.1*torch.randn_like(fea_rest)
            l=[
                {'params': [fea_dc], 'lr': 0.01, "name": "f_dc"},
                {'params': [fea_rest], 'lr': 0.01/20.0, "name": "f_rest"}
            ]
            self.decay_mask = self.target_gs.get_decay_mask
            self.optimizer = torch.optim.AdamW(l, lr=0.0, eps=1e-15)
            self.iter=0
            self.max_iter = 6000
            self.global_from_iter = (4.5/6)*self.max_iter
        progress_bar = tqdm(range(self.max_iter))
        for _ in progress_bar:
            fea_dc,fea_rest = self.target_gs.get_features_inplace
            fea_dc_grad,fea_rest_grad = fea_dc[self.target_graph],fea_rest[self.target_graph]
            fea_dc_grad = fea_dc[:,None,:,:]-fea_dc_grad #(B,3,1,3)
            fea_rest_grad = fea_rest[:,None,:,:]-fea_rest_grad #(B,3,15,3)
            m = self.decay_mask[:,None,None,None]

            random_pick = torch.randint(0,fea_dc.shape[0],(self.new_dc.shape[0],),device=fea_dc.device)
            loss_rgb = l2_loss(fea_dc[random_pick],self.new_dc)+0.1*l2_loss(fea_rest[random_pick],self.new_rest)
            loss_rgb += l2_loss(fea_dc[self.target_boarder_mask],self.new_dc)+0.1*l2_loss(fea_rest[self.target_boarder_mask],self.new_rest)
            if self.iter>self.global_from_iter:
                # # sample each point's sh to tune
                # random_pickmask = torch.randint(0,self.sample_graph.shape[0],(5000,),device=rgb.device)
                # random_sh = self.target_gs.get_features[random_pickmask].transpose(1, 2)
                # mini_cam = self.global_camera_handler()['minicam'][0]
                # viewdir = self.xyz2[random_pickmask].detach() - mini_cam.camera_center[None,:]
                # sh2rgb = eval_sh(self.target_gs.active_sh_degree, random_sh, viewdir/viewdir.norm(dim=1, keepdim=True))
                # sh2rgb = torch.clamp_min(sh2rgb + 0.5, 0.0)

                # or tune in image space (better)
                cam = self.local_cam_list[np.random.randint(0,len(self.local_cam_list))]
                out = self.render_fn(cam,self.target_gs)
                out_rgb = out['render'].permute(1,2,0)
                out_mask = (out['alpha']>0.9).permute(1,2,0)
                sh2rgb = out_rgb[out_mask.repeat(1,1,3)].view(-1,3) #(400,400,3)->(?,3)

                rgb_ref,l_w = self.global_detect_fn(sh2rgb)
                loss_color_global = l2_loss_weighted(sh2rgb,rgb_ref,l_w)
                loss_color += 2*loss_color_global

            random_pick = torch.randint(0,fea_dc.shape[0],(self.new_dc.shape[0],),device=fea_dc.device)
            loss_grad = l2_loss(
                fea_dc_grad[random_pick],
                self.target_fea_dc_grad[random_pick])+0.1*l2_loss(
                fea_rest_grad[random_pick],
                self.target_fea_rest_grad[random_pick])
            loss = loss_rgb+2*loss_grad

            progress_bar.set_postfix({"loss_rgb":loss_rgb.item(),"loss_grad":loss_grad.item()})

            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

    
    @torch.no_grad()
    def _grad_propagate_knnwolearning(self):
        r"""
        deprecated, useless
        """
        # step: keep the grad of the whole target gs
        K=5 # hardcoded in cuda, max 20
        target_graph = indexCUDA(self.target_gs.get_xyz,K) # 3 maybe too small
        print(f"[debug]target_graph:{target_graph.shape}")
        target_fea_dc,target_fea_rest = self.target_gs.get_features_inplace
        target_fea_dc_grad,target_fea_rest_grad = target_fea_dc[target_graph],target_fea_rest[target_graph]
        target_fea_dc_grad = target_fea_dc[:,None,:,:]-target_fea_dc_grad #(B,3,1,3)
        target_fea_rest_grad = target_fea_rest[:,None,:,:]-target_fea_rest_grad #(B,3,15,3)
        print(f"[debug]target_fea_dc_grad:{target_fea_dc_grad.shape}")
        print(f"[debug]target_fea_rest_grad:{target_fea_rest_grad.shape}")

        # step: set color in target gs boarder to the color of source gs by 2groupknn
        source_fea_dc, source_fea_rest=self.source_gs.get_features_inplace
        refer_source_fea_dc, refer_source_fea_rest = source_fea_dc[self.target_boarder_graph],source_fea_rest[self.target_boarder_graph]
        new_dc, new_rest = refer_source_fea_dc.mean(dim=1),refer_source_fea_rest.mean(dim=1)
        target_fea_dc, target_fea_rest=self.target_gs.get_features_inplace
        target_fea_dc[self.target_boarder_mask],target_fea_rest[self.target_boarder_mask]=new_dc, new_rest

        # step: update color in target (except boarder), by grad propagate
        # a naive solution with out grad_decent optimizer is that:
        # 1) maintain two set: one (S') means points have been opimized, another (S) means haven't,
        #    here use a mask to identify oped;
        # 2) mask[board_points] = True, add recolored points to S' in last step;
        # 3) pick out all point in S which has at least one neighbor in S'. update this points color
        #    with target_graph grad and its neighbor color. add point to S'
        # repeat propagate until mask.all()
        oped_mask=torch.full((target_graph.shape[0],),fill_value=False,device=target_graph.device)
        oped_mask[self.target_boarder_mask]=True
        iter=0
        while oped_mask.sum()<0.8*target_graph.shape[0]:
            pick_mask = oped_mask[target_graph] #(B,3)
            num_neighbor = pick_mask.sum(dim=-1)
            pick_mask = (num_neighbor > 0) & (~oped_mask) #(B,) has neighbor and not in S'
            print(f"[debug while]:pick_mask:{pick_mask.sum()}")
            _t_graph = target_graph[pick_mask] # (V,3) number of picked is V
            _t_grad_dc, _t_grad_rest = target_fea_dc_grad[pick_mask], target_fea_rest_grad[pick_mask] # (V,3,1/15,3)
            _t_dc, _t_rest = target_fea_dc[_t_graph], target_fea_rest[_t_graph] # (V,3,1/15,3)
            weight = oped_mask[_t_graph].float()[:,:,None,None] # (V,3,1,1) discard the neighbor in S
            weight_sum = weight.sum(dim=1,keepdim=False) # (V,1,1)

            new_dc,new_rest = (weight*(_t_dc+_t_grad_dc)).sum(dim=1), (weight*(_t_rest+_t_grad_rest)).sum(dim=1) # (V,1/15,3)
            new_dc,new_rest = new_dc/weight_sum, new_rest/weight_sum # (V,1/15,3)
            target_fea_dc, target_fea_rest=self.target_gs.get_features_inplace
            target_fea_dc[pick_mask],target_fea_rest[pick_mask] = new_dc,new_rest
            # target_fea_dc[pick_mask] = 10*torch.ones((1,3),device=new_dc.device)[None,...].broadcast_to(new_dc.shape[0],-1,-1)
            # target_fea_rest[pick_mask] = torch.zeros((15,3),device=new_dc.device)[None,...].broadcast_to(new_rest.shape[0],-1,-1)

            oped_mask[pick_mask] = True
            iter+=1

    # class Mapping(nn.Module):
    #     def __init__(self,) -> None:
    #         super().__init__()
    #         self.encoder = FreqEncoder_torch(input_dim=3, max_freq_log2=6-1, N_freqs=6, log_sampling=True)
    #         self.layer = nn.Sequential(
    #             torch.nn.Linear(self.encoder.output_dim,32),
    #             nn.ReLU(True),
    #             torch.nn.Linear(32,32),
    #             nn.ReLU(True),
    #             torch.nn.Linear(32,1),
    #             torch.nn.Sigmoid()
    #         )
    #         self.apply(self.weight_init)
    #     def forward(self,x):
    #         x = self.encoder(x)
    #         x = self.layer(x)
    #         return x
    #     def weight_init(self,m):
    #         if isinstance(m, nn.Linear):
    #             nn.init.xavier_normal_(m.weight)
    #             nn.init.constant_(m.bias, 0)
    # mapper = Mapping().cuda()
    # l=[{'params': mapper.parameters(), 'lr': 0.001, "name": "mapper"}]
    # opt = torch.optim.AdamW(l, lr=0.0, eps=1e-15)
    # xyzs = (2.0*torch.rand((1000,3),device="cuda")-1.0)
    # ts = (torch.rand((1000,1),device="cuda"))
    # xyzs = TF.interpolate(xyzs.permute(1,0)[None,...],size=[10000],mode="linear").squeeze(dim=0).permute(1,0)
    # ts = TF.interpolate(ts.permute(1,0)[None,...],size=[10000],mode="linear").squeeze(dim=0).permute(1,0)

    # with torch_yes_grad():
    #     for p in mapper.parameters():
    #         p.requires_grad_(True)
    #     for _ in tqdm(range(10000),desc="shallow mapper"):
    #         pick = torch.randint(0,ts.shape[0],(5000,),device="cuda")
    #         xyz = xyzs[pick]
    #         t = ts[pick]
    #         pred_t = mapper(xyz)
    #         loss = l2_loss_sum(pred_t,t)
    #         loss.backward()

    #         print(f"debug:pred_t.max({pred_t.max()})")
    #         print(f"debug:pred_t.min({pred_t.min()})")
    #         print(f"debug:pred_t.mean({pred_t.mean()})")
    #         print(f"debug:loss({loss.item()})")
    #         opt.step()
    #         opt.zero_grad()
