import copy
from collections import defaultdict
from typing import List

import numpy as np
import torch
from omegaconf import OmegaConf
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from ltsgns_mp.envs.train_iterator.abstract_train_iterator import AbstractTrainIterator, AbstractTrainBatch
from ltsgns_mp.envs.util.convert_to_single_data_trajectory import convert_traj_to_data, \
    compute_point_cloud_padding_size, build_idx_config, get_mesh_context_indices
from ltsgns_mp.util import keys
from ltsgns_mp.util.graph_input_output_util import remove_edge_distances, add_distances_from_positions, \
    add_gaussian_noise, add_and_update_node_features, add_label, node_type_mask


class CNPTrainBatch(AbstractTrainBatch):
    def __init__(self, context_batch: Batch, target_batch: Batch):
        super().__init__()
        self.context_batch = context_batch
        self.target_batch = target_batch


class TaskIndexIterator:
    def __init__(self, num_tasks, batch_size=1):
        self.num_tasks = num_tasks
        self.task_indices = list(range(num_tasks))
        # shuffle
        np.random.shuffle(self.task_indices)
        self.batch_size = batch_size

    def __iter__(self):
        return self

    def __next__(self):
        if len(self.task_indices) < self.batch_size:
            raise StopIteration
        indices = self.task_indices[:self.batch_size]
        self.task_indices = self.task_indices[self.batch_size:]
        return indices

    def __len__(self):
        return self.num_tasks // self.batch_size


class CNPTrainIterator(AbstractTrainIterator):
    def __init__(self, config, train_trajs, device):
        super().__init__(config, train_trajs, device)
        self.train_trajs = train_trajs
        self.task_index_iterator = self.get_task_index_iterator()
        self.batch_size = config.batch_size

    def get_task_index_iterator(self):
        return TaskIndexIterator(len(self.train_trajs), 1)

    def __iter__(self):
        return self

    def __next__(self):
        # draw a context size
        _, _, context_indices = get_mesh_context_indices(self.config, self.train_trajs[0])
        anchor_index = context_indices[-1] + 1
        # check if context indices is longer than batch size
        if len(context_indices) > self.batch_size:
            # subsample
            context_indices = sorted(np.random.choice(context_indices, self.batch_size, replace=False))

        context_data_list = []
        next_step_context_data_list = []
        target_data_list = []
        try:
            task_index = next(self.task_index_iterator)[0]
        except StopIteration:
            self.refresh_iterator()
            raise StopIteration
        traj = self.train_trajs[task_index]
        if self.config.last_collider_as_feature and keys.COLLIDER in traj[0].node_type_description:
            last_step = traj[-1].clone()
            last_collider = last_step.pos[node_type_mask(last_step, keys.COLLIDER)]
        else:
            last_collider = None
        for context_idx in context_indices:
            context_data = traj[context_idx].clone()
            if last_collider is not None:
                current_collider = context_data.pos[node_type_mask(context_data, keys.COLLIDER)]
                rel_last_collider = last_collider - current_collider
                # remove batch dim, since we add it to the x features
                collider_features = torch.zeros((context_data.x.shape[0], rel_last_collider.shape[1]))
                collider_features[node_type_mask(context_data, keys.COLLIDER)] = rel_last_collider
                context_data.x = torch.cat([context_data.x, collider_features], dim=1)
                context_data.x_description += ["rel_last_collider"] * collider_features.shape[1]
            context_data_list.append(context_data)
            next_step_context_data_list.append(traj[context_idx + 1].clone())
        context_batch = Batch.from_data_list(context_data_list)
        next_step_context_batch = Batch.from_data_list(next_step_context_data_list)
        # now it needs to have the velocities of the current step (i.e. the y values)
        vel = next_step_context_batch.pos - context_batch.pos
        # add velocity to the x feature
        context_batch.x = torch.cat([context_batch.x, vel], dim=1)
        # update x feature description
        for desc in context_batch.x_description:
            desc += [keys.VELOCITIES] * vel.shape[1]
        context_batch.to(self.device)
        # correct feature computation
        context_batch = remove_edge_distances(context_batch)
        context_batch = add_distances_from_positions(context_batch, self.config.add_euclidian_distance)
        context_batch = add_and_update_node_features(context_batch, second_order_dynamics=False)
        # add anchor index
        context_batch.anchor_index = torch.tensor(anchor_index, device=self.device)

        # target batch
        if self.config.trajectory_targets:
            idx_config = {
                keys.MESH: {
                    "indices": [],
                },
                keys.POINT_CLOUD: {
                    "indices": [],
                },
                keys.EVALUATION: {
                    "indices": [],
                },
                "anchor_idx": anchor_index,
            }
            # convert to omegaconf
            idx_config = OmegaConf.create(idx_config)

            target_batch = convert_traj_to_data(traj, idx_config,
                                                point_cloud_padding_size=None,
                                                context_type="mesh",
                                                anchor_index_as_feature=self.config.anchor_index_as_feature,
                                                last_collider_as_feature=self.config.last_collider_as_feature)

            target_batch = Batch.from_data_list([target_batch])
            target_batch.to(self.device)
            if self.config.input_mesh_noise > 0.0:
                target_batch = add_gaussian_noise(target_batch,
                                           node_type_description=keys.MESH,
                                           sigma=self.config.input_mesh_noise,
                                           device=self.device)

        else:
            if len(self.train_trajs[0]) < self.batch_size:
                target_indices = list(range(len(self.train_trajs[0])))
            else:
                target_indices = np.random.choice(list(range(len(self.train_trajs[0]))), self.batch_size, replace=False)
            for target_idx in target_indices:
                target_data_list.append(traj[target_idx].clone())
                # add noise
            target_batch = Batch.from_data_list(target_data_list)
            target_batch.to(self.device)
            if self.config.wrong_label_computation:
                # add noise to the positions and compute the edge features again
                target_batch = add_label(target_batch, self.config.second_order_dynamics)
                target_batch = add_gaussian_noise(target_batch,
                                                  node_type_description=keys.MESH,
                                                  sigma=self.config.input_mesh_noise,
                                                  device=self.device)
            else:
                # add noise to the positions and compute the edge features again
                target_batch = add_gaussian_noise(target_batch,
                                                  node_type_description=keys.MESH,
                                                  sigma=self.config.input_mesh_noise,
                                                  device=self.device)
                target_batch = add_label(target_batch, self.config.second_order_dynamics)
        target_batch = remove_edge_distances(target_batch)
        target_batch = add_distances_from_positions(target_batch, self.config.add_euclidian_distance)
        target_batch = add_and_update_node_features(target_batch, second_order_dynamics=False)
        return CNPTrainBatch(context_batch, target_batch)

    def __len__(self):
        return len(self.task_index_iterator)

    @property
    def num_tasks(self):
        return len(self.task_index_iterator)

    def refresh_iterator(self):
        self.task_index_iterator = self.get_task_index_iterator()


if __name__ == "__main__":
    ti = TaskIndexIterator(10, 3)
    for i in ti:
        print(i)
