from __future__ import annotations

from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from itertools import product

import networkx as nx
import numpy as np

from minigrid.core.constants import COLOR_TO_IDX, IDX_TO_OBJECT, OBJECT_TO_IDX
from minigrid.minigrid_env import MiniGridEnv


@dataclass
class EdgeDescriptor:
    between: tuple[str, str] | tuple[str]
    structure: str | None = None


# This is maybe general enough to be in utils
class GraphTransforms:
    OBJECT_TO_DENSE_GRAPH_ATTRIBUTE = {
        "empty": ("navigable", "empty"),
        "start": ("navigable", "start"),
        "agent": ("navigable", "start"),
        "goal": ("navigable", "goal"),
        "moss": ("navigable", "moss"),
        "wall": ("non_navigable", "wall"),
        "lava": ("non_navigable", "lava"),
    }

    DENSE_GRAPH_ATTRIBUTE_TO_OBJECT = {
        "empty": "empty",
        "start": "start",
        "goal": "goal",
        "moss": "moss",
        "wall": "wall",
        "lava": "lava",
        "navigable": None,
        "non_navigable": None,
    }

    MINIGRID_COLOR_CONFIG = {
        "empty": None,
        "wall": "grey",
        "agent": "blue",
        "goal": "green",
        "lava": "red",
        "moss": "purple",
    }

    @staticmethod
    def minigrid_to_bitmap(grids):

        layout = grids[..., 0]
        bitmap = np.zeros_like(layout)
        bitmap[layout == 2] = 1
        bitmap = list(bitmap)

        start_pos_id = np.where(layout == 10)
        goal_pos_id = np.where(layout == 8)

        start_pos = []
        goal_pos = []
        for i in range(len(bitmap)):
            bitmap[i] = bitmap[i][1:-1, 1:-1]
            start_pos.append(np.array([start_pos_id[2][i], start_pos_id[1][i]]))
            goal_pos.append(np.array([goal_pos_id[2][i], goal_pos_id[1][i]]))

        return bitmap, start_pos, goal_pos

    @staticmethod
    def minigrid_to_dense_graph(
        minigrids: np.ndarray | list[MiniGridEnv],
        node_attr=None,
        edge_config=None,
    ) -> list[nx.Graph]:
        if isinstance(minigrids[0], np.ndarray):
            minigrids = np.array(minigrids)
            layouts = minigrids[..., 0]
        elif isinstance(minigrids[0], MiniGridEnv):
            layouts = [minigrid.grid.encode()[..., 0] for minigrid in minigrids]
            for i in range(len(minigrids)):
                layouts[i][tuple(minigrids[i].agent_pos)] = OBJECT_TO_IDX["agent"]
            layouts = np.array(layouts)
        else:
            raise TypeError(
                f"minigrids must be of type List[np.ndarray], List[MiniGridEnv], "
                f"List[MultiGridEnv], not {type(minigrids[0])}"
            )
        graphs, _ = GraphTransforms.minigrid_layout_to_dense_graph(
            layouts, remove_border=True, node_attr=node_attr, edge_config=edge_config
        )
        return graphs

    @staticmethod
    def minigrid_layout_to_dense_graph(
        layouts: np.ndarray, remove_border=True, node_attr=None, edge_config=None
    ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:

        assert (
            layouts.ndim == 3
        ), f"Wrong dimensions for minigrid layout, expected 3 dimensions, got {layouts.ndim}."

        node_attr = [] if node_attr is None else node_attr

        # Remove borders
        if remove_border:
            layouts = layouts[:, 1:-1, 1:-1]  # remove edges
        dim_grid = layouts.shape[1:]

        # Get the objects present in the layout
        objects_idx = np.unique(layouts)
        object_instances = [IDX_TO_OBJECT[obj] for obj in objects_idx]
        assert set(object_instances).issubset(
            {"empty", "wall", "start", "goal", "agent", "lava", "moss"}
        ), (
            f"Unsupported object(s) in minigrid layout. Supported objects are: "
            f"empty, wall, start, goal, agent, lava, moss. Got {object_instances}."
        )

        # Get location of each object in the layout
        object_locations = {}
        for obj in object_instances:
            object_locations[obj] = defaultdict(list)
            ids = list(zip(*np.where(layouts == OBJECT_TO_IDX[obj])))
            for tup in ids:
                object_locations[obj][tup[0]].append(tup[1:])
            for m in range(layouts.shape[0]):
                if m not in object_locations[obj]:
                    object_locations[obj][m] = []
            object_locations[obj] = OrderedDict(sorted(object_locations[obj].items()))
        if "start" not in object_instances and "agent" in object_instances:
            object_locations["start"] = object_locations["agent"]
        if "agent" not in object_instances and "start" in object_instances:
            object_locations["agent"] = object_locations["start"]

        # Create one-hot graph feature tensor
        graph_feats = {}
        object_to_attr = GraphTransforms.OBJECT_TO_DENSE_GRAPH_ATTRIBUTE
        for obj in object_instances:
            for attr in object_to_attr[obj]:
                if attr not in graph_feats and attr in node_attr:
                    graph_feats[attr] = np.zeros(layouts.shape)
                loc = list(object_locations[obj].values())
                assert len(loc) == layouts.shape[0]
                for m in range(layouts.shape[0]):
                    if loc[m]:
                        loc_m = np.array(loc[m])
                        graph_feats[attr][m][loc_m[:, 0], loc_m[:, 1]] = 1
        for attr in node_attr:
            if attr not in graph_feats:
                graph_feats[attr] = np.zeros(layouts.shape)
            graph_feats[attr] = graph_feats[attr].reshape(layouts.shape[0], -1)

        graphs, edge_graphs = GraphTransforms.features_to_dense_graph(
            graph_feats, dim_grid, edge_config
        )

        return graphs, edge_graphs

    @staticmethod
    def features_to_dense_graph(
        features: dict[str, np.ndarray],
        dim_grid: tuple,
        edge_config: dict[str, EdgeDescriptor] = None,
    ) -> tuple[list[nx.Graph], dict[str, list[nx.Graph]]]:

        graphs = []
        edge_graphs = defaultdict(list)
        for m in range(features[list(features.keys())[0]].shape[0]):
            g_temp = nx.grid_2d_graph(*dim_grid)
            g = nx.Graph()
            g.add_nodes_from(sorted(g_temp.nodes(data=True)))
            for attr in features:
                nx.set_node_attributes(
                    g, {k: v for k, v in zip(g.nodes, features[attr][m].tolist())}, attr
                )
            if edge_config is not None:
                edge_layers = GraphTransforms.get_edge_layers(
                    g, edge_config, list(features.keys()), dim_grid
                )
                for edge_n, edge_g in edge_layers.items():
                    g.add_edges_from(edge_g.edges(data=True), label=edge_n)
                    edge_graphs[edge_n].append(edge_g)
            graphs.append(g)

        return graphs, edge_graphs

    @staticmethod
    def graph_features_to_minigrid(
        graph_features: dict[str, np.ndarray], shape: tuple[int, int], padding=1
    ) -> np.ndarray:

        features = graph_features.copy()
        node_attributes = list(features.keys())

        color_config = GraphTransforms.MINIGRID_COLOR_CONFIG

        # shape_no_padding = (features[node_attributes[0]].shape[-2], shape[0] - 2, shape[1] - 2, 3)
        shape_no_padding = (shape[0] - 2 * padding, shape[1] - 2 * padding, 3)
        for attr in node_attributes:
            features[attr] = features[attr].reshape(*shape_no_padding[:-1])
        grids = np.ones(shape_no_padding, dtype=np.uint8) * OBJECT_TO_IDX["empty"]

        minigrid_object_to_encoding_map = {}  # [object_id, color, state]
        for feature in node_attributes:
            obj_type = GraphTransforms.DENSE_GRAPH_ATTRIBUTE_TO_OBJECT[feature]
            if (
                obj_type is not None
                and obj_type not in minigrid_object_to_encoding_map.keys()
            ):
                if obj_type == "empty":
                    minigrid_object_to_encoding_map[obj_type] = [
                        OBJECT_TO_IDX["empty"],
                        0,
                        0,
                    ]
                elif obj_type == "agent":
                    minigrid_object_to_encoding_map[obj_type] = [
                        OBJECT_TO_IDX["agent"],
                        0,
                        0,
                    ]
                elif obj_type == "start":
                    color_str = color_config["agent"]
                    minigrid_object_to_encoding_map[obj_type] = [
                        OBJECT_TO_IDX["agent"],
                        COLOR_TO_IDX[color_str],
                        0,
                    ]
                else:
                    color_str = color_config[obj_type]
                    minigrid_object_to_encoding_map[obj_type] = [
                        OBJECT_TO_IDX[obj_type],
                        COLOR_TO_IDX[color_str],
                        0,
                    ]

        if (
            "start" not in minigrid_object_to_encoding_map.keys()
            and "agent" in minigrid_object_to_encoding_map.keys()
        ):
            minigrid_object_to_encoding_map["start"] = minigrid_object_to_encoding_map[
                "agent"
            ]
        if (
            "agent" not in minigrid_object_to_encoding_map.keys()
            and "start" in minigrid_object_to_encoding_map.keys()
        ):
            minigrid_object_to_encoding_map["agent"] = minigrid_object_to_encoding_map[
                "start"
            ]

        for i, attr in enumerate(node_attributes):
            if "wall" not in node_attributes:
                if attr == "navigable" and "wall" not in node_attributes:
                    mapping = minigrid_object_to_encoding_map["wall"]
                    grids[features[attr] == 0] = np.array(mapping, dtype=np.uint8)
                else:
                    mapping = minigrid_object_to_encoding_map[attr]
                    grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
            else:
                try:
                    mapping = minigrid_object_to_encoding_map[attr]
                    grids[features[attr] == 1] = np.array(mapping, dtype=np.uint8)
                except KeyError:
                    pass

        wall_encoding = np.array(
            minigrid_object_to_encoding_map["wall"], dtype=np.uint8
        )
        padded_grid = np.pad(
            grids,
            ((padding, padding), (padding, padding), (0, 0)),
            "constant",
            constant_values=-1,
        )
        padded_grid = np.where(
            padded_grid == -np.ones(3, dtype=np.uint8), wall_encoding, padded_grid
        )
        return padded_grid

    @staticmethod
    def get_node_features(
        graph: nx.Graph, pattern_shape, node_attributes: list[str] = None, reshape=True
    ) -> tuple[np.ndarray, list[str]]:

        if node_attributes is None:
            # Get node attributes from some node
            node_attributes = list(next(iter(graph.nodes.data()))[1].keys())

        # Get node features
        Fx = []
        for attr in node_attributes:
            if attr == "non_navigable" or attr == "wall":
                # The graph we are getting is only the navigable nodes so those that
                # are not present should be assumed to be walls and non-navigable
                f = np.ones(pattern_shape)
            else:
                f = np.zeros(pattern_shape)
            for node, data in graph.nodes.data(attr):
                f[node] = data
            if reshape:
                f = f.ravel()
            Fx.append(f)
        # Fx = torch.stack(Fx, dim=-1).to(device)
        Fx = np.stack(Fx, axis=-1)

        return Fx, node_attributes

    @staticmethod
    def dense_graph_to_minigrid(
        graph: nx.Graph, shape: tuple[int, int], padding=1
    ) -> np.ndarray:

        pattern_shape = (shape[0] - 2 * padding, shape[1] - 2 * padding)
        features, node_attributes = GraphTransforms.get_node_features(
            graph, pattern_shape, node_attributes=None
        )
        # num_zeros = features[features == 0.0].numel()
        # num_ones = features[features == 1.0].numel()
        num_zeros = (features == 0.0).sum()
        num_ones = (features == 1.0).sum()

        assert num_zeros + num_ones == features.size, "Graph features should be binary"
        features_dict = {}
        for i, key in enumerate(node_attributes):
            features_dict[key] = features[..., i]
        grids = GraphTransforms.graph_features_to_minigrid(
            features_dict, shape=shape, padding=padding
        )

        return grids

    @staticmethod
    def get_edge_layers(
        graph: nx.Graph,
        edge_config: dict[str, EdgeDescriptor],
        node_attr: list[str],
        dim_grid: tuple[int, int],
    ) -> dict[str, nx.Graph]:

        navigable_nodes = ["empty", "start", "goal", "moss"]
        non_navigable_nodes = ["wall", "lava"]
        assert all([isinstance(n, tuple) for n in graph.nodes])
        assert all([len(n) == 2 for n in graph.nodes])

        def partial_grid(graph, nodes, dim_grid):
            non_grid_nodes = [n for n in graph.nodes if n not in nodes]
            g_temp = nx.grid_2d_graph(*dim_grid)
            g_temp.remove_nodes_from(non_grid_nodes)
            g_temp.add_nodes_from(non_grid_nodes)
            g = nx.Graph()
            g.add_nodes_from(graph.nodes(data=True))
            g.add_edges_from(g_temp.edges)
            return g

        def pair_edges(graph, node_types):
            all_nodes = []
            for n_type in node_types:
                all_nodes.append(
                    [n for n, a in graph.nodes.items() if a[n_type] >= 1.0]
                )
            edges = list(product(*all_nodes))
            edged_graph = nx.create_empty_copy(graph, with_data=True)
            edged_graph.add_edges_from(edges)
            return edged_graph

        edge_graphs = {}
        for edge_ in edge_config.keys():
            if edge_ == "navigable" and "navigable" not in node_attr:
                edge_config[edge_].between = navigable_nodes
            elif edge_ == "non_navigable" and "non_navigable" not in node_attr:
                edge_config[edge_].between = non_navigable_nodes
            elif not set(edge_config[edge_].between).issubset(set(node_attr)):
                # TODO: remove
                # logger.warning(f"Edge {edge_} not compatible with node attributes {node_attr}. Skipping.")
                continue
            if edge_config[edge_].structure is None:
                edge_graphs[edge_] = pair_edges(graph, edge_config[edge_].between)
            elif edge_config[edge_].structure == "grid":
                nodes = []
                for n_type in edge_config[edge_].between:
                    nodes += [
                        n
                        for n, a in graph.nodes.items()
                        if a[n_type] >= 1.0 and n not in nodes
                    ]
                edge_graphs[edge_] = partial_grid(graph, nodes, dim_grid)
            else:
                raise NotImplementedError(
                    f"Edge structure {edge_config[edge_].structure} not supported."
                )

        return edge_graphs
