import copy
from typing import Tuple

import torch
from torch_geometric.data import Batch, Data

from ltsgns_mp.algorithms.abstract_algorithm import AbstractAlgorithm
from ltsgns_mp.algorithms.mgn import MGN
from ltsgns_mp.algorithms.util import _update_external_state
from ltsgns_mp.architectures.loss_functions.mse import mse
from ltsgns_mp.envs.train_iterator.cnp_train_iterator import CNPTrainBatch
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import recompute_external_edges, add_and_update_node_features
from ltsgns_mp.util.graph_input_output_util import add_distances_from_positions, node_type_mask
from ltsgns_mp.util.own_types import ValueDict


class MGNCNPIterator(MGN):
    def _single_train_step(self, batch: CNPTrainBatch) -> torch.Tensor:
        """
        Train the model with a CNP batch. Have to unpack the batch and call the super method.
        :param batch:
        :return:
        """
        # unpack batch
        step_train_batch = StepTrainBatch(batch.target_batch)
        return super()._single_train_step(step_train_batch)

