import copy
import random

import torch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator, AbstractTrainBatch
from ltsgns_mp.envs.util.convert_to_single_data_trajectory import compute_point_cloud_padding_size, build_idx_config
from ltsgns_mp.util import keys
from ltsgns_mp.envs.trajectory import Trajectory
from ltsgns_mp.util.graph_input_output_util import node_type_mask, add_gaussian_noise, remove_edge_distances, \
    add_distances_from_positions, add_label, add_and_update_node_features


class LTSGNSStepTrainBatch(AbstractTrainBatch):
    def __init__(self, context_batch, trajectory_batch):
        super().__init__()
        self.context_batch = context_batch
        self.trajectory_batch = trajectory_batch


class LTSGNSStepTrainIterator(AbstractTrainIterator):
    def __init__(self, config, train_trajs, device):
        super().__init__(config, train_trajs, device)
        if config.anchor_index_as_feature:
            raise NotImplementedError("Anchor index as feature is not implemented for step iterator.")
        self._tasks = []
        task_id = 0
        for traj in tqdm(train_trajs, desc="Preparing Step Train Iterator"):
            for auxiliary_task_idx in range(config.auxiliary_tasks_per_task):
                idx_config = build_idx_config(self.config, traj, context_type=keys.MESH)
                context_batch = []
                for context_idx in idx_config[keys.MESH]["indices"]:
                    context_graph = copy.deepcopy(traj[context_idx])
                    if keys.COLLIDER in context_graph.node_type_description:
                        next_graph = traj[context_idx + 1]
                        context_graph["next_collider_pos"] = next_graph["pos"][node_type_mask(next_graph, key=keys.COLLIDER)]
                    # don't include point clouds as this is kinda not possible to have PC context with step based LTSGNS
                    context_graph = remove_edge_distances(context_graph)
                    context_graph = add_distances_from_positions(context_graph, self.config.add_euclidian_distance)
                    context_graph = add_and_update_node_features(context_graph, self.config.second_order_dynamics)
                    # the label is next_mesh_pos and next_collider_pos
                    context_batch.append(context_graph)
                context_batch = Batch.from_data_list(context_batch)
                context_batch[keys.TASK_INDICES] = [task_id]
                context_batch = add_label(context_batch, self.config.second_order_dynamics)
                # NOTE: this is highly inefficient, but for the deformable plate ablation it should be fine
                trajectory_batch = traj.data_list
                trajectory_batch = Batch.from_data_list(trajectory_batch)
                trajectory_batch[keys.TASK_INDICES] = [task_id]
                self._tasks.append((context_batch, trajectory_batch))
                task_id += 1
        self._num_tasks = task_id
        # shuffle the task list and create an iterator
        random.shuffle(self._tasks)
        self._iterator = iter(self._tasks)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            context_batch, trajectory_batch = next(self._iterator)
            context_batch.to(self.device)
            trajectory_batch = copy.deepcopy(trajectory_batch)
            trajectory_batch.to(self.device)
            if self.config.wrong_label_computation:
                # add noise to the positions and compute the edge features again
                trajectory_batch = add_label(trajectory_batch, self.config.second_order_dynamics)
                trajectory_batch = add_gaussian_noise(trajectory_batch,
                                           node_type_description=keys.MESH,
                                           sigma=self.config.input_mesh_noise,
                                           device=self.device)
            else:
                # add noise to the positions and compute the edge features again
                trajectory_batch = add_gaussian_noise(trajectory_batch,
                                           node_type_description=keys.MESH,
                                           sigma=self.config.input_mesh_noise,
                                           device=self.device)
                trajectory_batch = add_label(trajectory_batch, self.config.second_order_dynamics)

            trajectory_batch = remove_edge_distances(trajectory_batch)
            trajectory_batch = add_distances_from_positions(trajectory_batch, self.config.add_euclidian_distance)
            trajectory_batch = add_and_update_node_features(trajectory_batch, self.config.second_order_dynamics)
            return LTSGNSStepTrainBatch(context_batch, trajectory_batch)
        except StopIteration:
            # If the DataLoader's iterator is exhausted, reset it and raise StopIteration
            self.refresh_iterator()
            raise StopIteration

    def __len__(self):
        return len(self._tasks)

    @property
    def num_tasks(self):
        return self._num_tasks

    def refresh_iterator(self):
        random.shuffle(self._tasks)
        self._iterator = iter(self._tasks)
