import copy
import os
import pickle
from typing import List, Dict, Callable

import torch
from torch_geometric.data import Data, Batch

from ltsgns_mp.envs.data_loader.abstract_dataloader import AbstractDataloader
from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types, add_second_order_dynamics

from ltsgns_mp.util.graph_input_output_util import build_edges_from_data_dict
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class SofaDataloader(AbstractDataloader):
    def __init__(self, config: ConfigDict):
        super().__init__(config)

    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        """
        Returns the rollout length of the task. This is the number of timesteps in the rollout.
        We have a -1 here because the initial for loop went to rollout_length -2 here and -1 for all other tasks
        Args:
            raw_traj:

        Returns:

        """
        if "nodes_grid" in raw_traj.keys():
            # Deforming plate task
            return len(raw_traj["nodes_grid"]) - 1
        elif "tissue_mesh_positions" in raw_traj.keys():
            # Tissue manipulation task
            return len(raw_traj["tissue_mesh_positions"]) - 1
        else:
            raise ValueError("Unknown dataset format. Cannot determine rollout")

    def _load_raw_data(self, split: str) -> List[ValueDict]:
        path_to_datasets = self.config.path_to_datasets
        dataset_name = self.config.dataset_name
        with open(os.path.join(path_to_datasets, dataset_name, dataset_name + "_" + split + ".pkl"), "rb") as file:
            rollout_data = pickle.load(file)
        return rollout_data

    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        # return select_and_normalize_attributes(raw_task,
        #                                        use_point_cloud=self.preprocess_config.use_point_cloud,
        #                                        use_poisson_ratio=self.preprocess_config.use_poisson_ratio,
        #                                        )
        if "nodes_grid" in raw_traj.keys():  # deformable plate
            # the v2 version is not normalized, do that first
            if self.config.name == "deformable_plate_v2":
                raw_traj["nodes_grid"] = [x / 100 for x in raw_traj["nodes_grid"]]
                raw_traj["nodes_collider"] = [x / 100 for x in raw_traj["nodes_collider"]]

            traj: ValueDict = {keys.MESH: raw_traj["nodes_grid"],
                               keys.MESH_EDGE_INDEX: raw_traj["edge_index_grid"],
                               keys.MESH_FACES: raw_traj["triangles_grid"],
                               keys.COLLIDER: raw_traj["nodes_collider"],
                               keys.COLLIDER_EDGE_INDEX: raw_traj["edge_index_collider"],
                               keys.COLLIDER_FACES: raw_traj["triangles_collider"],
                               }
        elif "tissue_mesh_positions" in raw_traj.keys():  # tissue manipulation
            traj: ValueDict = {keys.MESH: raw_traj["tissue_mesh_positions"],
                               keys.MESH_EDGE_INDEX: raw_traj["tissue_mesh_edges"],
                               keys.MESH_FACES: raw_traj["tissue_mesh_triangles"],
                               keys.COLLIDER: raw_traj["gripper_position"],
                               }
            if "panda_position" in raw_traj.keys():
                # cavity grasping task, also add the collider faces and edges
                traj[keys.COLLIDER_EDGE_INDEX] = raw_traj["gripper_edges"]
                traj[keys.COLLIDER_FACES] = raw_traj["gripper_triangles"]
            else:
                if "gripper_mesh_positions" in raw_traj.keys():
                    # visual information for the test set
                    traj[keys.VISUAL_COLLDER] = raw_traj["gripper_mesh_positions"]
                    traj[keys.VISUAL_COLLIDER_FACES] = raw_traj["gripper_mesh_triangles"]
        else:
            raise ValueError("Unknown task type")

        if "pcd_points" in raw_traj.keys():
            # deformable plate
            traj[keys.POINT_CLOUD] = raw_traj["pcd_points"]
        elif "tissue_pcd_points" in raw_traj.keys():
            # tissue manipulation
            traj[keys.POINT_CLOUD] = raw_traj["tissue_pcd_points"]
        else:
            # no pointcloud data, do nothing
            pass

        # always normalize poisson ratio and add it to the task, although it will usually not be used
        poisson_ratio = raw_traj[keys.POISSON_RATIO]
        # normalize to -1,1
        poisson_ratio = (poisson_ratio + 0.205) * (200 / 139)
        traj[keys.POISSON_RATIO] = poisson_ratio

        return traj

    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Function to get the correct data and convert to tensors from a single time step of a trajectory of the prepared data
        output from SOFA
        :return: Dict containing all the data for a single timestep in torch tensor format.
        :param raw_traj: ValueDict with all the data for all timesteps
        :param timestep: timestep to get the data for
        :param world_to_model_normalizer: WorldToModelNormalizer to normalize the data
        """
        data_dict = {keys.MESH: torch.tensor(raw_traj[keys.MESH][timestep], dtype=torch.float32),
                     keys.NEXT_MESH_POS: torch.tensor(raw_traj[keys.MESH][timestep + 1], dtype=torch.float32),
                     keys.INITIAL_MESH_POSITIONS: torch.tensor(raw_traj[keys.MESH][0], dtype=torch.float32),
                     keys.MESH_EDGE_INDEX: torch.tensor(raw_traj[keys.MESH_EDGE_INDEX].T, dtype=torch.long),
                     keys.MESH_FACES: torch.tensor(raw_traj[keys.MESH_FACES], dtype=torch.long)}

        if self.config.second_order_dynamics:
            data_dict = add_second_order_dynamics(data_dict, timestep, raw_traj)

        if keys.COLLIDER in raw_traj.keys():  # add information about the collider mesh
            data_dict |= {keys.COLLIDER: torch.tensor(raw_traj[keys.COLLIDER][timestep], dtype=torch.float32)
                          }
            data_dict |= {
                keys.COLLIDER_VELOCITY: torch.tensor(
                    raw_traj[keys.COLLIDER][timestep + 1] - raw_traj[keys.COLLIDER][timestep],
                    dtype=torch.float32)}
        if keys.COLLIDER_EDGE_INDEX in raw_traj.keys():  # Collider has more than a single node
            data_dict |= {
                keys.COLLIDER_EDGE_INDEX: torch.tensor(raw_traj[keys.COLLIDER_EDGE_INDEX].T, dtype=torch.long),
                keys.COLLIDER_FACES: torch.tensor(raw_traj[keys.COLLIDER_FACES], dtype=torch.long), }

        if keys.VISUAL_COLLDER in raw_traj.keys():
            # add visual information about the collider mesh
            data_dict |= {
                keys.VISUAL_COLLDER: torch.tensor(raw_traj[keys.VISUAL_COLLDER][timestep], dtype=torch.float32),
                keys.VISUAL_COLLIDER_FACES: torch.tensor(raw_traj[keys.VISUAL_COLLIDER_FACES],
                                                         dtype=torch.long), }

        if keys.POISSON_RATIO in raw_traj:
            data_dict[keys.POISSON_RATIO] = torch.tensor(raw_traj[keys.POISSON_RATIO], dtype=torch.float32)

        if self.config.use_point_cloud and keys.POINT_CLOUD in raw_traj:
            data_dict[keys.POINT_CLOUD] = torch.tensor(raw_traj[keys.POINT_CLOUD][timestep],
                                                       dtype=torch.float32)
        if self.config.use_point_cloud and keys.POINT_CLOUD_COLORS in raw_traj:
            data_dict[keys.POINT_CLOUD_COLORS] = torch.tensor(raw_traj[keys.POINT_CLOUD_COLORS][timestep],
                                                              dtype=torch.float32)

        return data_dict

    def _build_graph(self, data_dict: ValueDict) -> Data:
        use_collider_velocities = self.config.use_collider_velocities
        use_canonic_mesh_positions = self.config.use_canonic_mesh_positions
        use_poisson_ratio_as_node_feature = self.config.use_poisson_ratio_as_node_feature
        second_order_dynamics = self.config.second_order_dynamics
        connectivity_setting = self.config.connectivity_setting
        # build nodes features (one hot node type)
        pos_keys = [keys.MESH]
        if keys.COLLIDER in data_dict:
            pos_keys.append(keys.COLLIDER)
        if keys.POINT_CLOUD in data_dict and self.config.use_point_cloud_as_graph:
            pos_keys.append(keys.POINT_CLOUD)

        x_description = copy.deepcopy(pos_keys)



        num_nodes = [data_dict[pos_key].shape[0] for pos_key in pos_keys]

        x, node_type = get_one_hot_features_and_types(input_list=num_nodes)
        if self.config.fixed_mesh_flags:
            # find minimum z value of node positions
            min_z = torch.min(data_dict[keys.MESH][:, 2])
            # find all nodes with z value smaller than min_z + epsilon
            fixed_nodes = (data_dict[keys.MESH][:, 2] < min_z + self.config.fixed_mesh_epsilon).to(torch.float).reshape(-1, 1)
            padded_fixed_nodes = torch.zeros(size=(x.shape[0], 1))
            padded_fixed_nodes[node_type == pos_keys.index(keys.MESH)] = fixed_nodes
            x = torch.cat((x, padded_fixed_nodes), dim=1)
            x_description.append(keys.FIXED_MESH_FLAG)

            # import matplotlib.pyplot as plt
            # fig = plt.figure()
            # ax = fig.add_subplot(111, projection='3d')
            # ax.scatter(*data_dict[keys.MESH].T, c='r', marker='o')
            # # scatter the nodes with fixed flag
            # fixed_nodes_indices = torch.where(fixed_nodes == 1)[0]
            # ax.scatter(*data_dict[keys.MESH][fixed_nodes_indices].T, c='g', marker='o')
            # ax.scatter(*data_dict[keys.COLLIDER].T, c='b', marker='o')
            # # xyz description
            # ax.set_xlabel('X')
            # ax.set_ylabel('Y')
            # ax.set_zlabel('Z')
            # plt.show()

        if use_collider_velocities:
            collider_vel = data_dict[keys.COLLIDER_VELOCITY]
            padded_collider_vel = torch.zeros(size=(x.shape[0], collider_vel.shape[1]))
            padded_collider_vel[node_type == pos_keys.index(keys.COLLIDER)] = collider_vel
            x = torch.cat((x, padded_collider_vel), dim=1)
            x_description += [keys.COLLIDER_VELOCITY] * collider_vel.shape[1]

        if second_order_dynamics:
            prev_pos = data_dict[keys.PREV_MESH_POS]
            current_mesh_pos = data_dict[keys.MESH]
        else:
            prev_pos = None
            current_mesh_pos = None

        pos = torch.cat(tuple(data_dict[pos_key] for pos_key in pos_keys), dim=0)

        # point cloud data
        if keys.POINT_CLOUD in data_dict:
            point_cloud = data_dict[keys.POINT_CLOUD]
        else:
            point_cloud = None

        # we save the poisson ratio as task property
        if keys.POISSON_RATIO in data_dict:
            task_properties = torch.tensor([data_dict[keys.POISSON_RATIO]]).reshape(1, -1)
            task_properties_description = ["poisson_ratio (normalized)"]
        else:
            task_properties = torch.zeros(size=(1, 0))
            task_properties_description = []

        if use_poisson_ratio_as_node_feature:
            if keys.POISSON_RATIO not in data_dict:
                raise ValueError("Cannot use poisson ratio as node feature if it is not present in the data")
            # add poisson ratio as node feature
            x = torch.cat((x, task_properties.repeat(x.shape[0], 1)), dim=1)
            x_description += task_properties_description

        if keys.COLLIDER_FACES in data_dict:
            # collider exists. Move the indices of the collider faces to the correct position for proper visualization
            collider_faces = data_dict[keys.COLLIDER_FACES] + data_dict[keys.MESH].shape[0]
            collider_vertices = None  # given in "x"
            visual_collider_faces = data_dict[keys.COLLIDER_FACES]
            visual_collider_vertices = None

        else:
            collider_faces = None
            collider_vertices = None
            visual_collider_faces = None
            visual_collider_vertices = None

        if keys.VISUAL_COLLIDER_FACES in data_dict:
            # no collider exists, but a collider can be visualized. Indices do not need to be moved, but positions stored
            visual_collider_faces = data_dict[keys.VISUAL_COLLIDER_FACES]
            visual_collider_vertices = data_dict[keys.VISUAL_COLLDER]


        data = Data(x=x,
                    x_description=x_description,
                    pos=pos,
                    next_mesh_pos=data_dict[keys.NEXT_MESH_POS],
                    current_mesh_pos=current_mesh_pos,
                    prev_mesh_pos=prev_pos,
                    node_type=node_type,
                    node_type_description=pos_keys,
                    task_properties=task_properties,
                    task_properties_description=task_properties_description,
                    mesh_faces=data_dict[keys.MESH_FACES],
                    collider_faces=collider_faces,
                    collider_vertices=collider_vertices,
                    visual_collider_faces=visual_collider_faces,
                    visual_collider_vertices=visual_collider_vertices,
                    point_cloud=point_cloud,
                    )

        # edge features: edge_attr is just one-hot edge type, all other features are created after preprocessing
        edge_attr, edge_type, edge_index, edge_type_description = build_edges_from_data_dict(data_dict,
                                                                                             pos_keys,
                                                                                             num_nodes,
                                                                                             connectivity_setting,
                                                                                             use_canonic_mesh_positions)

        data.__setattr__("edge_attr", edge_attr)
        data.__setattr__("edge_type", edge_type)
        data.__setattr__("edge_index", edge_index)
        data.__setattr__("edge_type_description", edge_type_description)

        return data
