import copy
import os
import pickle
from typing import List, Dict, Callable

import numpy as np
import torch
from torch_geometric.data import Data, Batch

from ltsgns_mp.envs.data_loader.abstract_dataloader import AbstractDataloader
from ltsgns_mp.envs.util.data_loader_util import get_one_hot_features_and_types, add_second_order_dynamics

from ltsgns_mp.util.graph_input_output_util import build_edges_from_data_dict
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class ToyTaskDataloader(AbstractDataloader):
    def __init__(self, config: ConfigDict):
        super().__init__(config)

    ###########################################
    ####### Interfaces for data loading #######
    ###########################################

    def _get_rollout_length(self, raw_traj: ValueDict) -> int:
        """
        Returns the rollout length of the task. This is the number of timesteps in the rollout.
        We have a -1 here because the initial for loop went to rollout_length -2 here and -1 for all other tasks
        Args:
            raw_traj:

        Returns:

        """
        if "positions" in raw_traj.keys():
            # Standard case
            return len(raw_traj["positions"]) - 1
        else:
            raise ValueError("Unknown dataset format. Cannot determine rollout")

    def _load_raw_data(self, split: str) -> List[ValueDict]:
        path_to_datasets = self.config.path_to_datasets
        dataset_name = self.config.dataset_name
        path = os.path.join(path_to_datasets, dataset_name, dataset_name + "_" + split + ".npy")
        raw_data = np.load(path, allow_pickle=True)
        return raw_data

    def _select_and_normalize_attributes(self, raw_traj: ValueDict) -> ValueDict:
        if "positions" in raw_traj.keys():  # add num nodes dimension
            traj: ValueDict = {keys.MESH: raw_traj["positions"][:, None, :].type(torch.float32),
                               }
        else:
            raise ValueError("Unknown task type")
        traj[keys.PARAMETERS] = torch.tensor(raw_traj["params"], dtype=torch.float32)
        if "pc_positions" in raw_traj.keys():
            # Add point cloud positions
            traj[keys.POINT_CLOUD] = raw_traj["pc_positions"].type(torch.float32)
        return traj

    def _build_data_dict(self, raw_traj: ValueDict, timestep: int) -> ValueDict:
        """
        Function to get the correct data and convert to tensors from a single time step of a trajectory of the prepared data
        output from SOFA
        :return: Dict containing all the data for a single timestep in torch tensor format.
        :param raw_traj: ValueDict with all the data for all timesteps
        :param timestep: timestep to get the data for
        :param world_to_model_normalizer: WorldToModelNormalizer to normalize the data
        """
        data_dict = {keys.MESH: raw_traj[keys.MESH][timestep],
                     keys.NEXT_MESH_POS: raw_traj[keys.MESH][timestep + 1],
                     keys.MESH_EDGE_INDEX: torch.zeros((2, 0), dtype=torch.long)
                     }

        if self.config.second_order_dynamics:
            data_dict = add_second_order_dynamics(data_dict, timestep, raw_traj)

        if keys.PARAMETERS in raw_traj:
            data_dict[keys.PARAMETERS] = raw_traj[keys.PARAMETERS]

        if self.config.use_point_cloud and keys.POINT_CLOUD in raw_traj:
            data_dict[keys.POINT_CLOUD] = raw_traj[keys.POINT_CLOUD][timestep]

        return data_dict

    def _build_graph(self, data_dict: ValueDict) -> Data:
        use_parameters_as_node_feature = self.config.use_parameters_as_node_feature
        connectivity_setting = self.config.connectivity_setting
        use_canonic_mesh_positions = self.config.use_canonic_mesh_positions
        second_order_dynamics = self.config.second_order_dynamics

        # build nodes features (one hot node type)
        pos_keys = [keys.MESH]
        if keys.POINT_CLOUD in data_dict and self.config.use_point_cloud_as_graph:
            pos_keys.append(keys.POINT_CLOUD)
        x_description = copy.deepcopy(pos_keys)
        num_nodes = [data_dict[pos_key].shape[0] for pos_key in pos_keys]

        x, node_type = get_one_hot_features_and_types(input_list=num_nodes)

        if second_order_dynamics:
            prev_pos = data_dict[keys.PREV_MESH_POS]
            current_mesh_pos = data_dict[keys.MESH]
        else:
            prev_pos = None
            current_mesh_pos = None

        pos = torch.cat(tuple(data_dict[pos_key] for pos_key in pos_keys), dim=0)

        # point cloud data
        if keys.POINT_CLOUD in data_dict:
            point_cloud = data_dict[keys.POINT_CLOUD]
        else:
            point_cloud = None

        if keys.PARAMETERS in data_dict:
            task_properties = data_dict[keys.PARAMETERS].reshape(1, -1)
            task_properties_description = ["a", "b", "c"]
        else:
            task_properties = torch.zeros(size=(1, 0))
            task_properties_description = []

        if use_parameters_as_node_feature:
            if keys.PARAMETERS not in data_dict:
                raise ValueError("Cannot use the task parameters as node feature if it is not present in the data")
            # add parameters to node feature
            x = torch.cat((x, task_properties.repeat(x.shape[0], 1)), dim=1)
            x_description += task_properties_description

        # in the toy task, we save the positions as absolute features in the x
        x = torch.cat((x, pos), dim=1)
        x_description += [keys.MESH + "_x_pos", keys.MESH + "_y_pos"]

        data = Data(x=x,
                    x_description=x_description,
                    pos=pos,
                    next_mesh_pos=data_dict[keys.NEXT_MESH_POS],
                    current_mesh_pos=current_mesh_pos,
                    prev_mesh_pos=prev_pos,
                    node_type=node_type,
                    node_type_description=pos_keys,
                    task_properties=task_properties,
                    task_properties_description=task_properties_description,
                    point_cloud=point_cloud,
                    )

        # edge features: edge_attr is just one-hot edge type, all other features are created after preprocessing
        edge_attr, edge_type, edge_index, edge_type_description = build_edges_from_data_dict(data_dict,
                                                                                             pos_keys,
                                                                                             num_nodes,
                                                                                             connectivity_setting,
                                                                                             use_canonic_mesh_positions)

        data.__setattr__("edge_attr", edge_attr)
        data.__setattr__("edge_type", edge_type)
        data.__setattr__("edge_index", edge_index)
        data.__setattr__("edge_type_description", edge_type_description)

        return data
