"""GNS rollout function adapted from
https://github.com/wu375/simple-physics-simulator-pytorch-geometry.
"""

import torch


def eval_single_rollout(simulator, features, num_steps, device):
    initial_positions = features["enc_pos"]
    ground_truth_positions = features["target_pos"]
    dim = initial_positions.shape[-1]

    current_positions = initial_positions
    predictions = []
    for step in range(num_steps):
        next_position = simulator.predict_positions(
            current_positions,
            n_particles_per_example=features["n_particles_per_example"],
            particle_types=features["particle_type"],
        )  # (n_nodes, dim)
        # Update kinematic particles from prescribed trajectory.
        kinematic_mask = (features["particle_type"] == 3).clone().detach().to(device)
        next_position_ground_truth = ground_truth_positions[:, step]
        kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, dim)
        next_position = torch.where(kinematic_mask, next_position_ground_truth, next_position)
        predictions.append(next_position)
        current_positions = torch.cat([current_positions[:, 1:], next_position[:, None, :]], dim=1)
    predictions = torch.stack(predictions)  # (time, n_nodes, 2)
    ground_truth_positions = ground_truth_positions.permute(1, 0, 2)
    loss = (predictions - ground_truth_positions) ** 2
    output_dict = {
        "initial_positions": initial_positions.permute(1, 0, 2).cpu().numpy(),
        "predicted_rollout": predictions.cpu().numpy(),
        "ground_truth_rollout": ground_truth_positions.cpu().numpy(),
        "particle_types": features["particle_type"].cpu().numpy(),
    }
    return output_dict, loss


def eval_rollout(ds, simulator, num_steps, num_eval_steps=1, device="cuda"):
    eval_loss = []
    i = 0
    simulator.eval()
    with torch.no_grad():
        for example_i, (features, labels) in enumerate(ds):
            features["position"] = torch.tensor(features["position"]).to(
                device
            )  # (n_nodes, time, dim)
            features["n_particles_per_example"] = torch.tensor(
                features["n_particles_per_example"]
            ).to(device)
            features["particle_type"] = torch.tensor(features["particle_type"]).to(device)
            labels = torch.tensor(labels).to(device)
            example_rollout, loss = eval_single_rollout(simulator, features, num_steps, device)
            eval_loss.append(loss)
            print("Example: ", example_i)
            i += 1
            if i >= num_eval_steps:
                break
    simulator.train()
    return torch.stack(eval_loss).mean(0)
