# -*- coding: utf-8 -*-


class Runner:
    def __init__(self, args, conf_path, dataname, timestamp, mode='train'):
    
    def train(self):
            gradients_sample = self.sdf_network.gradient(samples,index=0).squeeze() 
            sdf_sample = self.sdf_network.sdf(samples,index=0)                      
            grad_norm1 = F.normalize(gradients_sample, dim=1)                
            st1_sample_moved = samples - grad_norm1 * sdf_sample                 #

            gradients_sample = self.sdf_network.gradient(st1_sample_moved,index=1).squeeze() 
            udf_sample = self.sdf_network.sdf(st1_sample_moved,index=1)                      
            grad_norm2 = F.normalize(gradients_sample, dim=1)                
            st2_sample_moved = st1_sample_moved - grad_norm2 * udf_sample                 #

            gradients_sample = self.sdf_network.gradient(st2_sample_moved,index=2).squeeze() 
            st3_udf_sample = self.sdf_network.sdf(st2_sample_moved,index=2)                      
            grad_norm3 = F.normalize(gradients_sample, dim=1)                
            st3_sample_moved = st2_sample_moved - grad_norm3 * st3_udf_sample                 
            ###loss###
            sdf_surf_loss=1e-3*torch.mean(st3_udf_sample**2-0)
            grad_sim_loss=1e-3*(1-min(F.cosine_similarity(grad_norm1, grad_norm2, dim=1).mean(),F.cosine_similarity(grad_norm1, grad_norm3, dim=1).mean()))
            focal_l2_loss=self.focal_loss(points,st1_sample_moved,st2_sample_moved)+torch.linalg.norm((points-st3_sample_moved), ord=2, dim=-1).mean()
            loss_sdf=focal_l2_loss+sdf_surf_loss+grad_sim_loss
    def focal_loss(self,st1_points,st1_sample_moved,st2_sample_moved,gamma=2):
        far_dis=torch.linalg.norm((st1_points-st1_sample_moved), ord=2, dim=-1).mean()
        near_dis=torch.linalg.norm((st1_points-st2_sample_moved), ord=2, dim=-1).mean()
        weight_matrix=torch.tensor([far_dis, near_dis], dtype=torch.float32)
        weight_matrix=torch.nn.functional.softmax(weight_matrix,dim=0)
        dynamic_alpha=torch.clamp(weight_matrix[0],max=1/self.maxiter)
        dynamic_beta=torch.clamp((1-weight_matrix[0])**gamma,max=1/self.maxiter)
        print('matrix:',weight_matrix)
        focal_loss=dynamic_alpha*far_dis+dynamic_beta*near_dis
        return focal_loss