import copy
import os
import pickle
from typing import List, Dict, Callable

import numpy as np
import torch
from torch_geometric.data import Data, Batch

from ltsgns_mp.envs.data_loader.abstract_dataloader import AbstractDataloader
from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types, add_second_order_dynamics

from ltsgns_mp.util.graph_input_output_util import build_edges_from_data_dict
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class HydraulicPressLowResDataloader(AbstractDataloader):
    def __init__(self, config: ConfigDict):
        super().__init__(config)

    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        """
        Returns the rollout length of the task. This is the number of timesteps in the rollout.
        Args:
            raw_traj:

        Returns:

        """
        return len(raw_traj["mesh_nodes"]) - 1


    def _load_raw_data(self, split: str) -> List[ValueDict]:
        path_to_datasets = self.config.path_to_datasets
        dataset_name = self.config.dataset_name
        with open(os.path.join(path_to_datasets, dataset_name, dataset_name + "_" + split + ".pkl"), "rb") as file:
            rollout_data = pickle.load(file)
        return rollout_data

    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        traj: ValueDict = {keys.MESH: raw_traj["mesh_nodes"],
                           keys.MESH_EDGE_INDEX: raw_traj["mesh_edges"],
                           keys.MESH_FACES: raw_traj["mesh_faces"],
                           keys.COLLIDER: raw_traj["collider_nodes"],
                           keys.COLLIDER_EDGE_INDEX: raw_traj["collider_edges"],
                           keys.COLLIDER_FACES: raw_traj["collider_faces"],
                           }

        parameters = np.array([raw_traj["thickness"],
                               raw_traj["orientation"],
                               raw_traj["temp_collider"]])
        # temp_ply is always the same

        min_parameters = np.array([0.5, 0.0, 80])
        max_parameters = np.array([1.5, 90.0, 120])
        parameters = (parameters - min_parameters) / (max_parameters - min_parameters)
        traj[keys.PARAMETERS] = parameters
        return traj

    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Function to get the correct data and convert to tensors from a single time step of a trajectory of the prepared data
        output from SOFA
        :return: Dict containing all the data for a single timestep in torch tensor format.
        :param raw_traj: ValueDict with all the data for all timesteps
        :param timestep: timestep to get the data for
        """
        data_dict = {keys.MESH: torch.tensor(raw_traj[keys.MESH][timestep], dtype=torch.float32),
                     keys.NEXT_MESH_POS: torch.tensor(raw_traj[keys.MESH][timestep + 1], dtype=torch.float32),
                     keys.INITIAL_MESH_POSITIONS: torch.tensor(raw_traj[keys.MESH][0], dtype=torch.float32),
                     keys.MESH_EDGE_INDEX: torch.tensor(raw_traj[keys.MESH_EDGE_INDEX].T, dtype=torch.long),
                     keys.MESH_FACES: torch.tensor(raw_traj[keys.MESH_FACES], dtype=torch.long)}

        if self.config.second_order_dynamics:
            data_dict = add_second_order_dynamics(data_dict, timestep, raw_traj)

        if keys.COLLIDER in raw_traj.keys():  # add information about the collider mesh
            data_dict |= {keys.COLLIDER: torch.tensor(raw_traj[keys.COLLIDER][timestep], dtype=torch.float32)
                          }
            data_dict |= {
                keys.COLLIDER_VELOCITY: torch.tensor(
                    raw_traj[keys.COLLIDER][timestep + 1] - raw_traj[keys.COLLIDER][timestep],
                    dtype=torch.float32)}
        if keys.COLLIDER_EDGE_INDEX in raw_traj.keys():  # Collider has more than a single node
            data_dict |= {
                keys.COLLIDER_EDGE_INDEX: torch.tensor(raw_traj[keys.COLLIDER_EDGE_INDEX].T, dtype=torch.long),
                keys.COLLIDER_FACES: torch.tensor(raw_traj[keys.COLLIDER_FACES], dtype=torch.long), }

        if keys.PARAMETERS in raw_traj:
            data_dict[keys.PARAMETERS] = torch.tensor(raw_traj[keys.PARAMETERS], dtype=torch.float32)

        return data_dict

    def _build_graph(self, data_dict: ValueDict) -> Data:
        use_collider_velocities = self.config.use_collider_velocities
        use_canonic_mesh_positions = self.config.use_canonic_mesh_positions
        use_parameters_as_node_feature = self.config.use_parameters_as_node_feature
        second_order_dynamics = self.config.second_order_dynamics
        connectivity_setting = self.config.connectivity_setting
        # build nodes features (one hot node type)
        pos_keys = [keys.MESH]
        if keys.COLLIDER in data_dict:
            pos_keys.append(keys.COLLIDER)
        x_description = copy.deepcopy(pos_keys)

        num_nodes = [data_dict[pos_key].shape[0] for pos_key in pos_keys]

        x, node_type = get_one_hot_features_and_types(input_list=num_nodes)
        if use_collider_velocities:
            collider_vel = data_dict[keys.COLLIDER_VELOCITY]
            padded_collider_vel = torch.zeros(size=(x.shape[0], collider_vel.shape[1]))
            padded_collider_vel[node_type == pos_keys.index(keys.COLLIDER)] = collider_vel
            x = torch.cat((x, padded_collider_vel), dim=1)
            x_description += [keys.COLLIDER_VELOCITY] * collider_vel.shape[1]

        if second_order_dynamics:
            prev_pos = data_dict[keys.PREV_MESH_POS]
            current_mesh_pos = data_dict[keys.MESH]
        else:
            prev_pos = None
            current_mesh_pos = None

        pos = torch.cat(tuple(data_dict[pos_key] for pos_key in pos_keys), dim=0)

        # we save the poisson ratio as task property
        if keys.PARAMETERS in data_dict:
            task_properties = data_dict[keys.PARAMETERS].reshape(1, -1)
            task_properties_description = ["Laminate Thickness", "Laminate Orientation", "Temp Mesh", "Temp Collider"]
        else:
            task_properties = torch.zeros(size=(1, 0))
            task_properties_description = []

        if use_parameters_as_node_feature:
            if keys.PARAMETERS not in data_dict:
                raise ValueError("Cannot use poisson ratio as node feature if it is not present in the data")
            # add parameters to node feature
            x = torch.cat((x, task_properties.repeat(x.shape[0], 1)), dim=1)
            x_description += task_properties_description

        if keys.COLLIDER_FACES in data_dict:
            # collider exists. Move the indices of the collider faces to the correct position for proper visualization
            collider_faces = data_dict[keys.COLLIDER_FACES] + data_dict[keys.MESH].shape[0]
            collider_vertices = None  # given in "x"
        else:
            collider_faces = None
            collider_vertices = None

        data = Data(x=x,
                    x_description=x_description,
                    pos=pos,
                    next_mesh_pos=data_dict[keys.NEXT_MESH_POS],
                    current_mesh_pos=current_mesh_pos,
                    prev_mesh_pos=prev_pos,
                    node_type=node_type,
                    node_type_description=pos_keys,
                    task_properties=task_properties,
                    task_properties_description=task_properties_description,
                    mesh_faces=data_dict[keys.MESH_FACES],
                    collider_faces=collider_faces,
                    collider_vertices=collider_vertices,
                    )

        # edge features: edge_attr is just one-hot edge type, all other features are created after preprocessing
        edge_attr, edge_type, edge_index, edge_type_description = build_edges_from_data_dict(data_dict,
                                                                                             pos_keys,
                                                                                             num_nodes,
                                                                                             connectivity_setting,
                                                                                             use_canonic_mesh_positions)

        data.__setattr__("edge_attr", edge_attr)
        data.__setattr__("edge_type", edge_type)
        data.__setattr__("edge_index", edge_index)
        data.__setattr__("edge_type_description", edge_type_description)

        return data
