import copy
import os
import pickle
from typing import List, Dict, Callable

import numpy as np
import torch
from torch_geometric.data import Data, Batch

from ltsgns_mp.envs.data_loader.abstract_dataloader import AbstractDataloader
from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types, add_second_order_dynamics

from ltsgns_mp.util.graph_input_output_util import build_edges_from_data_dict
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class PlanarBendingDataloader(AbstractDataloader):
    def __init__(self, config: ConfigDict):
        super().__init__(config)

    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        """
        Returns the rollout length of the task. This is the number of timesteps in the rollout.
        Args:
            raw_traj:

        Returns:

        """
        return len(raw_traj["mesh_nodes"]) - 1

    def _load_raw_data(self, split: str) -> List[ValueDict]:
        path_to_datasets = self.config.path_to_datasets
        dataset_name = self.config.dataset_name
        with open(os.path.join(path_to_datasets, dataset_name, dataset_name + "_" + split + ".pkl"), "rb") as file:
            rollout_data = pickle.load(file)
        return rollout_data

    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        traj: ValueDict = {keys.MESH: raw_traj["mesh_nodes"],
                           keys.MESH_EDGE_INDEX: raw_traj["mesh_edges"],
                           keys.MESH_FACES: raw_traj["mesh_faces"],
                           keys.BOARDER_NODES: raw_traj["boarder_nodes"],
                           keys.FORCE_INFLUENCE: raw_traj["force_influence"],
                           }

        parameters = np.array([raw_traj["youngs_modulus"]]),

        # temp_ply is always the same

        min_parameters = np.array([10])
        max_parameters = np.array([1000])
        parameters = (parameters - min_parameters) / (max_parameters - min_parameters)
        traj[keys.PARAMETERS] = parameters
        return traj

    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Function to get the correct data and convert to tensors from a single time step of a trajectory of the prepared data
        output from SOFA
        :return: Dict containing all the data for a single timestep in torch tensor format.
        :param raw_traj: ValueDict with all the data for all timesteps
        :param timestep: timestep to get the data for
        """
        data_dict = {keys.MESH: torch.tensor(raw_traj[keys.MESH][timestep], dtype=torch.float32),
                     keys.NEXT_MESH_POS: torch.tensor(raw_traj[keys.MESH][timestep + 1], dtype=torch.float32),
                     keys.INITIAL_MESH_POSITIONS: torch.tensor(raw_traj[keys.MESH][0], dtype=torch.float32),
                     keys.MESH_EDGE_INDEX: torch.tensor(raw_traj[keys.MESH_EDGE_INDEX].T, dtype=torch.long),
                     keys.MESH_FACES: torch.tensor(raw_traj[keys.MESH_FACES], dtype=torch.long),
                     keys.BOARDER_NODES: torch.tensor(raw_traj[keys.BOARDER_NODES], dtype=torch.float),
                     keys.FORCE_INFLUENCE: torch.tensor(raw_traj[keys.FORCE_INFLUENCE], dtype=torch.float),}

        if self.config.second_order_dynamics:
            data_dict = add_second_order_dynamics(data_dict, timestep, raw_traj)

        if keys.PARAMETERS in raw_traj:
            data_dict[keys.PARAMETERS] = torch.tensor(raw_traj[keys.PARAMETERS], dtype=torch.float32)

        return data_dict

    def _build_graph(self, data_dict: ValueDict) -> Data:
        use_canonic_mesh_positions = self.config.use_canonic_mesh_positions
        use_parameters_as_node_feature = self.config.use_parameters_as_node_feature
        second_order_dynamics = self.config.second_order_dynamics
        connectivity_setting = self.config.connectivity_setting
        # build nodes features (one hot node type)
        pos_keys = [keys.MESH]
        x_description = copy.deepcopy(pos_keys)

        num_nodes = [data_dict[pos_key].shape[0] for pos_key in pos_keys]

        x, node_type = get_one_hot_features_and_types(input_list=num_nodes)
        boarder_nodes = data_dict[keys.BOARDER_NODES].reshape(-1, 1)
        force_influence = data_dict[keys.FORCE_INFLUENCE].reshape(-1, 1)
        x = torch.cat((x, boarder_nodes, force_influence), dim=1)
        x_description += ["boarder_nodes", "force_influence"]
        if second_order_dynamics:
            prev_pos = data_dict[keys.PREV_MESH_POS]
            current_mesh_pos = data_dict[keys.MESH]
        else:
            prev_pos = None
            current_mesh_pos = None

        pos = torch.cat(tuple(data_dict[pos_key] for pos_key in pos_keys), dim=0)

        # we save the poisson ratio as task property
        if keys.PARAMETERS in data_dict:
            task_properties = data_dict[keys.PARAMETERS].reshape(1, -1)
            task_properties_description = ["Youngs Modulus"]
        else:
            task_properties = torch.zeros(size=(1, 0))
            task_properties_description = []

        if use_parameters_as_node_feature:
            if keys.PARAMETERS not in data_dict:
                raise ValueError("Cannot use poisson ratio as node feature if it is not present in the data")
            # add parameters to node feature
            x = torch.cat((x, task_properties.repeat(x.shape[0], 1)), dim=1)
            x_description += task_properties_description


        data = Data(x=x,
                    x_description=x_description,
                    pos=pos,
                    next_mesh_pos=data_dict[keys.NEXT_MESH_POS],
                    current_mesh_pos=current_mesh_pos,
                    prev_mesh_pos=prev_pos,
                    node_type=node_type,
                    node_type_description=pos_keys,
                    task_properties=task_properties,
                    task_properties_description=task_properties_description,
                    mesh_faces=data_dict[keys.MESH_FACES],
                    collider_faces=None,
                    collider_vertices=None,
                    )

        # edge features: edge_attr is just one-hot edge type, all other features are created after preprocessing
        edge_attr, edge_type, edge_index, edge_type_description = build_edges_from_data_dict(data_dict,
                                                                                             pos_keys,
                                                                                             num_nodes,
                                                                                             connectivity_setting,
                                                                                             use_canonic_mesh_positions)

        data.__setattr__("edge_attr", edge_attr)
        data.__setattr__("edge_type", edge_type)
        data.__setattr__("edge_index", edge_index)
        data.__setattr__("edge_type_description", edge_type_description)

        return data
