import warp as wp
import numpy as np
import os
from base_primitive_kernel import *

class Primitive:
    def __init__(self, cfg):
        self.cfg = cfg

    def sdf(self, f, grid_pos):
        grid_pos = inv_trans(grid_pos, self.position[f], self.rotation[f])
        return self._sdf(f, grid_pos)

    def collide(self, f, grid_pos, v_out, dt):
        dist = self.sdf(f, grid_pos)
        influence = min(ti.exp(-dist * self.softness[None]), 1)
        if (self.softness[None] > 0 and influence> 0.1) or dist <= 0:
            D = self.normal(f, grid_pos)
            collider_v_at_grid = self.collider_v(f, grid_pos, dt)

            input_v = v_out - collider_v_at_grid
            normal_component = input_v.dot(D)

            grid_v_t = input_v - min(normal_component, 0) * D

            grid_v_t_norm = length(grid_v_t)
            grid_v_t_friction = grid_v_t / grid_v_t_norm * max(0, grid_v_t_norm + normal_component * self.friction[None])
            flag = ti.cast(normal_component < 0 and ti.sqrt(grid_v_t.dot(grid_v_t)) > 1e-30, self.dtype)
            grid_v_t = grid_v_t_friction * flag + grid_v_t * (1 - flag)
            v_out = collider_v_at_grid + input_v * (1 - influence) + grid_v_t * influence

            #print(self.position[f], f)
            #print(grid_pos, collider_v, v_out, dist, self.friction, D)
            #if v_out[1] > 1000:
            #print(input_v, collider_v_at_grid, normal_component, D)

        return v_out

