import warp as wp

from mpm_solver_warp.warp_utils import *
import numpy as np
import torch

@wp.func
def normalize(n: wp.vec3) -> wp.vec3:
    return n/wp.length(n)

@wp.func
def column_sdf(pos: wp.vec3,
               grid_pos: wp.vec3,
               l: float,
               r: float):
    p = grid_pos - pos
    horizontal_dist = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - r
    vertical_dist = wp.max(-(p[2] + l), p[2])
    return wp.max(horizontal_dist, vertical_dist)

# @wp.func
# def column_sdf(pos: wp.vec3,
#                grid_pos: wp.vec3,
#                l: float,
#                r: float):
#     p = grid_pos - pos
#     dist = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - r
#     return dist

@wp.func
def column_normal(pos: wp.vec3,
                  grid_pos: wp.vec3,
                  l: float,
                  r: float):
    epsilon = 1e-7
    p = grid_pos - pos
    if p[2] >= epsilon:
        return wp.vec3(0.0, 0.0, 1.0)
    elif p[2] <= -l + epsilon:
        return wp.vec3(0.0, 0.0, -1.0)
    horizontal = wp.vec3(p[0], p[1], 0.0)
    return normalize(horizontal + wp.vec3(epsilon))

# @wp.func
# def column_normal(pos: wp.vec3,
#                   grid_pos: wp.vec3,
#                   l: float,
#                   r: float):
#     p = grid_pos - pos
#     p[2] = 0.0
#     return normalize(p)


@wp.kernel
def collide_kernel(
        state: MPMStateStruct,
        model: MPMModelStructMulti,
        pos: wp.vec3,
        vel: wp.vec3,
        r: float,
        l: float,
        softness: float,
        friction: float,
        dt: float,
        is_move: wp.array(dtype=int, ndim=3)
):
    grid_x, grid_y, grid_z = wp.tid()
    grid_pos = wp.vec3(
        float(grid_x) * model.dx,
        float(grid_y) * model.dx,
        float(grid_z) * model.dx
    )
    # SDF
    dist = column_sdf(pos, grid_pos, l, r)
    if dist <= 0.0:
        is_move[grid_x, grid_y, grid_z] = 1
    influence = wp.min(wp.exp(-dist * softness), 1.0)

    # collide
    if (softness > 0.0 and influence > 0.1) or dist <= 0.0:
        D = column_normal(pos, grid_pos, l, r)
        collider_v_at_grid = vel
        input_v = state.grid_v_out[grid_x, grid_y, grid_z] - collider_v_at_grid
        normal_component = wp.dot(input_v, D)

        grid_v_t = input_v - wp.min(normal_component, 0.0) * D

        grid_v_t_norm = wp.length(grid_v_t)
        if grid_v_t_norm > 1e-30:
            grid_v_t_friction = grid_v_t / grid_v_t_norm * wp.max(
                0.0, grid_v_t_norm + normal_component * friction
            )
        else:
            grid_v_t_friction = grid_v_t

        use_friction = wp.select(normal_component < 0.0 and grid_v_t_norm > 1e-30, 1.0, 0.0)
        new_vel = collider_v_at_grid + input_v * (1.0 - influence) + \
                  (grid_v_t_friction * use_friction + grid_v_t * (1.0 - use_friction)) * influence

        state.grid_v_out[grid_x, grid_y, grid_z] = new_vel


class ColumnPrimitive:
    def __init__(self, eef_pos, frame_dt, start_frame, grid_size, substep_per_frame,
                 radius=0.025, length=0.2, softness=0, friction=0.01, device='cuda:0',end_frame = 0):
        self.eef_pos = eef_pos
        self.eef_pos[:, -1]  = 1.1
        self.frame_dt = frame_dt
        self.start_frame = start_frame
        self.end_frame = end_frame
        self.current_frame = self.start_frame
        self.grid_size = grid_size
        self.step_per_frame = substep_per_frame
        self.radius = radius
        self.length = length
        self.softness = softness
        self.friction = friction
        self.device = device
        self.pos_now = self.eef_pos[start_frame].clone()
        self.vel_now = (self.eef_pos[start_frame + 1] - self.eef_pos[start_frame]) / self.frame_dt
        self.current_step = 0
        self.visual_3dgs = self.generate_visual_3dgs(2000)

    def reset(self):
        self.current_frame = self.start_frame
        self.current_step = 0
        self.pos_now = self.eef_pos[self.current_frame].clone()
        self.update_velocity()
        self.visual_3dgs = self.generate_visual_3dgs(2000)

    def update_velocity(self):
        # print("current_frame",self.current_frame)
        self.vel_now = (self.eef_pos[self.current_frame + 1] - self.eef_pos[self.current_frame]) / self.frame_dt

    def move(self, dt):
        if self.current_step >= self.step_per_frame:
            self.current_step = 0
            self.current_frame += 1
            self.update_velocity()
        self.pos_now += self.vel_now * dt
        self.visual_3dgs[0] += self.vel_now * dt
        self.current_step += 1

    def collide(self, state: MPMStateStruct, model: MPMModelStructMulti, dt: float):
        pos_now = wp.vec3(self.pos_now[0], self.pos_now[1], self.pos_now[2])
        vel_now = wp.vec3(self.vel_now[0], self.vel_now[1],self.vel_now[2])
        is_move = wp.zeros(
            shape=(model.n_grid, model.n_grid, model.n_grid),
            dtype=int,
            device=self.device, requires_grad=False
        )
        wp.launch(
            collide_kernel, dim=(self.grid_size),
            inputs=[state, model, pos_now, vel_now, self.radius, self.length, self.softness, self.friction, dt, is_move],
            device=self.device,
        )
        # if self.current_step % 200 == 0:
        #     print(f'is_move sum: {is_move.numpy().sum()}')
        self.move(dt)

    def generate_visual_3dgs(self, n):
        ee_num = n
        normalization = 1.0 / np.sqrt(4 * np.pi)
        max_sh_degree = 3
        num_coeffs = (max_sh_degree + 1) ** 2

        xyz_ee = torch.tensor(self.generate_visualization_pcd(n), device=self.device, requires_grad=False, dtype=torch.float32)
        opacities_ee = torch.tensor(np.array([15.0] * ee_num).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        scales_ee = torch.tensor(np.array([[-9, -9, -9] * ee_num]).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        rots_ee = torch.tensor(np.array([[0.71, -0.01, -0.30, -0.007] * ee_num]).reshape(ee_num, -1), device=self.device, requires_grad=False, dtype=torch.float32)
        precomp_colors_ee = np.array([[0.0, 0.0, 0.9] * ee_num]).reshape(ee_num, -1)
        features_dc_ee = np.zeros((ee_num, 3, 1))
        features_dc_ee[:, :, 0] = (precomp_colors_ee - 0.5) / normalization
        precomp_colors_ee = torch.tensor(precomp_colors_ee, device=self.device, requires_grad=False, dtype=torch.float32)
        features_dc_ee = torch.tensor(features_dc_ee, device=self.device, requires_grad=False, dtype=torch.float32)
        features_extra_ee = torch.tensor(np.zeros((ee_num, 3, num_coeffs - 1)), device=self.device, requires_grad=False, dtype=torch.float32)
        # print("xyz_ee.shape",xyz_ee.shape)

        return [xyz_ee,
                opacities_ee,
                scales_ee,
                rots_ee,
                precomp_colors_ee,
                features_dc_ee,
                features_extra_ee]

    def generate_visualization_pcd(self, n):
        r = self.radius
        l = self.length
        center = self.pos_now.detach().cpu().numpy()
        theta = 2 * np.pi * np.random.rand(n)
        radius = r * np.sqrt(np.random.rand(n))
        z = -l * np.random.rand(n)

        x = radius * np.cos(theta)
        y = radius * np.sin(theta)

        points = np.vstack([x, y, z]).T + np.array(center)
        return points


    @property
    def get_current_pos(self):
        return self.pos_now

    @property
    def get_radius(self):
        return self.radius

    @property
    def get_ee_3dgs(self):
        return [p.clone() for p in self.visual_3dgs]
