import torch
import numpy as np
import collections.abc
from torch.utils.data._utils.collate import default_collate
import dill
from copy import deepcopy

container_abcs = collections.abc


def restore(data):
    """
    In case we dilled some structures to share between multiple process this function will restore them.
    If the data input are not bytes we assume it was not dilled in the first place

    :param data: Possibly dilled data structure
    :return: Un-dilled data structure
    """
    if type(data) is bytes:
        return dill.loads(data)
    return data


def dict_collate(batch):
    batch = collate(batch)
    (
        first_history_index,
        x_t,
        y_t,
        x_st_t,
        y_st_t,
        neighbors_data_st,
        neighbors_gt_st,
        neighbors_edge_value,
        robot_traj_st_t,
        map_tuple,
        dt,
        index,
    ) = batch

    out = {
        "index": index,
        "obs": x_t,
        "gt": y_t,
        "obs_st": x_st_t,
        "gt_st": y_st_t,
        "neighbors_st": neighbors_data_st,
        "neighbors_gt_st": neighbors_gt_st,
        "neighbors_edge": neighbors_edge_value,
        "robot_traj_st": robot_traj_st_t,
        "map": map_tuple,
        "dt": dt,
        "first_history_index": first_history_index,
    }

    return out


def collate(batch):
    if len(batch) == 0:
        return batch
    elem = batch[0]
    if elem is None:
        return None
    elif isinstance(elem, container_abcs.Sequence):
        if (
            len(elem) == 4
        ):  # We assume those are the maps, map points, headings and patch_size
            scene_map, scene_pts, heading_angle, patch_size = zip(*batch)
            if heading_angle[0] is None:
                heading_angle = None
            else:
                heading_angle = torch.Tensor(heading_angle)
            map = scene_map[0].get_cropped_maps_from_scene_map_batch(
                scene_map,
                scene_pts=torch.Tensor(scene_pts),
                patch_size=patch_size[0],
                rotation=heading_angle,
            )
            return map

        transposed = zip(*batch)
        return [
            collate(samples) if not isinstance(samples[0], tuple) else samples
            for samples in transposed
        ]

    elif isinstance(elem, container_abcs.Mapping):
        # We have to dill the neighbors structures. Otherwise each tensor is put into
        # shared memory separately -> slow, file pointer overhead
        # we only do this in multiprocessing
        neighbor_dict = {key: [d[key] for d in batch] for key in elem}
        return (
            dill.dumps(neighbor_dict)
            if torch.utils.data.get_worker_info()
            else neighbor_dict
        )

    return default_collate(batch)


def get_relative_robot_traj(env, state, node_traj, robot_traj, node_type, robot_type):
    # TODO: We will have to make this more generic if robot_type != node_type
    # Make Robot State relative to node
    _, std = env.get_standardize_params(state[robot_type], node_type=robot_type)
    std[0:2] = env.attention_radius[(node_type, robot_type)]
    robot_traj_st = env.standardize(
        robot_traj, state[robot_type], node_type=robot_type, mean=node_traj, std=std
    )
    robot_traj_st_t = torch.tensor(robot_traj_st, dtype=torch.float)

    return robot_traj_st_t


# def get_node_timestep_data(env, scene, t, node, state, pred_state,
#                            edge_types, max_ht, max_ft, hyperparams,
#                            scene_graph=None, normalize_direction=False):
#     """
#     Pre-processes the data for a single batch element: node state over time for a specific time in a specific scene
#     as well as the neighbour data for it.

#     :param env: Environment
#     :param scene: Scene
#     :param t: Timestep in scene
#     :param node: Node
#     :param state: Specification of the node state
#     :param pred_state: Specification of the prediction state
#     :param edge_types: List of all Edge Types for which neighbours are pre-processed
#     :param max_ht: Maximum history timesteps
#     :param max_ft: Maximum future timesteps (prediction horizon)
#     :param hyperparams: Model hyperparameters
#     :param scene_graph: If scene graph was already computed for this scene and time you can pass it here
#     :return: Batch Element
#     """

#     # Node
#     timestep_range_x = np.array([t - max_ht, t])
#     timestep_range_y = np.array([t + 1, t + max_ft])

#     x = node.get(timestep_range_x, state[node.type])            # (8, 6)
#     y = node.get(timestep_range_y, pred_state[node.type])       # (8, 6)
#     first_history_index = (max_ht - node.history_points_at(t)).clip(0)

#     _, std = env.get_standardize_params(state[node.type], node.type)
#     std[0:2] = env.attention_radius[(node.type, node.type)]
#     rel_state = np.zeros_like(x[0])
#     rel_state[0:2] = np.array(x)[-1, 0:2]
#     x_st = env.standardize(x, state[node.type], node.type, mean=rel_state, std=std)
#     if list(pred_state[node.type].keys())[0] == 'position':  # If we predict position we do it relative to current pos
#         y_st = env.standardize(y, pred_state[node.type], node.type, mean=rel_state[0:2], std=std[0:2])
#     else:
#         y_st = env.standardize(y, pred_state[node.type], node.type)
    
#     x_t = torch.tensor(x, dtype=torch.float)
#     y_t = torch.tensor(y, dtype=torch.float)
#     x_st_t = torch.tensor(x_st, dtype=torch.float)
#     y_st_t = torch.tensor(y_st, dtype=torch.float)

#     # Neighbors
#     neighbors_data_st = None
#     neighbors_edge_value = None
#     if hyperparams['edge_encoding']:
#         # Scene Graph
#         scene_graph = scene.get_scene_graph(t,
#                                             env.attention_radius,
#                                             hyperparams['edge_addition_filter'],
#                                             hyperparams['edge_removal_filter']) if scene_graph is None else scene_graph
#         neighbors_data_st = dict()
#         neighbors_gt_st = dict()
#         neighbors_edge_value = dict()
#         for edge_type in edge_types:
#             neighbors_data_st[edge_type] = list()
#             neighbors_gt_st[edge_type] = list()
#             # We get all nodes which are connected to the current node for the current timestep
#             connected_nodes = scene_graph.get_neighbors(node, edge_type[1])

#             if hyperparams['dynamic_edges'] == 'yes':
#                 # We get the edge masks for the current node at the current timestep
#                 edge_masks = torch.tensor(scene_graph.get_edge_scaling(node), dtype=torch.float)
#                 neighbors_edge_value[edge_type] = edge_masks

#             for connected_node in connected_nodes:
#                 neighbor_state_np = connected_node.get(timestep_range_x,
#                                                        state[connected_node.type],
#                                                        padding=0.0)
#                 neighbor_gt_np = connected_node.get(timestep_range_y,
#                                                      pred_state[connected_node.type],
#                                                      padding=0.0)

#                 # Make State relative to node where neighbor and node have same state
#                 _, std = env.get_standardize_params(state[connected_node.type], node_type=connected_node.type)
#                 std[0:2] = env.attention_radius[edge_type]
#                 equal_dims = np.min((neighbor_state_np.shape[-1], x.shape[-1]))
#                 rel_state = np.zeros_like(neighbor_state_np)
#                 rel_state[:, ..., :equal_dims] = x[-1, ..., :equal_dims]
#                 neighbor_state_np_st = env.standardize(neighbor_state_np,
#                                                        state[connected_node.type],
#                                                        node_type=connected_node.type,
#                                                        mean=rel_state,
#                                                        std=std)
#                 _, std = env.get_standardize_params(pred_state[connected_node.type], node_type=connected_node.type)
#                 std[0:2] = env.attention_radius[edge_type]
#                 equal_dims = np.min((neighbor_gt_np.shape[-1], x.shape[-1]))
#                 rel_state = np.zeros_like(neighbor_gt_np)
#                 rel_state[:, ..., :equal_dims] = x[-1, ..., :equal_dims]
#                 neighbor_gt_np_st = env.standardize(neighbor_gt_np,
#                                                     pred_state[connected_node.type],
#                                                     node_type=connected_node.type,
#                                                     mean=rel_state)

#                 neighbor_state = torch.tensor(neighbor_state_np_st, dtype=torch.float)
#                 neighbor_gt = torch.tensor(neighbor_gt_np_st, dtype=torch.float)
#                 neighbors_data_st[edge_type].append(neighbor_state)
#                 neighbors_gt_st[edge_type].append(neighbor_gt)

#     # Robot
#     robot_traj_st_t = None
#     timestep_range_r = np.array([t, t + max_ft])
#     if hyperparams['incl_robot_node']:
#         x_node = node.get(timestep_range_r, state[node.type])
#         if scene.non_aug_scene is not None:
#             robot = scene.get_node_by_id(scene.non_aug_scene.robot.id)
#         else:
#             robot = scene.robot
#         robot_type = robot.type
#         robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
#         robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)

#     # Map
#     map_tuple = None
#     if hyperparams['use_map_encoding']:
#         if node.type in hyperparams['map_encoder']:
#             if node.non_aug_node is not None:
#                 x = node.non_aug_node.get(np.array([t]), state[node.type])
#             me_hyp = hyperparams['map_encoder'][node.type]
#             if 'heading_state_index' in me_hyp:
#                 heading_state_index = me_hyp['heading_state_index']
#                 # We have to rotate the map in the opposit direction of the agent to match them
#                 if type(heading_state_index) is list:  # infer from velocity or heading vector
#                     heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
#                                                 x[-1, heading_state_index[0]]) * 180 / np.pi
#                 else:
#                     heading_angle = -x[-1, heading_state_index] * 180 / np.pi
#             else:
#                 heading_angle = None

#             scene_map = scene.map[node.type]
#             map_point = x[-1, :2]


#             patch_size = hyperparams['map_encoder'][node.type]['patch_size']
#             map_tuple = (scene_map, map_point, heading_angle, patch_size)

#     return (first_history_index, x_t, y_t, x_st_t, y_st_t, neighbors_data_st, neighbors_gt_st,
#             neighbors_edge_value, robot_traj_st_t, map_tuple, scene.dt, (scene.name, t, '/'.join([node.type.name, node.id])))
    

def get_node_timestep_data(
    env,
    scene,
    t,
    node,
    state,
    pred_state,
    edge_types,
    max_ht,
    max_ft,
    hyperparams,
    scene_graph=None,
    normalize_direction=False,
):
    """
    Pre-processes the data for a single batch element: node state over time for a specific time in a specific scene
    as well as the neighbour data for it.

    :param env: Environment
    :param scene: Scene
    :param t: Timestep in scene
    :param node: Node
    :param state: Specification of the node state
    :param pred_state: Specification of the prediction state
    :param edge_types: List of all Edge Types for which neighbours are pre-processed
    :param max_ht: Maximum history timesteps
    :param max_ft: Maximum future timesteps (prediction horizon)
    :param hyperparams: Model hyperparameters
    :param scene_graph: If scene graph was already computed for this scene and time you can pass it here
    :return: Batch Element
    """

    # Node
    timestep_range_x = np.array([t - max_ht, t])
    timestep_range_y = np.array([t + 1, t + max_ft])

    x = node.get(timestep_range_x, state[node.type])
    y = node.get(timestep_range_y, pred_state[node.type])
    first_history_index = (max_ht - node.history_points_at(t)).clip(0)

    _, std = env.get_standardize_params(state[node.type], node.type)
    std[0:2] = env.attention_radius[(node.type, node.type)]
    rel_state = np.zeros_like(x[0])
    rel_state[0:2] = np.array(x)[-1, 0:2]
    x_st = env.standardize(x, state[node.type], node.type, mean=rel_state, std=std)
    if (
        list(pred_state[node.type].keys())[0] == "position"
    ):  # If we predict position we do it relative to current pos
        y_st = env.standardize(
            y, pred_state[node.type], node.type, mean=rel_state[0:2], std=std[0:2]
        )
    else:
        y_st = env.standardize(y, pred_state[node.type], node.type)

    x_t = torch.tensor(x, dtype=torch.float)
    y_t = torch.tensor(y, dtype=torch.float)
    x_st_t = torch.tensor(x_st, dtype=torch.float)
    y_st_t = torch.tensor(y_st, dtype=torch.float)

    # Neighbors
    neighbors_data_st = None
    neighbors_edge_value = None
    if hyperparams["edge_encoding"]:
        # Scene Graph
        scene_graph = (
            scene.get_scene_graph(
                t,
                env.attention_radius,
                hyperparams["edge_addition_filter"],
                hyperparams["edge_removal_filter"],
            )
            if scene_graph is None
            else scene_graph
        )

        neighbors_data_st = dict()
        neighbors_gt_st = dict()
        neighbors_edge_value = dict()
        for edge_type in edge_types:
            neighbors_data_st[edge_type] = list()
            neighbors_gt_st[edge_type] = list()
            # We get all nodes which are connected to the current node for the current timestep
            connected_nodes = scene_graph.get_neighbors(node, edge_type[1])

            if hyperparams["dynamic_edges"] == "yes":
                # We get the edge masks for the current node at the current timestep
                edge_masks = torch.tensor(
                    scene_graph.get_edge_scaling(node), dtype=torch.float
                )
                neighbors_edge_value[edge_type] = edge_masks

            for connected_node in connected_nodes:
                neighbor_state_np = connected_node.get(
                    timestep_range_x, state[connected_node.type], padding=0.0
                )
                neighbor_gt_np = connected_node.get(
                    timestep_range_y, pred_state[connected_node.type], padding=0.0
                )

                # Make State relative to node where neighbor and node have same state
                _, std = env.get_standardize_params(
                    state[connected_node.type], node_type=connected_node.type
                )
                std[0:2] = env.attention_radius[edge_type]
                equal_dims = np.min((neighbor_state_np.shape[-1], x.shape[-1]))
                rel_state = np.zeros_like(neighbor_state_np)
                rel_state[:, ..., :equal_dims] = x[-1, ..., :equal_dims]
                neighbor_state_np_st = env.standardize(
                    neighbor_state_np,
                    state[connected_node.type],
                    node_type=connected_node.type,
                    mean=rel_state,
                    std=std,
                )
                _, std = env.get_standardize_params(
                    pred_state[connected_node.type], node_type=connected_node.type
                )
                std[0:2] = env.attention_radius[edge_type]
                equal_dims = np.min((neighbor_gt_np.shape[-1], x.shape[-1]))
                rel_state = np.zeros_like(neighbor_gt_np)
                rel_state[:, ..., :equal_dims] = x[-1, ..., :equal_dims]
                neighbor_gt_np_st = env.standardize(
                    neighbor_gt_np,
                    pred_state[connected_node.type],
                    node_type=connected_node.type,
                    mean=rel_state,
                )

                neighbor_state = torch.tensor(neighbor_state_np_st, dtype=torch.float)
                neighbor_gt = torch.tensor(neighbor_gt_np_st, dtype=torch.float)
                neighbors_data_st[edge_type].append(neighbor_state)
                neighbors_gt_st[edge_type].append(neighbor_gt)

    # Robot
    robot_traj_st_t = None
    timestep_range_r = np.array([t, t + max_ft])
    if hyperparams["incl_robot_node"]:
        x_node = node.get(timestep_range_r, state[node.type])
        if scene.non_aug_scene is not None:
            robot = scene.get_node_by_id(scene.non_aug_scene.robot.id)
        else:
            robot = scene.robot
        robot_type = robot.type
        robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
        robot_traj_st_t = get_relative_robot_traj(
            env, state, x_node, robot_traj, node.type, robot_type
        )

    # Map
    map_tuple = None
    if hyperparams["use_map_encoding"]:
        if node.type in hyperparams["map_encoder"]:
            if node.non_aug_node is not None:
                x = node.non_aug_node.get(np.array([t]), state[node.type])
            me_hyp = hyperparams["map_encoder"][node.type]
            if "heading_state_index" in me_hyp:
                heading_state_index = me_hyp["heading_state_index"]
                # We have to rotate the map in the opposit direction of the agent to match them
                if (
                    type(heading_state_index) is list
                ):  # infer from velocity or heading vector
                    heading_angle = (
                        -np.arctan2(
                            x[-1, heading_state_index[1]], x[-1, heading_state_index[0]]
                        )
                        * 180
                        / np.pi
                    )
                else:
                    heading_angle = -x[-1, heading_state_index] * 180 / np.pi
            else:
                heading_angle = None

            scene_map = scene.map[node.type]
            map_point = x[-1, :2]

            patch_size = hyperparams["map_encoder"][node.type]["patch_size"]
            map_tuple = (scene_map, map_point, heading_angle, patch_size)

    if normalize_direction:
        # rotate
        x_t_rotate = torch.zeros_like(x_t)  # (8,6)
        y_t_rotate = torch.zeros_like(y_t)  # (12,2) vel
        x_st_t_rotate = torch.zeros_like(x_st_t)  # (8,6)
        y_st_t_rotate = torch.zeros_like(y_st_t)  # (12,2) vel
        neighbors_gt_st_rotate = deepcopy(neighbors_gt_st)

        current_vel = x_t[-1, 2:4]
        rotate_angle = -torch.arctan2(current_vel[1], current_vel[0])
        rotate_matrix = torch.tensor([torch.cos(rotate_angle), -torch.sin(rotate_angle),
                                        torch.sin(rotate_angle), torch.cos(rotate_angle)]).reshape(2,2).unsqueeze(0) # (1,2,2)
        
        x_t_rotate[:,0:2] = torch.bmm(rotate_matrix.repeat(8,1,1), (x_t[:,0:2] - x_t[-1,0:2]).unsqueeze(-1)).squeeze(-1) + x_t[-1,0:2]
        x_t_rotate[:,2:4] = torch.bmm(rotate_matrix.repeat(8,1,1), x_t[:,2:4].unsqueeze(-1)).squeeze(-1)
        x_t_rotate[:,4:6] = torch.bmm(rotate_matrix.repeat(8,1,1), x_t[:,4:6].unsqueeze(-1)).squeeze(-1)
        y_t_rotate = torch.bmm(rotate_matrix.repeat(12,1,1), y_t.unsqueeze(-1)).squeeze(-1)

        x_st_t_rotate[:,0:2] = torch.bmm(rotate_matrix.repeat(8,1,1), (x_st_t[:,0:2] - x_st_t[-1,0:2]).unsqueeze(-1)).squeeze(-1) + x_st_t[-1,0:2]
        x_st_t_rotate[:,2:4] = torch.bmm(rotate_matrix.repeat(8,1,1), x_st_t[:,2:4].unsqueeze(-1)).squeeze(-1)
        x_st_t_rotate[:,4:6] = torch.bmm(rotate_matrix.repeat(8,1,1), x_st_t[:,4:6].unsqueeze(-1)).squeeze(-1)
        y_st_t_rotate = torch.bmm(rotate_matrix.repeat(12,1,1), y_st_t.unsqueeze(-1)).squeeze(-1)

        if neighbors_gt_st[edge_type] is not None:
            for i, nb_fut in enumerate(neighbors_gt_st[edge_type]):
                neighbors_gt_st_rotate[edge_type][i] = torch.bmm(rotate_matrix.repeat(12,1,1), nb_fut.unsqueeze(-1)).squeeze(-1)

        x_t = x_t_rotate
        y_t = y_t_rotate
        x_st_t = x_st_t_rotate
        y_st_t = y_st_t_rotate
        neighbors_gt_st = neighbors_gt_st_rotate

    return (
        first_history_index,
        x_t,
        y_t,
        x_st_t,
        y_st_t,
        neighbors_data_st,
        neighbors_gt_st,
        neighbors_edge_value,
        robot_traj_st_t,
        map_tuple,
        scene.dt,
        (scene.name, t, "/".join([node.type.name, node.id])),
    )

def get_timesteps_data(env, scene, t, node_type, state, pred_state,
                       edge_types, min_ht, max_ht, min_ft, max_ft, hyperparams):
    """
    Puts together the inputs for ALL nodes in a given scene and timestep in it.

    :param env: Environment
    :param scene: Scene
    :param t: Timestep in scene
    :param node_type: Node Type of nodes for which the data shall be pre-processed
    :param state: Specification of the node state
    :param pred_state: Specification of the prediction state
    :param edge_types: List of all Edge Types for which neighbors are pre-processed
    :param max_ht: Maximum history timesteps
    :param max_ft: Maximum future timesteps (prediction horizon)
    :param hyperparams: Model hyperparameters
    :return:
    """
    nodes_per_ts = scene.present_nodes(t,
                                       type=node_type,
                                       min_history_timesteps=min_ht,
                                       min_future_timesteps=max_ft,
                                       return_robot=not hyperparams['incl_robot_node'])
    batch = list()
    nodes = list()
    out_timesteps = list()
    for timestep in nodes_per_ts.keys():
            scene_graph = scene.get_scene_graph(timestep,
                                                env.attention_radius,
                                                hyperparams['edge_addition_filter'],
                                                hyperparams['edge_removal_filter'])
            present_nodes = nodes_per_ts[timestep]
            for node in present_nodes:
                nodes.append(node)
                out_timesteps.append(timestep)
                batch.append(get_node_timestep_data(env, scene, timestep, node, state, pred_state,
                                                    edge_types, max_ht, max_ft, hyperparams,
                                                    scene_graph=scene_graph))
    if len(out_timesteps) == 0:
        return None
    return collate(batch), nodes, out_timesteps


def data_dict_to_next_step(data_dict, time_step):
    data_dict_ = deepcopy(data_dict)
    obs = data_dict["obs"]
    obs_st = data_dict["obs_st"]
    gt = data_dict["gt"]
    gt_st = data_dict["gt_st"]
    neighbors_st = data_dict["neighbors_st"]
    neighbors_gt_st = data_dict["neighbors_gt_st"]
    bs, n, d_o = obs.shape
    _, _, d_g = gt.shape
    data_dict_["obs"][:, :-time_step] = obs[:, time_step:]
    data_dict_["obs"][:, -time_step:, :d_g] = gt[:, :time_step]
    data_dict_["obs_st"][:, :-time_step] = obs_st[:, time_step:]
    data_dict_["obs_st"][:, -time_step:, :d_g] = gt_st[:, :time_step]
    """
    neighbors_st_ = []
    for b in range(bs):
        import pdb;pdb.set_trace()
        neighbors_st_.append(torch.cat([neighbors_st[b], neighbors_gt_st[b]], dim=1)[:, time_step:n+time_step])
    """
    data_dict_["neighbors_st"] = neighbors_st

    return data_dict_
