import copy
import os
import pickle
import warnings
from typing import List, Dict, Callable

import numpy as np
import torch
from torch_geometric.data import Data, Batch

from ltsgns_mp.envs.data_loader.abstract_dataloader import AbstractDataloader
from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types, add_second_order_dynamics

from ltsgns_mp.util.graph_input_output_util import build_edges_from_data_dict
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class IsaacSimDataloader(AbstractDataloader):
    def __init__(self, config: ConfigDict):
        super().__init__(config)

    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        """
        Returns the rollout length of the task. This is the number of timesteps in the rollout.
        We have a -1 here because the initial for loop went to rollout_length -2 here and -1 for all other tasks
        Args:
            raw_traj:

        Returns:

        """
        if "nodes_grid" in raw_traj.keys():
            # Standard case, for example Sphere Fall
            return len(raw_traj["nodes_grid"]) - 1
        else:
            raise ValueError("Unknown dataset format. Cannot determine rollout")

    def _load_raw_data(self, split: str) -> List[ValueDict]:
        path_to_datasets = self.config.path_to_datasets
        dataset_name = self.config.dataset_name
        with open(os.path.join(path_to_datasets, dataset_name, dataset_name + "_" + split + ".pkl"), "rb") as file:
            rollout_data = pickle.load(file)
        return rollout_data

    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        # return select_and_normalize_attributes(raw_task,
        #                                        use_point_cloud=self.preprocess_config.use_point_cloud,
        #                                        use_poisson_ratio=self.preprocess_config.use_poisson_ratio,
        #                                        )
        if "nodes_grid" in raw_traj.keys():  # standard case, for example sphere fall
            # z_correction = False
            # if z_correction:
            #     raw_traj["nodes_grid"][:, :, 2] += 1
            traj: ValueDict = {keys.MESH: raw_traj["nodes_grid"],
                               keys.MESH_EDGE_INDEX: raw_traj["edge_index_grid"],
                               keys.MESH_FACES: raw_traj["triangles_grid"],
                               # keys.COLLIDER: raw_traj["nodes_collider"],
                               # keys.COLLIDER_EDGE_INDEX: raw_traj["edge_index_collider"],
                               # keys.COLLIDER_FACES: raw_traj["triangles_collider"],
                               }
        else:
            raise ValueError("Unknown task type")
        if "pcd_points" in raw_traj.keys():
            traj[keys.POINT_CLOUD] = raw_traj["pcd_points"]

        if keys.YOUNGS_MODULUS in raw_traj.keys():
            # always normalize poisson ratio and add it to the task, although it will usually not be used
            youngs_modulus = raw_traj[keys.YOUNGS_MODULUS]
            poisson_ratio = raw_traj["poisson_ratios"]
            # normalize to -1,1
            min_youngs_modulus = 100000
            max_youngs_modulus = 1000000
            min_poisson_ratio = 0.0
            max_poisson_ratio = 0.5
            youngs_modulus = (youngs_modulus - min_youngs_modulus) / (max_youngs_modulus - min_youngs_modulus) * 2 - 1
            poisson_ratio = (poisson_ratio - min_poisson_ratio) / (max_poisson_ratio - min_poisson_ratio) * 2 - 1
            result = np.array([poisson_ratio, youngs_modulus])
            if result.shape == (2, 1000):
                # weird bug in the data that all values are saved for every episode -> need to regenerate, in order to make it work take just the first.
                # THIS RESULTS IN THE WRONG MATERIAL PROPERTIES PARAMETER! Don't do experiments with this data
                result = result[:, 0]
                warnings.warn("Weird bug in the data that all values are saved for every episode -> need to regenerate, in order to make it work take just the first. THIS RESULTS IN THE WRONG MATERIAL PROPERTIES PARAMETER! Don't do experiments with this data")
            traj[keys.PARAMETERS] = result

        return traj

    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Function to get the correct data and convert to tensors from a single time step of a trajectory of the prepared data
        output from SOFA
        :return: Dict containing all the data for a single timestep in torch tensor format.
        :param raw_traj: ValueDict with all the data for all timesteps
        :param timestep: timestep to get the data for
        :param world_to_model_normalizer: WorldToModelNormalizer to normalize the data
        """
        data_dict = {keys.MESH: torch.tensor(raw_traj[keys.MESH][timestep], dtype=torch.float32),
                     keys.NEXT_MESH_POS: torch.tensor(raw_traj[keys.MESH][timestep + 1], dtype=torch.float32),
                     keys.INITIAL_MESH_POSITIONS: torch.tensor(raw_traj[keys.MESH][0], dtype=torch.float32),
                     keys.MESH_EDGE_INDEX: torch.tensor(raw_traj[keys.MESH_EDGE_INDEX].T, dtype=torch.long),
                     keys.MESH_FACES: torch.tensor(raw_traj[keys.MESH_FACES], dtype=torch.long)}

        if self.config.second_order_dynamics:
            data_dict = add_second_order_dynamics(data_dict, timestep, raw_traj)

        if keys.COLLIDER in raw_traj.keys():  # add information about the collider mesh
            data_dict |= {keys.COLLIDER: torch.tensor(raw_traj[keys.COLLIDER][timestep], dtype=torch.float32)
                          }
            data_dict |= {
                keys.COLLIDER_VELOCITY: torch.tensor(
                    raw_traj[keys.COLLIDER][timestep + 1] - raw_traj[keys.COLLIDER][timestep],
                    dtype=torch.float32)}
        if keys.COLLIDER_EDGE_INDEX in raw_traj.keys():  # Collider has more than a single node
            data_dict |= {
                keys.COLLIDER_EDGE_INDEX: torch.tensor(raw_traj[keys.COLLIDER_EDGE_INDEX].T, dtype=torch.long),
                keys.COLLIDER_FACES: torch.tensor(raw_traj[keys.COLLIDER_FACES], dtype=torch.long), }

        if keys.VISUAL_COLLDER in raw_traj.keys():
            # add visual information about the collider mesh
            data_dict |= {
                keys.VISUAL_COLLDER: torch.tensor(raw_traj[keys.VISUAL_COLLDER][timestep], dtype=torch.float32),
                keys.VISUAL_COLLIDER_FACES: torch.tensor(raw_traj[keys.VISUAL_COLLIDER_FACES],
                                                         dtype=torch.long), }

        if keys.PARAMETERS in raw_traj:
            data_dict[keys.PARAMETERS] = torch.tensor(raw_traj[keys.PARAMETERS], dtype=torch.float32)

        if self.config.use_point_cloud and keys.POINT_CLOUD in raw_traj:
            data_dict[keys.POINT_CLOUD] = torch.tensor(raw_traj[keys.POINT_CLOUD][timestep],
                                                       dtype=torch.float32)
        if self.config.use_point_cloud and keys.POINT_CLOUD_COLORS in raw_traj:
            data_dict[keys.POINT_CLOUD_COLORS] = torch.tensor(raw_traj[keys.POINT_CLOUD_COLORS][timestep],
                                                              dtype=torch.float32)

        return data_dict

    def _build_graph(self, data_dict: ValueDict) -> Data:
        use_collider_velocities = self.config.use_collider_velocities
        use_canonic_mesh_positions = self.config.use_canonic_mesh_positions
        use_parameters_as_node_feature = self.config.use_parameters_as_node_feature
        second_order_dynamics = self.config.second_order_dynamics
        connectivity_setting = self.config.connectivity_setting
        # build nodes features (one hot node type)
        pos_keys = [keys.MESH]
        x_description = copy.deepcopy(pos_keys)
        if keys.COLLIDER in data_dict:
            pos_keys.append(keys.COLLIDER)
        num_nodes = [data_dict[pos_key].shape[0] for pos_key in pos_keys]
        # TODO: add point cloud as node type
        x, node_type = get_one_hot_features_and_types(input_list=num_nodes)
        if use_collider_velocities:
            collider_vel = data_dict[keys.COLLIDER_VELOCITY]
            padded_collider_vel = torch.zeros(size=(x.shape[0], collider_vel.shape[1]))
            padded_collider_vel[node_type == pos_keys.index(keys.COLLIDER)] = collider_vel
            x = torch.cat((x, padded_collider_vel), dim=1)
            x_description += [keys.COLLIDER_VELOCITY] * collider_vel.shape[1]

        if second_order_dynamics:
            prev_pos = data_dict[keys.PREV_MESH_POS]
            current_mesh_pos = data_dict[keys.MESH]
        else:
            prev_pos = None
            current_mesh_pos = None

        if keys.COLLIDER not in data_dict:
            # add the z coordinate as a feature to the nodes
            z = data_dict[keys.MESH][:, 2].reshape(-1, 1)
            x = torch.cat((x, z), dim=1)
            x_description.append(keys.MESH + "_z_pos")

        pos = torch.cat(tuple(data_dict[pos_key] for pos_key in pos_keys), dim=0)

        # we save the youngs modulus and the poisson ratio as task property
        if keys.PARAMETERS in data_dict:
            task_properties = data_dict[keys.PARAMETERS].reshape(1, 2)
            task_properties_description = ["Poisson Ratio", "Youngs Modulus"]
        else:
            task_properties = torch.zeros(size=(1, 0))
            task_properties_description = []

        if use_parameters_as_node_feature:
            if keys.PARAMETERS not in data_dict:
                raise ValueError("Cannot use the task parameters as node feature if it is not present in the data")
            # add parameters to node feature
            x = torch.cat((x, task_properties.repeat(x.shape[0], 1)), dim=1)
            x_description += task_properties_description

        if keys.COLLIDER_FACES in data_dict:
            # collider exists. Move the indices of the collider faces to the correct position for proper visualization
            collider_faces = data_dict[keys.COLLIDER_FACES] + data_dict[keys.MESH].shape[0]
            collider_vertices = None  # given in "x"
        else:
            collider_faces = None
            collider_vertices = None

        # import matplotlib.pyplot as plt
        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # ax.scatter(*data_dict[keys.MESH].T, c='r', marker='o')
        # # scatter the nodes with fixed flag
        # # fixed_nodes_indices = torch.where(fixed_nodes == 1)[0]
        # # ax.scatter(*data_dict[keys.MESH][fixed_nodes_indices].T, c='g', marker='o')
        # # ax.scatter(*data_dict[keys.COLLIDER].T, c='b', marker='o')
        # # xyz description
        # ax.set_xlabel('X')
        # ax.set_ylabel('Y')
        # ax.set_zlabel('Z')
        # plt.show()

        data = Data(x=x,
                    x_description=x_description,
                    pos=pos,
                    next_mesh_pos=data_dict[keys.NEXT_MESH_POS],
                    current_mesh_pos=current_mesh_pos,
                    prev_mesh_pos=prev_pos,
                    node_type=node_type,
                    node_type_description=pos_keys,
                    task_properties=task_properties,
                    task_properties_description=task_properties_description,
                    mesh_faces=data_dict[keys.MESH_FACES],
                    collider_faces=collider_faces,
                    collider_vertices=collider_vertices,
                    )

        # edge features: edge_attr is just one-hot edge type, all other features are created after preprocessing
        edge_attr, edge_type, edge_index, edge_type_description = build_edges_from_data_dict(data_dict,
                                                                                             pos_keys,
                                                                                             num_nodes,
                                                                                             connectivity_setting,
                                                                                             use_canonic_mesh_positions)

        data.__setattr__("edge_attr", edge_attr)
        data.__setattr__("edge_type", edge_type)
        data.__setattr__("edge_index", edge_index)
        data.__setattr__("edge_type_description", edge_type_description)

        return data
