from typing import Optional, Union, Sequence, Any
from dataclasses import dataclass
from pathlib import Path

import warp as wp
import numpy as np
from torch import Tensor
from omegaconf import DictConfig

from .abstract import State, Model, ModelBuilder, StateInitializer, StaticsInitializer
from modules.nclaw.warp import Tape, CondTape

from modules.d3gs.scene.gaussian_model import GaussianModel

import trimesh


@wp.struct
class MPMStatics(object):

    vol: wp.array(dtype=float)
    rho: wp.array(dtype=float)
    clip_bound: wp.array(dtype=float)
    enabled: wp.array(dtype=int)

    def init(self, shape: Union[Sequence[int], int], device: wp.context.Devicelike = None) -> None:
        self.vol = wp.zeros(shape=shape, dtype=float, device=device, requires_grad=False)
        self.rho = wp.zeros(shape=shape, dtype=float, device=device, requires_grad=False)
        self.clip_bound = wp.zeros(shape=shape, dtype=float, device=device, requires_grad=False)
        self.enabled = wp.zeros(shape=shape, dtype=int, device=device, requires_grad=False)

    @staticmethod
    @wp.kernel
    def set_int(x: wp.array(dtype=int), start: int, end: int, value: int) -> None:
        p = wp.tid()
        if start <= p and p < end:
            x[p] = value

    @staticmethod
    @wp.kernel
    def set_float(x: wp.array(dtype=float), start: int, end: int, value: float) -> None:
        p = wp.tid()
        if start <= p and p < end:
            x[p] = value
    
    @staticmethod
    @wp.kernel
    def set_float_array(x: wp.array(dtype=float), start: int, end: int, value: wp.array(dtype=float)) -> None:
        p = wp.tid()
        if start <= p and p < end:
            x[p] = value[p]

    def update_vol(self, sections: list[int], vols: list[float]) -> None:
        offset = 0
        for section, vol in zip(sections, vols):
            wp.launch(self.set_float, dim=self.vol.shape, inputs=[self.vol, offset, offset + section, vol], device=self.vol.device)
            offset += section

    def update_rho(self, sections: list[int], rhos: list[float]) -> None:
        offset = 0
        for section, rho in zip(sections, rhos):
            wp.launch(self.set_float, dim=self.rho.shape, inputs=[self.rho, offset, offset + section, rho], device=self.rho.device)
            offset += section

    def update_clip_bound(self, sections: list[int], clip_bounds: list[float]) -> None:
        offset = 0
        for section, clip_bound in zip(sections, clip_bounds):
            wp.launch(self.set_float, dim=self.clip_bound.shape, inputs=[self.clip_bound, offset, offset + section, clip_bound], device=self.clip_bound.device)
            offset += section

    def update_enabled(self, sections: list[int], spans: list[tuple[int, int]], step: int = 0) -> None:
        offset = 0
        for section, span in zip(sections, spans):
            enabled = 1 if (span[0] <= step < span[1]) else 0
            wp.launch(self.set_int, dim=self.enabled.shape, inputs=[self.enabled, offset, offset + section, enabled], device=self.enabled.device)
            offset += section


@wp.struct
class MPMParticleData(object):

    x: wp.array(dtype=wp.vec3)
    v: wp.array(dtype=wp.vec3)
    C: wp.array(dtype=wp.mat33)
    F: wp.array(dtype=wp.mat33)
    stress: wp.array(dtype=wp.mat33)

    def init(self, shape: Union[Sequence[int], int], device: wp.context.Devicelike = None, requires_grad: bool = False) -> None:

        self.x = wp.zeros(shape=shape, dtype=wp.vec3, device=device, requires_grad=requires_grad)
        self.v = wp.zeros(shape=shape, dtype=wp.vec3, device=device, requires_grad=requires_grad)
        self.C = wp.zeros(shape=shape, dtype=wp.mat33, device=device, requires_grad=requires_grad)
        self.F = wp.empty(shape=shape, dtype=wp.mat33, device=device, requires_grad=requires_grad)
        self.stress = wp.zeros(shape=shape, dtype=wp.mat33, device=device, requires_grad=requires_grad)

        # initialize F
        wp.launch(self.init_F, dim=self.F.shape, inputs=[self.F], device=self.F.device)

    def clear(self) -> None:
        self.x.zero_()
        self.v.zero_()
        self.C.zero_()
        self.stress.zero_()

        # initialize F
        wp.launch(self.init_F, dim=self.F.shape, inputs=[self.F], device=self.F.device)

    @staticmethod
    @wp.kernel
    def init_F(F: wp.array(dtype=wp.mat33)) -> None:

        p = wp.tid()

        # to avoid the not implemented adj_mat33 for mat33 plain constructor
        I33_1 = wp.vec3(1.0, 0.0, 0.0)
        I33_2 = wp.vec3(0.0, 1.0, 0.0)
        I33_3 = wp.vec3(0.0, 0.0, 1.0)
        I33 = wp.mat33(I33_1, I33_2, I33_3)

        F[p] = I33

    def zero_grad(self) -> None:
        if self.x.requires_grad:
            self.x.grad.zero_()
        if self.v.requires_grad:
            self.v.grad.zero_()
        if self.C.requires_grad:
            self.C.grad.zero_()
        if self.F.requires_grad:
            self.F.grad.zero_()
        if self.stress.requires_grad:
            self.stress.grad.zero_()


@wp.struct
class MPMGridData(object):

    v: wp.array(dtype=wp.vec3, ndim=3)
    mv: wp.array(dtype=wp.vec3, ndim=3)
    m: wp.array(dtype=float, ndim=3)

    def init(self, shape: Union[Sequence[int], int], device: wp.context.Devicelike = None, requires_grad: bool = False) -> None:

        self.v = wp.zeros(shape=shape, dtype=wp.vec3, ndim=3, device=device, requires_grad=requires_grad)
        self.mv = wp.zeros(shape=shape, dtype=wp.vec3, ndim=3, device=device, requires_grad=requires_grad)
        self.m = wp.zeros(shape=shape, dtype=float, ndim=3, device=device, requires_grad=requires_grad)

    def clear(self) -> None:
        self.v.zero_()
        self.mv.zero_()
        self.m.zero_()

    def zero_grad(self) -> None:
        if self.v.requires_grad:
            self.v.grad.zero_()
        if self.mv.requires_grad:
            self.mv.grad.zero_()
        if self.m.requires_grad:
            self.m.grad.zero_()


@wp.struct
class MPMConstant(object):

    num_grids: int
    dt: float
    bound: int
    gravity: wp.vec3
    dx: float
    inv_dx: float
    eps: float


class MPMState(State):

    def __init__(
            self,
            shape: int,
            device: wp.context.Devicelike = None,
            requires_grad: bool = False) -> None:

        super().__init__(shape, device, requires_grad)

        particle = MPMParticleData()
        particle.init(shape, device, requires_grad)
        self.particle = particle

    def zero_grad(self) -> None:
        self.particle.zero_grad()

    def clear(self) -> None:
        self.particle.clear()

    def to_torch(self) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        x = wp.to_torch(self.particle.x).requires_grad_(self.particle.x.requires_grad)
        v = wp.to_torch(self.particle.v).requires_grad_(self.particle.v.requires_grad)
        C = wp.to_torch(self.particle.C).requires_grad_(self.particle.C.requires_grad)
        F = wp.to_torch(self.particle.F).requires_grad_(self.particle.F.requires_grad)
        stress = wp.to_torch(self.particle.stress).requires_grad_(self.particle.stress.requires_grad)
        return x, v, C, F, stress

    def to_torch_grad(self) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        grad_x = wp.to_torch(self.particle.x.grad) if self.particle.x.grad is not None else None
        grad_v = wp.to_torch(self.particle.v.grad) if self.particle.v.grad is not None else None
        grad_C = wp.to_torch(self.particle.C.grad) if self.particle.C.grad is not None else None
        grad_F = wp.to_torch(self.particle.F.grad) if self.particle.F.grad is not None else None
        grad_stress = wp.to_torch(self.particle.stress.grad) if self.particle.stress.grad is not None else None
        return grad_x, grad_v, grad_C, grad_F, grad_stress

    def from_torch(
            self,
            x: Optional[Tensor] = None,
            v: Optional[Tensor] = None,
            C: Optional[Tensor] = None,
            F: Optional[Tensor] = None,
            stress: Optional[Tensor] = None) -> None:

        if x is not None:
            self.particle.x = wp.from_torch(x.contiguous(), dtype=wp.vec3)
        if v is not None:
            self.particle.v = wp.from_torch(v.contiguous(), dtype=wp.vec3)
        if C is not None:
            self.particle.C = wp.from_torch(C.contiguous(), dtype=wp.mat33)
        if F is not None:
            self.particle.F = wp.from_torch(F.contiguous(), dtype=wp.mat33)
        if stress is not None:
            self.particle.stress = wp.from_torch(stress.contiguous(), dtype=wp.mat33)

    def from_torch_grad(
            self,
            grad_x: Optional[Tensor] = None,
            grad_v: Optional[Tensor] = None,
            grad_C: Optional[Tensor] = None,
            grad_F: Optional[Tensor] = None,
            grad_stress: Optional[Tensor] = None) -> None:

        if grad_x is not None:
            self.particle.x.grad = wp.from_torch(grad_x.contiguous(), dtype=wp.vec3)
        if grad_v is not None:
            self.particle.v.grad = wp.from_torch(grad_v.contiguous(), dtype=wp.vec3)
        if grad_C is not None:
            self.particle.C.grad = wp.from_torch(grad_C.contiguous(), dtype=wp.mat33)
        if grad_F is not None:
            self.particle.F.grad = wp.from_torch(grad_F.contiguous(), dtype=wp.mat33)
        if grad_stress is not None:
            self.particle.stress.grad = wp.from_torch(grad_stress.contiguous(), dtype=wp.mat33)


class MPMModel(Model):

    ConstantType = MPMConstant
    StaticsType = MPMStatics
    StateType = MPMState

    def __init__(self, constant: ConstantType, device: wp.context.Devicelike = None, requires_grad: bool = False) -> None:
        super().__init__(constant, device)
        self.requires_grad = requires_grad

        shape = (self.constant.num_grids, self.constant.num_grids, self.constant.num_grids)
        grid = MPMGridData()
        grid.init(shape, device, requires_grad)
        self.grid = grid

    def forward_extra(self, statics: MPMStatics, state: MPMState, statics_extra: MPMStatics, state_extra: MPMState) -> None:

        device = self.device
        constant = self.constant
        particle = state.particle
        particle_extra = state_extra.particle
        grid = self.grid

        num_grids = constant.num_grids
        num_particles = particle.x.shape[0]
        num_particles_extra = particle_extra.x.shape[0]

        grid.clear()
        grid.zero_grad()

        wp.launch(self.p2g, dim=num_particles, inputs=[constant, statics, particle, grid], device=device)
        wp.launch(self.grid_op, dim=[num_grids] * 3, inputs=[constant, grid], device=device)
        wp.launch(self.g2p, dim=num_particles_extra, inputs=[constant, statics_extra, particle_extra, particle_extra, grid], device=device)

    def forward(self, statics: MPMStatics, state_curr: MPMState, state_next: MPMState, tape: Optional[Tape] = None) -> None:

        device = self.device
        constant = self.constant
        particle_curr = state_curr.particle
        particle_next = state_next.particle
        grid = self.grid

        num_grids = constant.num_grids
        num_particles = particle_curr.x.shape[0]

        grid.clear()
        grid.zero_grad()

        wp.launch(self.p2g, dim=num_particles, inputs=[constant, statics, particle_curr, grid], device=device)
        wp.launch(self.grid_op, dim=[num_grids] * 3, inputs=[constant, grid], device=device)

        with CondTape(tape, self.requires_grad):
            wp.launch(self.g2p, dim=num_particles, inputs=[constant, statics, particle_curr, particle_next, grid], device=device)

    def backward(self, statics: MPMStatics, state_curr: MPMState, state_next: MPMState, tape: Tape) -> None:

        device = self.device
        constant = self.constant
        particle_curr = state_curr.particle
        grid = self.grid

        num_grids = constant.num_grids
        num_particles = particle_curr.x.shape[0]

        grid.clear()
        grid.zero_grad()

        local_tape = Tape()
        with local_tape:
            wp.launch(self.p2g, dim=num_particles, inputs=[constant, statics, particle_curr, grid], device=device)
            wp.launch(self.grid_op, dim=[num_grids] * 3, inputs=[constant, grid], device=device)

        tape.backward()

        local_tape.backward()

    @staticmethod
    @wp.kernel
    def p2g(
            constant: ConstantType,
            statics: StaticsType,
            particle_curr: MPMParticleData,
            grid: MPMGridData) -> None:

        p = wp.tid()

        if statics.enabled[p] == 0:
            return

        p_mass = statics.vol[p] * statics.rho[p]

        p_x = particle_curr.x[p] * constant.inv_dx
        base_x = int(p_x[0] - 0.5)
        base_y = int(p_x[1] - 0.5)
        base_z = int(p_x[2] - 0.5)
        f_x = p_x - wp.vec3(
            float(base_x),
            float(base_y),
            float(base_z))

        # quadratic kernels  [http://mpm.graphics   Eqn. 123, with x=fx,fx-1,fx-2]
        wa = wp.vec3(1.5) - f_x
        wb = f_x - wp.vec3(1.0)
        wc = f_x - wp.vec3(0.5)

        # wp.mat33(col_vec, col_vec, col_vec)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.75) - wp.cw_mul(wb, wb),
            wp.cw_mul(wc, wc) * 0.5,
        )

        stress = (-constant.dt * statics.vol[p] * 4.0 * constant.inv_dx * constant.inv_dx) * particle_curr.stress[p]
        affine = stress + p_mass * particle_curr.C[p]

        for i in range(3):
            for j in range(3):
                for k in range(3):

                    offset = wp.vec3(float(i), float(j), float(k))
                    dpos = (offset - f_x) * constant.dx
                    weight = w[0, i] * w[1, j] * w[2, k]
                    mv = weight * (p_mass * particle_curr.v[p] + affine * dpos)
                    m = weight * p_mass

                    wp.atomic_add(grid.mv, base_x + i, base_y + j, base_z + k, mv)
                    wp.atomic_add(grid.m, base_x + i, base_y + j, base_z + k, m)

    @staticmethod
    @wp.kernel
    def grid_op_freeslip(
            constant: ConstantType,
            grid: MPMGridData) -> None:

        px, py, pz = wp.tid()

        v = wp.vec3(0.0)
        if grid.m[px, py, pz] > 0.0:
            v = grid.mv[px, py, pz] / (grid.m[px, py, pz] + constant.eps) + constant.gravity * constant.dt
        else:
            v = constant.gravity * constant.dt

        if px < constant.bound and v[0] < 0.0:
            v = wp.vec3(0.0, v[1], v[2])
        if py < constant.bound and v[1] < 0.0:
            v = wp.vec3(v[0], 0.0, v[2])
        if pz < constant.bound and v[2] < 0.0:
            v = wp.vec3(v[0], v[1], 0.0)
        if px >= constant.num_grids - constant.bound and v[0] > 0.0:
            v = wp.vec3(0.0, v[1], v[2])
        if py >= constant.num_grids - constant.bound and v[1] > 0.0:
            v = wp.vec3(v[0], 0.0, v[2])
        if pz >= constant.num_grids - constant.bound and v[2] > 0.0:
            v = wp.vec3(v[0], v[1], 0.0)

        grid.v[px, py, pz] = v

    @staticmethod
    @wp.kernel
    def grid_op_noslip(
            constant: ConstantType,
            grid: MPMGridData) -> None:

        px, py, pz = wp.tid()

        v = wp.vec3(0.0)
        if grid.m[px, py, pz] > 0.0:
            v = grid.mv[px, py, pz] / (grid.m[px, py, pz] + constant.eps) + constant.gravity * constant.dt
        else:
            v = constant.gravity * constant.dt

        if px < constant.bound and v[0] < 0.0:
            v = wp.vec3(0.0)
        if py < constant.bound and v[1] < 0.0:
            v = wp.vec3(0.0)
        if pz < constant.bound and v[2] < 0.0:
            v = wp.vec3(0.0)
        if px >= constant.num_grids - constant.bound and v[0] > 0.0:
            v = wp.vec3(0.0)
        if py >= constant.num_grids - constant.bound and v[1] > 0.0:
            v = wp.vec3(0.0)
        if pz >= constant.num_grids - constant.bound and v[2] > 0.0:
            v = wp.vec3(0.0)

        grid.v[px, py, pz] = v


    @staticmethod
    @wp.kernel
    def g2p(
            constant: ConstantType,
            statics: StaticsType,
            particle_curr: MPMParticleData,
            particle_next: MPMParticleData,
            grid: MPMGridData) -> None:

        p = wp.tid()

        if statics.enabled[p] == 0:
            return

        p_x = particle_curr.x[p] * constant.inv_dx
        base_x = int(p_x[0] - 0.5)
        base_y = int(p_x[1] - 0.5)
        base_z = int(p_x[2] - 0.5)
        f_x = p_x - wp.vec3(
            float(base_x),
            float(base_y),
            float(base_z))

        # quadratic kernels  [http://mpm.graphics   Eqn. 123, with x=fx,fx-1,fx-2]
        wa = wp.vec3(1.5) - f_x
        wb = f_x - wp.vec3(1.0)
        wc = f_x - wp.vec3(0.5)

        # wp.mat33(col_vec, col_vec, col_vec)
        w = wp.mat33(
            wp.cw_mul(wa, wa) * 0.5,
            wp.vec3(0.75) - wp.cw_mul(wb, wb),
            wp.cw_mul(wc, wc) * 0.5,
        )

        new_v = wp.vec3(0.0)
        new_C = wp.mat33(new_v, new_v, new_v)

        for i in range(3):
            for j in range(3):
                for k in range(3):

                    offset = wp.vec3(float(i), float(j), float(k))
                    dpos = (offset - f_x) * constant.dx
                    weight = w[0, i] * w[1, j] * w[2, k]

                    v = grid.v[base_x + i, base_y + j, base_z + k]
                    new_v = new_v + weight * v
                    new_C = new_C + (4.0 * weight * constant.inv_dx * constant.inv_dx) * wp.outer(v, dpos)

        # to avoid the not implemented adj_mat33 for mat33 plain constructor
        I33_1 = wp.vec3(1.0, 0.0, 0.0)
        I33_2 = wp.vec3(0.0, 1.0, 0.0)
        I33_3 = wp.vec3(0.0, 0.0, 1.0)
        I33 = wp.mat33(I33_1, I33_2, I33_3)
        particle_next.v[p] = new_v
        particle_next.C[p] = new_C
        particle_next.F[p] = (I33 + constant.dt * new_C) * particle_curr.F[p]

        bound = statics.clip_bound[p] * constant.dx
        new_x = particle_curr.x[p] + constant.dt * new_v
        new_x = wp.vec3(
            wp.clamp(new_x[0], 0.0 + bound, 1.0 - bound),
            wp.clamp(new_x[1], 0.0 + bound, 1.0 - bound),
            wp.clamp(new_x[2], 0.0 + bound, 1.0 - bound),
        )
        particle_next.x[p] = new_x


class MPMModelBuilder(ModelBuilder):

    StateType = MPMState
    ConstantType = MPMConstant
    ModelType = MPMModel

    def parse_cfg(self, cfg: DictConfig) -> 'MPMModelBuilder':

        num_grids: int = cfg.num_grids
        dt: float = cfg.dt
        bound: int = cfg.bound
        gravity: np.ndarray = np.array(cfg.gravity, dtype=np.float32)
        bc: str = cfg.bc
        eps: float = cfg.eps

        dx: float = 1 / num_grids
        inv_dx: float = float(num_grids)

        self.config['num_grids'] = num_grids
        self.config['dt'] = dt
        self.config['bound'] = bound
        self.config['gravity'] = gravity
        self.config['dx'] = dx
        self.config['inv_dx'] = inv_dx
        self.config['bc'] = bc
        self.config['eps'] = eps

        return self

    def build_constant(self) -> ConstantType:

        constant = super().build_constant()
        constant.num_grids = self.config['num_grids']
        constant.dt = self.config['dt']
        constant.bound = self.config['bound']
        constant.gravity = wp.vec3(*self.config['gravity'])
        constant.dx = self.config['dx']
        constant.inv_dx = self.config['inv_dx']
        constant.eps = self.config['eps']

        return constant

    def finalize(self, device: wp.context.Devicelike = None, requires_grad: bool = False) -> ModelType:
        model = super().finalize(device, requires_grad)
        if self.config['bc'] == 'freeslip':
            model.grid_op = model.grid_op_freeslip
        elif self.config['bc'] == 'noslip':
            model.grid_op = model.grid_op_noslip
        else:
            raise ValueError('invalid boundary condition: {}'.format(self.config['bc']))
        return model


@dataclass
class MPMInitData(object):

    rho: float
    clip_bound: float
    span: tuple[int, int]

    num_particles: int
    vol: float

    pos: np.ndarray
    lin_vel: np.ndarray = np.zeros(3)
    ang_vel: np.ndarray = np.zeros(3)
    center: Optional[np.ndarray] = None
    ind_vel: Optional[np.ndarray] = None
    bounds: Optional[np.ndarray] = None
    size: Optional[np.ndarray] = None

    def __post_init__(self) -> None:
        if self.center is None:
            self.center = self.pos.mean(0)

    @staticmethod
    def alignment(
        min_bound_1: np.ndarray,
        max_bound_1: np.ndarray,
        min_bound_2: np.ndarray,
        max_bound_2: np.ndarray
    ):
        """Calculate the translation and scale factor to transform bound1 to bound2."""
        # calculate the center of the bounding box
        center_1 = (min_bound_1 + max_bound_1) / 2
        center_2 = (min_bound_2 + max_bound_2) / 2

        # calculate the scale factor along each axis
        scale_factor = (max_bound_2 - min_bound_2) / (max_bound_1 - min_bound_1)

        # calculate the translation
        translation = center_2 - center_1 * scale_factor

        return scale_factor, translation

    @classmethod
    def get(cls, cfg: DictConfig) -> 'MPMInitData':
        kwargs = cls.get_pcd(
            cfg.shape.name,
            cfg.shape.asset_root,
            cfg.shape.sort,
            cfg.shape.get('ori_bounds'),
            cfg.shape.get('sim_bounds'),
        )
        return cls(rho=cfg.rho, clip_bound=cfg.clip_bound, span=cfg.span, **kwargs)

    @classmethod
    def get_pcd(
            cls,
            name: str,
            asset_root: Optional[str],
            sort: Optional[int] = None,
            ori_bounds: Optional[list] = None,
            sim_bounds: Optional[list] = None
        ) -> dict[str, Any]:

        if ori_bounds is not None:
            ori_bounds = np.array(ori_bounds)
        if sim_bounds is not None:
            sim_bounds = np.array(sim_bounds)

        assert ori_bounds is not None, "ori_bounds must be provided for pcd shape."
        assert sim_bounds is not None, "sim_bounds must be provided for pcd shape."

        if asset_root is None:
            asset_root = Path(__file__).resolve().parent.parent.parent.parent / "experiments" / 'assets'
        else:
            asset_root = Path(asset_root)
        precompute_name = f'{name}'
        precompute_name += '.npz'

        if (asset_root / precompute_name).is_file():
            file = np.load(asset_root / precompute_name)
            p_x = file['p_x']
            vol = file['vol']
        else:
            import trimesh

            pcd_path = (asset_root / f'{name}.ply').as_posix()
            pcd: trimesh.PointCloud = trimesh.load(pcd_path)
            p_x = np.array(pcd.vertices).copy()

            if sort is not None:
                indices = np.array(list(sorted(range(p_x.shape[0]), reverse=True, key=lambda x: p_x[:, sort][x])))
                p_x = p_x[indices]
            
            mesh_path = [_ for _ in Path(pcd_path).parent.glob("mesh.*")]
            if len(mesh_path) == 1:
                mesh: trimesh.Trimesh = trimesh.load(mesh_path[0].as_posix(), force_mask=True)
                vol = mesh.volume / p_x.shape[0]
            else:
                vol = pcd.convex_hull.volume / p_x.shape[0]
                print(f'  WARNING: mesh file not found, using convex hull volume.')
            np.savez(asset_root / precompute_name, p_x=p_x, vol=vol)

        bbmin = ori_bounds[0]
        bbmax = ori_bounds[1]
        sim_bbmin = sim_bounds[0]
        sim_bbmax = sim_bounds[1]
        size, center = cls.alignment(bbmin, bbmax, sim_bbmin, sim_bbmax)

        vol = vol * np.prod(size)
        p_x = p_x * size + center
        p_x = np.ascontiguousarray(p_x.reshape(-1, 3))
        x_min, x_max = p_x[:, 0].min(), p_x[:, 0].max()
        y_min, y_max = p_x[:, 1].min(), p_x[:, 1].max()
        z_min, z_max = p_x[:, 2].min(), p_x[:, 2].max()
        print(f'  [pcd] | num_points: {p_x.shape[0]}')
        print(f"  x: [{x_min}, {x_max}]")
        print(f"  y: [{y_min}, {y_max}]")
        print(f"  z: [{z_min}, {z_max}]")

        assert x_min >= 0.0 and x_max <= 1.0
        assert y_min >= 0.0 and y_max <= 1.0
        assert z_min >= 0.0 and z_max <= 1.0

        return dict(num_particles=p_x.shape[0], vol=vol, pos=p_x, center=center, size=size)
    
    def set_lin_vel(self, value: Union[list, np.ndarray]) -> None:
        self.lin_vel = np.array(value)

    def zero_lin_vel(self) -> None:
        self.set_lin_vel(np.zeros_like(self.lin_vel))

    def set_ang_vel(self, value: Union[list, np.ndarray]) -> None:
        self.ang_vel = np.array(value)

    def zero_ang_vel(self) -> None:
        self.set_ang_vel(np.zeros_like(self.ang_vel))

    def set_ind_vel(self, ind_vel: np.ndarray) -> None:
        self.ind_vel = np.array(ind_vel)


class MPMStateInitializer(StateInitializer):

    StateType = MPMState
    ModelType = MPMModel

    def __init__(self, model: ModelType) -> None:
        super().__init__(model)
        self.groups: list[MPMInitData] = []

    def add_group(self, group: MPMInitData) -> None:
        self.groups.append(group)

    def finalize(self) -> tuple[StateType, list[int]]:

        pos_groups = []
        vel_groups = []
        sections = []

        for group in self.groups:
            pos = group.pos.copy()

            if group.ind_vel is None:
                lin_vel = group.lin_vel.copy()
                ang_vel = group.ang_vel.copy()
                vel = lin_vel + np.cross(ang_vel, pos - group.center)
            else:
                vel = group.ind_vel.copy()

            pos_groups.append(pos)
            vel_groups.append(vel)
            sections.append(group.num_particles)

        pos_groups = np.concatenate(pos_groups, axis=0)
        vel_groups = np.concatenate(vel_groups, axis=0)

        state_0 = super().finalize(shape=pos_groups.shape[0], requires_grad=False)

        state_0.particle.x.assign(pos_groups)
        state_0.particle.v.assign(vel_groups)

        return state_0, sections


class MPMStaticsInitializer(StaticsInitializer):

    StaticsType = MPMStatics
    ModelType = MPMModel

    def __init__(self, model: ModelType) -> None:
        super().__init__(model)
        self.groups: list[MPMInitData] = []

        self.sections: list[int] = []
        self.vols: list[float] = []
        self.rhos: list[float] = []
        self.clip_bounds: list[float] = []
        self.spans: list[tuple[int, int]] = []

    def update(self, statics: StaticsType, step: int = 0) -> None:

        statics.update_enabled(self.sections, self.spans, step=step)

    def add_group(self, group: MPMInitData) -> None:
        self.groups.append(group)

    def finalize(self) -> StaticsType:

        for group in self.groups:
            self.sections.append(group.num_particles)

            self.vols.append(group.vol)
            self.rhos.append(group.rho)
            self.clip_bounds.append(group.clip_bound)
            self.spans.append(group.span)

        statics = super().finalize(shape=sum(self.sections))
        statics.update_vol(self.sections, self.vols)
        statics.update_rho(self.sections, self.rhos)
        statics.update_clip_bound(self.sections, self.clip_bounds)
        self.update(statics, step=0)

        return statics


@dataclass
class MPMInitDataGaussianOld(object):     ######TODO!! 526

    rho: float
    clip_bound: float
    span: tuple[int, int]

    num_particles: int
    vol: float

    pos: np.ndarray
    lin_vel: np.ndarray = np.zeros(3)
    ang_vel: np.ndarray = np.zeros(3)
    center: Optional[np.ndarray] = None
    ind_vel: Optional[np.ndarray] = None
    bounds: Optional[np.ndarray] = None
    size: Optional[np.ndarray] = None

    def __post_init__(self) -> None:
        if self.center is None:
            self.center = self.pos.mean(0)

    @staticmethod
    def alignment(
        min_bound_1: np.ndarray,
        max_bound_1: np.ndarray,
        min_bound_2: np.ndarray,
        max_bound_2: np.ndarray
    ):
        """Calculate the translation and scale factor to transform bound1 to bound2."""
        # calculate the center of the bounding box
        center_1 = (min_bound_1 + max_bound_1) / 2
        center_2 = (min_bound_2 + max_bound_2) / 2

        # calculate the scale factor along each axis
        scale_factor = (max_bound_2 - min_bound_2) / (max_bound_1 - min_bound_1)

        # calculate the translation
        translation = center_2 - center_1 * scale_factor

        return scale_factor, translation

    @classmethod
    def get(cls, cfg: DictConfig) -> 'MPMInitDataGaussian':
        kwargs = cls.get_pcd(
            cfg.shape.name,
            cfg.shape.asset_root,
            cfg.shape.sort,
            cfg.shape.get('ori_bounds'),
            cfg.shape.get('sim_bounds'),
        )
        return cls(rho=cfg.rho, clip_bound=cfg.clip_bound, span=cfg.span, **kwargs)

    @classmethod
    def get_pcd(
            cls,
            name: str,
            asset_root: Optional[str],
            sort: Optional[int] = None,
            ori_bounds: Optional[list] = None,
            sim_bounds: Optional[list] = None
        ) -> dict[str, Any]:

        if ori_bounds is not None:
            ori_bounds = np.array(ori_bounds)
        if sim_bounds is not None:
            sim_bounds = np.array(sim_bounds)

        assert ori_bounds is not None, "ori_bounds must be provided for pcd shape."
        assert sim_bounds is not None, "sim_bounds must be provided for pcd shape."

        if asset_root is None:
            asset_root = Path(__file__).resolve().parent.parent.parent.parent / "experiments" / 'assets'
        else:
            asset_root = Path(asset_root)
        precompute_name = f'{name}'
        #print(f"precompute_name", precompute_name)
        precompute_name += '.npz'
        
        
        
        '''
        if (asset_root / precompute_name).is_file():
            file = np.load(asset_root / precompute_name)
            p_x = file['p_x']
            vol = file['vol']
        else:
            import trimesh

            pcd_path = (asset_root / f'{name}.ply').as_posix()
            pcd: trimesh.PointCloud = trimesh.load(pcd_path)
            
            gaussians_path = [_ for _ in Path(pcd_path).parent.glob("kernels.*")]
            gaussians = GaussianModel(3)
            gaussians.load_ply(gaussians_path, requires_grad=False)
            
            #p_x = np.array(pcd.vertices).copy()
            p_x = gaussians.get_xyz
            if sort is not None:
                indices = np.array(list(sorted(range(p_x.shape[0]), reverse=True, key=lambda x: p_x[:, sort][x])))
                p_x = p_x[indices]
            
            mesh_path = [_ for _ in Path(pcd_path).parent.glob("mesh.*")]
            if len(mesh_path) == 1:
                mesh: trimesh.Trimesh = trimesh.load(mesh_path[0].as_posix(), force_mask=True)
                vol = mesh.volume / p_x.shape[0]
            else:
                vol = pcd.convex_hull.volume / p_x.shape[0]
                print(f'  WARNING: mesh file not found, using convex hull volume.')
            np.savez(asset_root / precompute_name, p_x=p_x, vol=vol)
            
        
            
        '''
        
        import trimesh

        pcd_path = (asset_root / f'{name}.ply').as_posix()
        pcd: trimesh.PointCloud = trimesh.load(pcd_path)
        
        object_name = name.split('/')[0]
        gaussians_path = (asset_root / object_name / 'kernels.ply').as_posix()
        gaussians = GaussianModel(3)
        gaussians.load_ply(gaussians_path, requires_grad=False)
        
        #p_x = np.array(pcd.vertices).copy()
        p_x = gaussians.get_xyz.cpu().numpy()
        if sort is not None:
            indices = np.array(list(sorted(range(p_x.shape[0]), reverse=True, key=lambda x: p_x[:, sort][x])))
            p_x = p_x[indices]
        
        '''
        #mesh_path = [_ for _ in Path(pcd_path).parent.glob("mesh.*")]
        #if len(mesh_path) == 1:
        #    mesh: trimesh.Trimesh = trimesh.load(mesh_path[0].as_posix(), force_mask=True)
        #    vol = mesh.volume / p_x.shape[0]
        '''
        
        vol = pcd.convex_hull.volume / p_x.shape[0]
        
        vol = 1e-6
        
        print("vol", vol)
        print(f'  WARNING: mesh file not found, using convex hull volume.')
        
        
        np.savez(asset_root / precompute_name, p_x=p_x, vol=vol)

        bbmin = ori_bounds[0]
        bbmax = ori_bounds[1]
        sim_bbmin = sim_bounds[0]
        sim_bbmax = sim_bounds[1]
        size, center = cls.alignment(bbmin, bbmax, sim_bbmin, sim_bbmax)
        
        print("size", size)
        print("center", center)

        vol = vol * np.prod(size)
        print("vol_new", vol)
        
        #vol = 5e-02
        
        p_x = p_x * size + center
        p_x = np.ascontiguousarray(p_x.reshape(-1, 3))
        x_min, x_max = p_x[:, 0].min(), p_x[:, 0].max()
        y_min, y_max = p_x[:, 1].min(), p_x[:, 1].max()
        z_min, z_max = p_x[:, 2].min(), p_x[:, 2].max()
        print(f'  [pcd] | num_points: {p_x.shape[0]}')
        print(f"  x: [{x_min}, {x_max}]")
        print(f"  y: [{y_min}, {y_max}]")
        print(f"  z: [{z_min}, {z_max}]")

        assert x_min >= 0.0 and x_max <= 1.0
        assert y_min >= 0.0 and y_max <= 1.0
        assert z_min >= 0.0 and z_max <= 1.0

        return dict(num_particles=p_x.shape[0], vol=vol, pos=p_x, center=center, size=size)
    
    def set_lin_vel(self, value: Union[list, np.ndarray]) -> None:
        self.lin_vel = np.array(value)

    def zero_lin_vel(self) -> None:
        self.set_lin_vel(np.zeros_like(self.lin_vel))

    def set_ang_vel(self, value: Union[list, np.ndarray]) -> None:
        self.ang_vel = np.array(value)

    def zero_ang_vel(self) -> None:
        self.set_ang_vel(np.zeros_like(self.ang_vel))

    def set_ind_vel(self, ind_vel: np.ndarray) -> None:
        self.ind_vel = np.array(ind_vel)



@dataclass
class MPMInitDataGaussian(object):
    # Class attributes remain the same
    rho: float
    clip_bound: float
    span: tuple[int, int]
    num_particles: int
    vol: float
    pos: np.ndarray
    lin_vel: np.ndarray = np.zeros(3)
    ang_vel: np.ndarray = np.zeros(3)
    center: Optional[np.ndarray] = None
    size: Optional[np.ndarray] = None
    ind_vel: Optional[np.ndarray] = None

    def __post_init__(self) -> None:
        if self.center is None:
            self.center = self.pos.mean(0)

    @staticmethod
    def alignment(
        min_bound_1: np.ndarray,
        max_bound_1: np.ndarray,
        min_bound_2: np.ndarray,
        max_bound_2: np.ndarray
    ):
        """
        Calculate the translation and scale factor to transform bound1 to bound2.
        This is the robust logic from your original class.
        """
        center_1 = (min_bound_1 + max_bound_1) / 2
        center_2 = (min_bound_2 + max_bound_2) / 2
        scale_factor = (max_bound_2 - min_bound_2) / (max_bound_1 - min_bound_1)
        translation = center_2 - center_1 * scale_factor
        return scale_factor, translation

    @classmethod
    def get_from_gaussians(cls, cfg: DictConfig, gaussians) -> 'MPMInitDataGaussian':
        """
        Initializes from a GaussianModel object, using explicit bounds from the config.
        """
        # 1. Get particle positions directly from the gaussians object
        p_x = gaussians.get_xyz.cpu().numpy()
        p_x = np.ascontiguousarray(p_x)

        # 2. Use user-defined bounds for coordinate transformation
        # ***** FIX: Read ori_bounds and sim_bounds from the config *****
        #ori_bounds = np.array(cfg.particle_bounds.ori_bounds)
        #sim_bounds = np.array(cfg.particle_bounds.sim_bounds)
        
        ori_bounds = np.array(cfg.shape.ori_bounds)
        sim_bounds = np.array(cfg.shape.sim_bounds)
        
        # Use the alignment method to get the correct scale and translation
        scale, translation = cls.alignment(
            ori_bounds[0], ori_bounds[1],
            sim_bounds[0], sim_bounds[1]
        )
        
        # Apply the transformation
        p_x = p_x * scale + translation
        
        # Ensure particles are clipped within the simulation bounds
        p_x = np.clip(p_x, sim_bounds[0], sim_bounds[1])
        
        # 3. Calculate volume
        # We scale the placeholder volume by the product of the scaling factors
        # to get a more reasonable estimate of the volume in the simulation space.
        vol_per_particle = 1e-6 * np.prod(scale)

        # 4. Gather all data
        kwargs = {
            "num_particles": p_x.shape[0],
            "vol": vol_per_particle,
            "pos": p_x,
            "center": translation,
            "size": scale
        }

        print(f'  [pcd] | Initialized {kwargs["num_particles"]} particles using explicit bounds.')
        
        return cls(rho=cfg.rho, clip_bound=cfg.clip_bound, span=cfg.span, **kwargs)

    # You can keep the other methods like set_lin_vel etc. if needed
    def set_lin_vel(self, value: Union[list, np.ndarray]) -> None:
        self.lin_vel = np.array(value)

    def zero_lin_vel(self) -> None:
        self.set_lin_vel(np.zeros_like(self.lin_vel))

    def set_ang_vel(self, value: Union[list, np.ndarray]) -> None:
        self.ang_vel = np.array(value)

    def zero_ang_vel(self) -> None:
        self.set_ang_vel(np.zeros_like(self.ang_vel))



  
        
class MPMStateInitializerGaussian(StateInitializer):

    StateType = MPMState
    ModelType = MPMModel

    def __init__(self, model: ModelType) -> None:
        super().__init__(model)
        self.groups: list[MPMInitDataGaussian] = []

    def add_group(self, group: MPMInitDataGaussian) -> None:
        self.groups.append(group)

    def finalize(self) -> tuple[StateType, list[int]]:

        pos_groups = []
        vel_groups = []
        sections = []

        for group in self.groups:
            pos = group.pos.copy()

            if group.ind_vel is None:
                lin_vel = group.lin_vel.copy()
                ang_vel = group.ang_vel.copy()
                vel = lin_vel + np.cross(ang_vel, pos - group.center)
            else:
                vel = group.ind_vel.copy()

            pos_groups.append(pos)
            vel_groups.append(vel)
            sections.append(group.num_particles)

        pos_groups = np.concatenate(pos_groups, axis=0)
        vel_groups = np.concatenate(vel_groups, axis=0)

        state_0 = super().finalize(shape=pos_groups.shape[0], requires_grad=False)

        state_0.particle.x.assign(pos_groups)
        state_0.particle.v.assign(vel_groups)

        return state_0, sections


class MPMStaticsInitializerGaussian(StaticsInitializer):

    StaticsType = MPMStatics
    ModelType = MPMModel

    def __init__(self, model: ModelType) -> None:
        super().__init__(model)
        self.groups: list[MPMInitDataGaussian] = []

        self.sections: list[int] = []
        self.vols: list[float] = []
        self.rhos: list[float] = []
        self.clip_bounds: list[float] = []
        self.spans: list[tuple[int, int]] = []

    def update(self, statics: StaticsType, step: int = 0) -> None:

        statics.update_enabled(self.sections, self.spans, step=step)

    def add_group(self, group: MPMInitDataGaussian) -> None:
        self.groups.append(group)

    def finalize(self) -> StaticsType:

        for group in self.groups:
            self.sections.append(group.num_particles)

            self.vols.append(group.vol)
            self.rhos.append(group.rho)
            self.clip_bounds.append(group.clip_bound)
            self.spans.append(group.span)

        statics = super().finalize(shape=sum(self.sections))
        statics.update_vol(self.sections, self.vols)
        statics.update_rho(self.sections, self.rhos)
        statics.update_clip_bound(self.sections, self.clip_bounds)
        self.update(statics, step=0)

        return statics
