import copy

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.mgn import MGN
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import node_type_mask


def _modify_train_batch(batch: Batch, imputation_rate: float = 0.5) -> Batch:
    """
    Modifies a given batch to selectively remove point clouds from certain graphs for an imputation-based training
    approach.

    Args:
        batch (Batch): A torch geometric batch object containing various graph attributes.
        imputation_rate (float): The rate of graphs to modify. Defaults to 0.5.

    Returns:
        Batch: The modified batch object with point clouds removed from selected graphs.

    Notes:
        The function randomly selects graphs with a rate of 0.5 and removes point clouds from them.
        The point cloud node type is identified using `batch.node_type_description[0]` for the full batch.
    """
    # Define the imputation rate and generate a random mask to select which graphs to modify
    imputation_mask: torch.Tensor = torch.rand(len(batch.ptr) - 1) < imputation_rate
    imputation_mask = imputation_mask.to(batch.x.device)

    # Create a mask for point cloud nodes in graphs selected for imputation
    num_nodes_per_graph = batch.ptr[1:] - batch.ptr[:-1]
    graph_imputation_mask = torch.repeat_interleave(imputation_mask, num_nodes_per_graph)
    point_cloud_mask = node_type_mask(batch, keys.POINT_CLOUD)
    node_mask = graph_imputation_mask & point_cloud_mask  # Combine the two masks using logical AND

    # Invert the mask to identify nodes to keep
    keep_node_mask = ~node_mask

    # Create a mask for edges to keep based on kept nodes
    # Only keep an edge if both its source and destination nodes are kept
    keep_edge_mask = keep_node_mask[batch.edge_index[0]] & keep_node_mask[batch.edge_index[1]]
    new_edge_index = batch.edge_index[:, keep_edge_mask]

    # Next, create a mapping from old node indices to new node indices
    # This is needed to update the edge indexing because we remove nodes and thus change what the edges point to
    # Apply this mapping to the edge indices to update them
    node_idx_mapping = torch.zeros(len(batch.x), dtype=torch.long, device=batch.x.device)
    node_idx_mapping[keep_node_mask] = torch.arange(keep_node_mask.sum(), device=batch.x.device)
    new_edge_index = node_idx_mapping[new_edge_index]

    # Modify the batch to remove point clouds based on the mask
    # update nodes
    batch.batch = batch.batch[keep_node_mask]
    batch.x = batch.x[keep_node_mask]
    batch.pos = batch.pos[keep_node_mask]
    batch.node_type = batch.node_type[keep_node_mask]

    # update edges
    batch.edge_index = new_edge_index
    batch.edge_attr = batch.edge_attr[keep_edge_mask]
    batch.edge_type = batch.edge_type[keep_edge_mask]

    # Update batch.ptr
    num_nodes_kept_per_graph = torch.bincount(batch.batch.cpu(), minlength=len(batch.ptr) - 1).to(batch.x.device)
    batch.ptr = torch.cat([torch.tensor([0], device=batch.x.device), num_nodes_kept_per_graph.cumsum(0)])

    # Note that we do not update the batch.point_cloud attribute, as this is never used in the training process

    return batch


class GGNS(MGN):
    def _single_train_step(self, batch: StepTrainBatch) -> torch.Tensor:
        batch = _modify_train_batch(batch.batch, self.config.imputation_rate)
        return super()._single_train_step(StepTrainBatch(batch))

    def _update_external_state(self, batch_index: int, current_step: int, data: Data, last_mesh_pos: torch.Tensor | None = None) -> Data:
        """
        Update the external state of the graph by adding a point cloud to the graph if the current step is a multiple
        of the k parameter that is specified in the config as
        point_cloud:
            start_idx: 0   # which index to start the evaluation at, usually 0 for next-step GNS'
            stop_idx: null  # null means that we evaluate until the end
            step: 5  # evaluate every 5th step, here we have k=5
        Args:
            batch_index: Index of the batch to evaluate. Currently only supports a single batch with index 0
            current_step: Current step in the trajectory
            data: Data object containing the anchor graph and auxiliary information about the external trajectory state

        Returns:

        """
        # first, remove old point cloud
        old_point_cloud_mask = node_type_mask(data, keys.POINT_CLOUD)

        if sum(old_point_cloud_mask) > 0:
            data.x = data.x[~old_point_cloud_mask]
            data.pos = data.pos[~old_point_cloud_mask]
            data.node_type = data.node_type[~old_point_cloud_mask]

        if data.point_cloud_indices[batch_index, current_step]:  # we have a new point cloud to add, so do that
            # add new point cloud by finding the point cloud positions, removing nan-padding and adding one hot encoding
            current_point_cloud_positions = data[keys.CONTEXT_POINT_CLOUD_POSITIONS][batch_index, current_step]
            invalid_points = torch.isnan(current_point_cloud_positions).any(dim=1)  # removed invalid/padded points
            current_point_cloud_positions = current_point_cloud_positions[~invalid_points]

            point_cloud_one_hot_position = data.node_type_description.index(keys.POINT_CLOUD)
            point_cloud_x = torch.zeros((current_point_cloud_positions.shape[0], data.x.shape[1])).to(self._device)
            point_cloud_x[:, point_cloud_one_hot_position] = 1
            point_cloud_type = torch.full((current_point_cloud_positions.shape[0],), point_cloud_one_hot_position,
                                          dtype=torch.long).to(self._device)

            # add new point cloud to data
            data[keys.POINT_CLOUD] = current_point_cloud_positions
            data.x = torch.cat([data.x, point_cloud_x])
            data.pos = torch.cat([data.pos, current_point_cloud_positions])
            data.node_type = torch.cat([data.node_type, point_cloud_type])
        # print(f"Step: {current_step}, old point cloud: {sum(old_point_cloud_mask)>0}, "
        #       f"new point cloud: {data.point_cloud_indices[batch_index, current_step]}, data: {data}")

        data = super()._update_external_state(batch_index=batch_index,
                                              current_step=current_step,
                                              data=data,
                                              last_mesh_pos=last_mesh_pos)  # adapts collider position
        return data
