import torch

import neural_mpm.util.interpolation as interp
from neural_mpm.util.metaparams import LOW, HIGH, DIM


def get_step_fn(coords, interp_fct, size, interaction_radius):
    def euler_step(grid_velocities, old_pos, types):
        ys = []
        pos = old_pos[..., :DIM]
        for x in grid_velocities.permute(3, 0, 1, 2, 4):
            interpolated_vel = interp.g2p(x, pos)
            if len(interpolated_vel.shape) == 2:
                interpolated_vel = interpolated_vel.unsqueeze(0)
            pos = pos + interpolated_vel
            ys.append((pos, interpolated_vel))

        particles = torch.stack([y[0] for y in ys]).permute(1, 0, 2, 3)
        vel = torch.stack([y[1] for y in ys]).permute(1, 0, 2, 3)

        particles = torch.where(
            types[:, None, :, None] > 0.0, particles, old_pos[:, None, :, :DIM]
        )
        vel = torch.where(types[:, None, :, None] > 0.0, vel, 0.0)

        particles = torch.clamp(particles, min=LOW, max=HIGH)

        # perm = torch.randperm(5141)
        # # idx = perm[:5141 // 3]
        # idx = perm[:1500]


        next_input = interp.create_grid_cluster_batch(
            coords,
            # torch.cat((particles[:, -1, idx], particles[:, -1, 5141:]), axis=1),
            # torch.cat((vel[:, -1, idx], vel[:, -1, 5141:]), axis=1),
            # torch.cat((types[:, idx], types[:, 5141:]), axis=1),
            particles[:, -1],
            vel[:, -1],
            types,
            interp_fct,
            size,
            interaction_radius,
        )

        particles = torch.cat((particles, vel), axis=-1)

        return particles, next_input

    return euler_step


# @torch.no_grad()
def unroll(
    model,
    init_state,
    coords,
    num_calls,
    gmean,
    gstd,
    types,
    size,
    interaction_radius=0.015,
    interp_fn=interp.linear,
    dim=2,
    requires_grad=False,
):
    euler_step = get_step_fn(
        coords, interp_fn, size, interaction_radius=interaction_radius
    )

    def step(state, i):
        (state, old_particles) = state
        state = (state - gmean) / gstd

        grid_preds = model(state)
        grid_velocities = grid_preds * gstd[None, ..., :dim] + gmean[None, ..., :dim]

        full_particles, next_input = euler_step(grid_velocities, old_particles, types)
        new_particles = full_particles[:, -1]

        return (next_input, new_particles), (
            grid_preds,
            full_particles,
            next_input,
            new_particles,
        )

    with torch.set_grad_enabled(requires_grad):
        carry = init_state
        ys = []
        for x in torch.arange(num_calls):
            carry, y = step(carry, x)

            # TODO: this
            if init_state[0].shape[-1] == 6:
                grid = torch.cat((carry[0], init_state[0][..., -2:]), axis=-1)
                carry = (grid, carry[1])
            ys.append(y)

        full_grids = torch.stack([y[0] for y in ys])
        full_particles = torch.stack([y[1] for y in ys]).permute(1, 0, 2, 3, 4)
        input_grids = torch.stack([y[2] for y in ys]).permute(1, 0, 2, 3, 4)

        bsize = init_state[0].shape[0]

        full_grids = full_grids.permute(1, 0, -2, 2, 3, -1)
        full_grids = full_grids.reshape(bsize, -1, *full_grids.shape[-3:])
        full_particles = full_particles.reshape(bsize, -1, *full_particles.shape[-2:])

    return full_grids, full_particles, input_grids
