import h5py
import torch

DT = 0.0025
DIM = 2
MATERIALS = {
    "goop": 7,  # Special.
    "sand": 6,  # Sand.
    "water": 5,  # Water.
}


def load_monomaterial(path, material):
    with h5py.File(path, "r") as f:
        particles = f["particles"][()]
        boundary = f["boundary"][()]

    particles = torch.tensor(particles, dtype=torch.float32)
    boundary = torch.tensor(boundary, dtype=torch.float32)

    # Time step
    particles[..., DIM:].multiply_(DT)

    if boundary.shape[-1] != 2 * DIM:
        boundary = torch.cat((boundary, torch.zeros_like(boundary)), axis=-1)
    else:
        boundary[..., DIM:].zero_()


    current_sim = torch.cat([particles, boundary], dim=1)
    current_types = torch.cat(
        [
            torch.ones(particles.shape[1]) * MATERIALS[material],
            torch.zeros(boundary.shape[1]),
        ]
    )

    vel = current_sim[1:, ..., :DIM] - current_sim[:-1, ..., :DIM]
    current_sim = current_sim[1:, ..., :DIM]
    current_sim = torch.cat((current_sim, vel), axis=-1)

    return (
        current_sim,
        current_types,
        None,
    )


def load_WBC(path, material=5):
    LOW = 0.0025
    HIGH = 0.9975

    with h5py.File(path, "r") as f:
        particles = f["particles"][()]
        boundary = f["boundary"][()]
        grav = f["grav"][()]

    particles = torch.tensor(particles, dtype=torch.float32)
    boundary = torch.tensor(boundary, dtype=torch.float32)
    grav = torch.tensor(grav, dtype=torch.float32)

    # Time step
    # particles[..., DIM:].multiply_(DT)
    boundary[..., DIM:].zero_()
    grav = grav * DT

    current_sim = torch.cat([particles, boundary], dim=1)
    current_types = torch.cat(
        [
            torch.ones(particles.shape[1]) * MATERIALS[material],
            torch.zeros(boundary.shape[1]),
        ]
    )
    # mn = current_sim[..., :2].reshape(-1, 2).min(dim=0).values
    # mx = current_sim[..., :2].reshape(-1, 2).max(dim=0).values
    # current_sim[..., :2] = (current_sim[..., :2] - mn) / (mx - mn) * (HIGH - LOW) + LOW
    current_sim[..., :DIM] = torch.clamp(current_sim[..., :DIM], LOW, HIGH)

    vel = current_sim[1:, ..., :DIM] - current_sim[:-1, ..., :DIM]
    current_sim = current_sim[1:, ..., :DIM]
    current_sim = torch.cat((current_sim, vel), axis=-1)

    return (
        current_sim,
        current_types,
        grav,
    )



def load_multimaterial(path, material=5):
    pass
