import copy
from collections import defaultdict

import numpy as np
from omegaconf import OmegaConf
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator, AbstractTrainBatch
from ltsgns_mp.envs.util.convert_to_single_data_trajectory import convert_traj_to_data, \
    compute_point_cloud_padding_size, build_idx_config
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import remove_edge_distances, add_distances_from_positions, \
    add_gaussian_noise, add_and_update_node_features


class TrajectoryTrainBatch(AbstractTrainBatch):
    def __init__(self, batch):
        super().__init__()
        self.batch = batch


def _get_context_type(context_type):
    # decidce for each graph randomly when it's a mixed context
    if context_type == "mixed":
        context_type = np.random.choice([keys.MESH, keys.POINT_CLOUD])
    return context_type


class TrajectoryTrainIterator(AbstractTrainIterator):
    def __init__(self, config, train_trajs, device):
        super().__init__(config, train_trajs, device)
        self._auxiliary_trajectories = defaultdict(lambda: [])
        task_id = 0
        point_cloud_padding_size = compute_point_cloud_padding_size(train_trajs)
        print("Preparing Trajectory Train Iterator...")
        for traj in tqdm(train_trajs, desc="Preparing Trajectory Train Iterator", disable=True):
            for auxiliary_task_idx in range(config.auxiliary_tasks_per_task):
                context_type = _get_context_type(config.context_type)
                idx_config = build_idx_config(self.config, traj, context_type)
                context_size = self.get_context_size(idx_config, context_type)
                num_nodes = len(traj[0][keys.POSITIONS])
                data_traj = convert_traj_to_data(traj, idx_config, point_cloud_padding_size, context_type=context_type,
                                                 anchor_index_as_feature=config.anchor_index_as_feature)
                # compute the relative positional features already here, as they don't change in this setup
                data_traj = remove_edge_distances(data_traj)
                data_traj = add_distances_from_positions(data_traj, self.config.add_euclidian_distance)
                # add task id
                data_traj[keys.TASK_INDICES] = task_id
                self._auxiliary_trajectories[f"context_size_{context_size}_num_nodes_{num_nodes}_context_type_{context_type}"].append(data_traj)
                task_id += 1
        self._num_tasks = task_id
        # create 1 dataloader per key in the dict.
        self._dataloader = {}
        self._iterator = {}
        self.iter_contains_data = {}
        for key, traj_list in self._auxiliary_trajectories.items():
            self._dataloader[key] = DataLoader(traj_list, batch_size=config.batch_size, shuffle=True)
            self._iterator[key] = iter(self._dataloader[key])
            self.iter_contains_data[key] = True


    def __iter__(self):
        return self

    def __next__(self):
        while True:
            # check if all iterators are exhausted
            if not any(self.iter_contains_data.values()):
                # refresh all iterators
                self.refresh_iterator()
                raise StopIteration
            try:
                # create list of keys which still contain data
                keys_with_data = [key for key, contains_data in self.iter_contains_data.items() if contains_data]
                # sample one key of that
                key = np.random.choice(keys_with_data)
                current_iterator = self._iterator[key]
                # not sure if copy is needed, but safety first
                batch = next(current_iterator)
                batch = copy.deepcopy(batch)
                # save current iterator to list
                self._iterator[key] = current_iterator
                # don't need to add relative positions here, already done in the init (since the data is not perturbed,
                # this is possible here, but not in the step iterator)
                batch.to(self.device)
                if self.config.input_mesh_noise > 0.0:
                    batch = add_gaussian_noise(batch,
                                               node_type_description=keys.MESH,
                                               sigma=self.config.input_mesh_noise,
                                               device=self.device)
                    batch = remove_edge_distances(batch)
                    batch = add_distances_from_positions(batch, self.config.add_euclidian_distance)
                    batch = add_and_update_node_features(batch, second_order_dynamics=False)
                return TrajectoryTrainBatch(batch)
            except StopIteration:
                # If the DataLoader's iterator is exhausted, mark it as exhausted and try again
                self.iter_contains_data[key] = False

    def __len__(self):
        # number of batches == sum of length of all dataloaders
        return sum([len(dataloader) for dataloader in self._dataloader.values()])

    @property
    def num_tasks(self):
        return self._num_tasks

    def get_context_size(self, idx_config, context_type: str):
        if context_type == keys.MESH:
            return len(idx_config.mesh.indices)
        elif context_type == keys.POINT_CLOUD:
            return len(idx_config.point_cloud.indices)
        else:
            raise ValueError(f"Unknown context type {context_type}")

    def refresh_iterator(self):
        for key in self.iter_contains_data:
            self.iter_contains_data[key] = True
            self._iterator[key] = iter(self._dataloader[key])