import torch

from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import node_type_mask


def _update_external_state(batch_index, current_step, data, last_mesh_pos):
    if keys.COLLIDER in data.node_type_description:
        # insert ground truth collider
        collider_mask = node_type_mask(data, keys.COLLIDER)

        current_collider_positions = data[keys.CONTEXT_COLLIDER_POSITIONS][batch_index, current_step]
        data[keys.POSITIONS][collider_mask] = current_collider_positions
    if keys.MESH + "_x_pos" in data.x_description:
        # insert current predicted x position
        x_pos_index = data.x_description.index(keys.MESH + "_x_pos")
        x_pos = data[keys.POSITIONS][:, 0]
        data.x[:, x_pos_index] = x_pos
    if keys.MESH + "_y_pos" in data.x_description:
        # insert current predicted y position
        y_pos_index = data.x_description.index(keys.MESH + "_y_pos")
        y_pos = data[keys.POSITIONS][:, 1]
        data.x[:, y_pos_index] = y_pos
    if keys.MESH + "_z_pos" in data.x_description:
        # insert current predicted z position
        z_pos_index = data.x_description.index(keys.MESH + "_z_pos")
        z_pos = data[keys.POSITIONS][:, 2]
        data.x[:, z_pos_index] = z_pos
    if keys.VELOCITIES in data.x_description:
        assert last_mesh_pos is not None
        mesh_mask = node_type_mask(data, keys.MESH)
        # insert current predicted velocities
        velocities_indices = [index for index, value in enumerate(data.x_description) if value == keys.VELOCITIES]
        mesh_vel = data[keys.POSITIONS][mesh_mask] - last_mesh_pos
        all_vel = torch.zeros_like(data[keys.POSITIONS])
        all_vel[mesh_mask] = mesh_vel
        data.x[:, velocities_indices] = mesh_vel

    return data