import torch
import torch.nn.functional as F
from torch_cluster import grid_cluster
from torch_scatter import scatter
from neural_mpm.util.metaparams import LOW, HIGH


def linear(x, y, r=0.015):
    # w = 1 - torch.linalg.norm(x - y + 10e-10) / r
    w = 1 - torch.prod(torch.abs(x - y), dim=-1) / r
    # w = torch.where(w < 0.0, 0, w)
    return w


def exp_intp(x, y, r=0.015):
    d = (x - y) ** 2
    d = d.sum() / r**2
    d = torch.exp(-d)
    return d


def w_piecewise(x, y, r=0.015, sigma=1.0):
    # Define each piece of the function
    q = torch.linalg.norm(x - y) / r

    piece1 = lambda q: sigma * (
        (5 / 2 - q) ** 4 - 5 * (3 / 2 - q) ** 4 + 10 * (1 - q) ** 4
    )
    piece2 = lambda q: sigma * ((5 / 2 - q) ** 4 - 5 * (3 / 2 - q) ** 4)
    piece3 = lambda q: sigma * ((5 / 2 - q) ** 4)
    piece4 = lambda q: 0.0

    # Conditional logic for piecewise function
    return torch.cond(
        q < 1 / 2,
        lambda q: piece1(q),
        lambda q: torch.cond(
            q < 3 / 2,
            lambda q: piece2(q),
            lambda q: torch.cond(
                q < 5 / 2, lambda q: piece3(q), lambda q: piece4(q), q
            ),
            q,
        ),
        q,
    )


def p2g(grid_coords, positions, features, intp, interaction_radius=0.015):
    # grid_coords is now of shape (X, Y, 2), holding coordinates for each grid point
    X, Y, _ = grid_coords.shape
    N, M = features.shape

    # Flatten the grid coordinates for vectorization
    flat_grid_coords = grid_coords.reshape(-1, 2)  # Shape: (X*Y, 2)

    # Define a function to apply interpolation between a grid point and all position points
    def interpolate_at_point(grid_point):
        # Vectorize intp over position points
        weights = torch.vmap(lambda pos: intp(grid_point, pos, r=interaction_radius))(
            positions
        )  # Shape: (N,)
        weighted_features = weights[:, None] * features  # Shape: (N, M)
        # weighted_features = torch.where(weights[:, None] > 0, weighted_features, weighted_features.detach())

        feat = weighted_features.sum(axis=0)
        # out = torch.where(weights[:, None] > 0, feat, 0.0)
        norm = weights.sum() + (EPS := 10e-10)

        return feat / norm  # Shape: (M,)

    # Vectorize the above function over all grid points
    interpolated_grid = torch.vmap(interpolate_at_point)(
        flat_grid_coords
    )  # Shape: (X*Y, M)
    # Reshape to get the final grid with interpolated features
    final_grid = interpolated_grid.reshape(X, Y, M)

    return final_grid


def g2p(grid, coords):
    # print(f"grid: {grid.shape}, coords: {coords.shape}")
    coords = (coords - LOW) / (HIGH - LOW)
    # Normalize coords from [0, 1] to [-1, 1]
    coords = coords * 2 - 1

    # Reverse y axis
    # coords = coords * torch.tensor([1, -1], device=coords.device)

    if len(grid.shape) == 3:
        grid = grid.unsqueeze(0)
    grid = grid.permute(0, 3, 1, 2)

    # if len(coords.shape) == 2:
    #     coords = coords.unsqueeze(0)
    coords = coords[:, :, None, :]

    return (
        F.grid_sample(
            grid.real,
            coords.real,
            mode="bilinear",
            align_corners=False,
            padding_mode="zeros",
        )
        .squeeze()
        .transpose(-1, -2)
    ).real


def create_grid(
    grid_coords, positions, velocities, types, interp_fn, interaction_radius
):
    # Different particles have different int types, boundary are the only ones that have 0
    density = types.bool().int()
    num_particles = int(density.sum())
    num_wall = ~(types.bool())
    num_wall = int(num_wall.sum())

    pos = positions[:num_particles]
    vel = velocities[:num_particles]
    density = torch.cat(
        (density[:, None] / num_particles, (~(types.bool())).int()[:, None] / num_wall),
        axis=-1,
    )

    # wall_pos = positions[num_particles:]
    # wall_density = ~types.int()

    density_grid = p2g(
        grid_coords,
        positions,
        density,
        interp_fn,
        interaction_radius=interaction_radius,
    )
    vel_grid = p2g(
        grid_coords, pos, vel, interp_fn, interaction_radius=interaction_radius
    )

    grid = torch.cat((vel_grid, density_grid), axis=-1)

    return grid


def p2g_cluster_batch(
    grid_coords,
    positions,
    features,
    batch,
    intp,
    size,
    interaction_radius=0.015,
    normalize=False,
):
    X, Y, _ = grid_coords.shape
    N, M = features.shape  # Assume features are batched similarly: (B, N, M)
    B = int(batch.max())
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    s = torch.tensor([size, size, 1], device=device)
    start = torch.tensor([LOW, LOW, 0], device=device)
    end = torch.tensor([HIGH, HIGH, B], device=device)

    vox = grid_cluster(torch.cat((positions, batch[:, None]), axis=-1), s, start, end)
    flat_grid_coords = torch.tile(
        grid_coords.view(-1, 2), (B + 1, 1)
    )  # Shape: (X*Y, 2)

    lit_grid = flat_grid_coords[vox]
    weights = torch.vmap(intp, in_dims=(0, 0, None))(
        lit_grid, positions, interaction_radius
    )
    weights = torch.ones_like(weights)
    weighted_features = weights[:, None] * features

    weighted_features = torch.where(
        weights[:, None] > 0, weighted_features, weighted_features.detach()
    )
    feat = scatter(
        weighted_features,
        vox[:, None],
        dim=0,
        dim_size=flat_grid_coords.shape[0],
        reduce="sum",
    )

    if normalize:
        norm = scatter(weights, vox, dim=0, dim_size=flat_grid_coords.shape[0]) + 1e-10
        norm = norm[:, None]
    else:
        norm = 1.0

    final_grid = feat / norm
    final_grid = [
        final_grid[i : i + X * Y].view(X, Y, M)
        for i in range(0, (B + 1) * X * Y, X * Y)
    ]
    return final_grid


def create_grid_cluster_batch(
    grid_coords, positions, velocities, types, interp_fn, size, interaction_radius
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'


    B = len(positions)

    num_particles = [int(torch.where(types[b] > 0.0, 1.0, 0.0).sum()) for b in range(B)]

    particle_density = [torch.ones((n,), device=device) / n for b, n in zip(range(B), num_particles)]
    wall_density = [
        torch.ones((torch.count_nonzero(types[b] <= 0.0),), device=device)
        / torch.count_nonzero(types[b] <= 0.0)
        for b in range(B)
    ]

    particle_density = torch.cat(particle_density)
    wall_density = torch.cat(wall_density)

    particle_batch = torch.cat(
        [
            torch.ones_like(positions[b, : int(n), ..., 0]) * b
            for b, n in zip(range(B), num_particles)
        ],
        dim=0,
    )

    wall_batch = torch.cat(
        [
            torch.ones_like(positions[b, int(n) :, ..., 0]) * b
            for b, n in zip(range(B), num_particles)
        ],
        dim=0,
    )

    particle_pos = torch.cat(
        [positions[b, : int(n)] for b, n in zip(range(B), num_particles)], dim=0
    )

    particle_vel = torch.cat(
        [velocities[b, : int(n)] for b, n in zip(range(B), num_particles)], dim=0
    )

    wall_pos = torch.cat(
        [positions[b, int(n) :] for b, n in zip(range(B), num_particles)], dim=0
    )

    # TODO: multi-material
    particle_density = p2g_cluster_batch(
        grid_coords,
        particle_pos,
        particle_density[:, None],
        particle_batch,
        interp_fn,
        size,
        interaction_radius=interaction_radius,
        normalize=False,
    )
    wall_density = p2g_cluster_batch(
        grid_coords,
        wall_pos,
        wall_density[:, None],
        wall_batch,
        interp_fn,
        size,
        interaction_radius=interaction_radius,
        normalize=True,
    )

    density_grid = torch.cat(
        (torch.stack(particle_density), torch.stack(wall_density)), axis=-1
    )

    vel_grid = p2g_cluster_batch(
        grid_coords,
        particle_pos,
        particle_vel,
        particle_batch,
        interp_fn,
        size,
        interaction_radius=interaction_radius,
        normalize=True,
    )

    vel_grid = torch.stack(vel_grid)
    grid = torch.cat((vel_grid, density_grid), axis=-1)

    return grid
