import warp as wp

from mpm_solver_warp.warp_utils import *
import numpy as np
import torch

@wp.func
def normalize(n: wp.vec3) -> wp.vec3:
    return n/wp.length(n)

@wp.func
def sphere_sdf(pos: wp.vec3, grid_pos: wp.vec3):
    return wp.length(grid_pos - pos) - 0.015

@wp.func
def sphere_normal(pos: wp.vec3, grid_pos: wp.vec3):
    return normalize(grid_pos-pos)


@wp.kernel
def collide_kernel(
        state: MPMStateStruct,
        model: MPMModelStructMulti,
        pos: wp.vec3,
        vel: wp.vec3,
        softness: float,
        friction: float,
        dt: float,
        is_move: wp.array(dtype=int, ndim=3)
):
    grid_x, grid_y, grid_z = wp.tid()
    grid_pos = wp.vec3(
        float(grid_x) * model.dx,
        float(grid_y) * model.dx,
        float(grid_z) * model.dx
    )
    dist = sphere_sdf(pos, grid_pos)
    if dist <= 0.0:
        is_move[grid_x, grid_y, grid_z] = 1
    influence = wp.min(wp.exp(-dist * softness), 1.0)

    # collide
    if (softness > 0.0 and influence > 0.1) or dist <= 0.0:
        D = sphere_normal(pos, grid_pos)
        collider_v_at_grid = vel
        input_v = state.grid_v_out[grid_x, grid_y, grid_z] - collider_v_at_grid
        normal_component = wp.dot(input_v, D)

        grid_v_t = input_v - wp.min(normal_component, 0.0) * D

        grid_v_t_norm = wp.length(grid_v_t)
        if grid_v_t_norm > 1e-30:
            grid_v_t_friction = grid_v_t / grid_v_t_norm * wp.max(
                0.0, grid_v_t_norm + normal_component * friction
            )
        else:
            grid_v_t_friction = grid_v_t

        use_friction = wp.select(normal_component < 0.0 and grid_v_t_norm > 1e-30, 1.0, 0.0)
        new_vel = collider_v_at_grid + input_v * (1.0 - influence) + \
                  (grid_v_t_friction * use_friction + grid_v_t * (1.0 - use_friction)) * influence

        state.grid_v_out[grid_x, grid_y, grid_z] = new_vel


class SpherePrimitive:
    def __init__(self, eef_pos, frame_dt, start_frame, grid_size, substep_per_frame, radius=0.015, softness=0.0, friction=0.1, device='cuda:0'):
        self.eef_pos = eef_pos
        self.frame_dt = frame_dt
        self.current_frame = start_frame
        self.grid_size = grid_size
        self.step_per_frame = substep_per_frame
        self.radius = radius
        self.softness = softness
        self.friction = friction
        self.device = device
        self.pos_now = self.eef_pos[start_frame]
        self.vel_now = (self.eef_pos[start_frame + 1] - self.eef_pos[start_frame]) / self.frame_dt
        self.current_step = 0
        self.visual_3dgs = self.generate_visual_3dgs(5000)

    def update_velocity(self):
        self.vel_now = (self.eef_pos[self.current_frame + 1] - self.eef_pos[self.current_frame]) / self.frame_dt

    def move(self, dt):
        if self.current_step >= self.step_per_frame:
            self.current_step = 0
            self.current_frame += 1
            self.update_velocity()
        self.pos_now += self.vel_now * dt
        self.visual_3dgs[0] += self.vel_now * dt
        self.current_step += 1

    def collide(self, state: MPMStateStruct, model: MPMModelStructMulti, dt: float):
        pos_now = wp.vec3(self.pos_now[0], self.pos_now[1], self.pos_now[2])
        vel_now = wp.vec3(self.vel_now[0], self.vel_now[1],self.vel_now[2])
        is_move = wp.zeros(
            shape=(model.n_grid, model.n_grid, model.n_grid),
            dtype=int,
            device=self.device, requires_grad=False
        )
        wp.launch(
            collide_kernel, dim=(self.grid_size),
            inputs=[state, model, pos_now, vel_now, self.softness, self.friction, dt, is_move],
            device=self.device,
        )
        # if self.current_step % 200 == 0:
        #     print(f'is_move sum: {is_move.numpy().sum()}')
        self.move(dt)

    def generate_visual_3dgs(self, n):
        ee_num = n
        normalization = 1.0 / np.sqrt(4 * np.pi)
        max_sh_degree = 3
        num_coeffs = (max_sh_degree + 1) ** 2

        xyz_ee = torch.tensor(self.generate_visualization_pcd(n), device=self.device, requires_grad=False, dtype=torch.float32)
        opacities_ee = torch.tensor(np.array([15.0] * ee_num).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        scales_ee = torch.tensor(np.array([[-9, -9, -9] * ee_num]).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        rots_ee = torch.tensor(np.array([[0.71, -0.01, -0.30, -0.007] * ee_num]).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        precomp_colors_ee = np.array([[0.0, 0.0, 0.9] * ee_num]).reshape(ee_num, -1)
        features_dc_ee = np.zeros((ee_num, 3, 1))
        features_dc_ee[:, :, 0] = (precomp_colors_ee - 0.5) / normalization
        precomp_colors_ee = torch.tensor(precomp_colors_ee, device=self.device, requires_grad=False, dtype=torch.float32)
        features_dc_ee = torch.tensor(features_dc_ee, device=self.device, requires_grad=False, dtype=torch.float32)
        features_extra_ee = torch.tensor(np.zeros((ee_num, 3, num_coeffs - 1)), device=self.device, requires_grad=False, dtype=torch.float32)


        return [xyz_ee,
                opacities_ee,
                scales_ee,
                rots_ee,
                precomp_colors_ee,
                features_dc_ee,
                features_extra_ee]

    def generate_visualization_pcd(self, n):
        center = self.pos_now.detach().cpu().numpy()
        radius = self.radius
        theta = np.random.uniform(0, 2 * np.pi, n)
        phi = np.arccos(np.random.uniform(-1, 1, n))

        r = radius * np.cbrt(np.random.uniform(0, 1, n))

        x = r * np.sin(phi) * np.cos(theta)
        y = r * np.sin(phi) * np.sin(theta)
        z = r * np.cos(phi)

        points = np.vstack([x, y, z]).T + np.array(center)
        # points = torch.from_numpy(points).to(self.device)
        return points


    @property
    def get_current_pos(self):
        return self.pos_now

    @property
    def get_radius(self):
        return self.radius

    @property
    def get_ee_3dgs(self):
        return self.visual_3dgs
