import copy

import torch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator, AbstractTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.envs.trajectory import Trajectory
from ltsgns_mp.util.graph_input_output_util import node_type_mask, add_gaussian_noise, remove_edge_distances, \
    add_distances_from_positions, add_label, add_and_update_node_features


class StepTrainBatch(AbstractTrainBatch):
    def __init__(self, batch):
        super().__init__()
        self.batch = batch


class StepTrainIterator(AbstractTrainIterator):
    def __init__(self, config, train_trajs, device):
        super().__init__(config, train_trajs, device)
        self._dataloader = None
        self._iterator = None
        self.set_dataloader(config, train_trajs)

    def set_dataloader(self, config, train_trajs):
        if config.anchor_index_as_feature:
            raise NotImplementedError("Anchor index as feature is not implemented for step iterator.")
        # build a list of all steps
        self._steps = []
        for traj in tqdm(train_trajs, desc="Preparing Step Train Iterator"):
            # ignore first and last step
            for graph in traj[1:-1]:
                self._steps.append(graph)
            # this graph has no velocity features and no label (y)
        # build a Dataloader from the steps list
        self._dataloader = DataLoader(self._steps, batch_size=config.batch_size, shuffle=True, )
        self._iterator = iter(self._dataloader)

    def __iter__(self):
        return self

    def __next__(self) -> StepTrainBatch:
        try:
            batch = copy.deepcopy(next(self._iterator))
            batch.to(self.device)

            if self.config.wrong_label_computation:
                # add noise to the positions and compute the edge features again
                batch = add_label(batch, self.config.second_order_dynamics)
                batch = add_gaussian_noise(batch,
                                           node_type_description=keys.MESH,
                                           sigma=self.config.input_mesh_noise,
                                           device=self.device)
            else:
                # add noise to the positions and compute the edge features again
                batch = add_gaussian_noise(batch,
                                           node_type_description=keys.MESH,
                                           sigma=self.config.input_mesh_noise,
                                           device=self.device)
                batch = add_label(batch, self.config.second_order_dynamics)

            batch = remove_edge_distances(batch)
            batch = add_distances_from_positions(batch, self.config.add_euclidian_distance)
            batch = add_and_update_node_features(batch, self.config.second_order_dynamics)
            return StepTrainBatch(batch)
        except StopIteration:
            # If the DataLoader's iterator is exhausted, reset it and raise StopIteration
            self.refresh_iterator()
            raise StopIteration

    def __len__(self):
        return len(self._dataloader)

    def refresh_iterator(self):
        self._iterator = iter(self._dataloader)