import copy

import torch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator, AbstractTrainBatch
from ltsgns_mp.envs.train_iterator.step_train_iterator import StepTrainIterator, StepTrainBatch
from ltsgns_mp.util import keys
from ltsgns_mp.envs.trajectory import Trajectory
from ltsgns_mp.util.graph_input_output_util import node_type_mask, add_gaussian_noise, remove_edge_distances, \
    add_distances_from_positions, add_label, add_and_update_node_features, add_history_gaussian_noise, \
    add_history_vel_feature


class HistoryStepTrainIterator(StepTrainIterator):
    def __init__(self, config, train_trajs, device):
        super().__init__(config, train_trajs, device)


    def set_dataloader(self, config, train_trajs):
        if config.anchor_index_as_feature:
            raise NotImplementedError("Anchor index as feature is not implemented for step iterator.")
        # build a list of all steps
        self._steps = []
        for traj in tqdm(train_trajs, desc="Preparing Step Train Iterator"):
            # ignore first and last step
            for idx in range(1, len(traj) - 1):
                graph = traj[idx]
                graph["history_pos"] = []
                for hist_idx in range(idx - 1, idx - config.history_length - 1, -1):
                    if hist_idx < 0:
                        hist_idx = 0
                    hist_graph = traj[hist_idx]
                    hist_pos = hist_graph.pos.clone()
                    graph["history_pos"].insert(0, hist_pos)
                    # order: oldest to newest
                if config.history_length > 0:
                    graph["history_pos"] = torch.stack(graph["history_pos"], dim=1)
                else:
                    graph["history_pos"] = torch.zeros((graph.pos.shape[0], 0, graph.pos.shape[1]), device=self.device)
                # shape of history_pos: [num_nodes, history_length, world_dim]
                self._steps.append(graph)
            # this graph has no velocity features and no label (y)
        # build a Dataloader from the steps list
        self._dataloader = DataLoader(self._steps, batch_size=config.batch_size, shuffle=True, )
        self._iterator = iter(self._dataloader)

    def __next__(self) -> StepTrainBatch:
        try:
            batch = copy.deepcopy(next(self._iterator))
            batch.to(self.device)

            if self.config.wrong_label_computation:
                # add noise to the positions and compute the edge features again
                batch = add_label(batch, self.config.second_order_dynamics)
                batch = add_history_gaussian_noise(batch,
                                                   node_type_description=keys.MESH,
                                                   sigma=self.config.input_mesh_noise,
                                                   device=self.device)
                batch = add_history_vel_feature(batch)
            else:
                # add noise to the positions and compute the edge features again
                batch = add_history_gaussian_noise(batch,
                                                   node_type_description=keys.MESH,
                                                   sigma=self.config.input_mesh_noise,
                                                   device=self.device)
                batch = add_history_vel_feature(batch)
                batch = add_label(batch, self.config.second_order_dynamics)

            batch = remove_edge_distances(batch)
            batch = add_distances_from_positions(batch, self.config.add_euclidian_distance)
            batch = add_and_update_node_features(batch, self.config.second_order_dynamics)
            return StepTrainBatch(batch)
        except StopIteration:
            # If the DataLoader's iterator is exhausted, reset it and raise StopIteration
            self.refresh_iterator()
            raise StopIteration
