"""Custom Gym Space for Heterogenous Graph Edges where the source and target edges can be heterogenous (optional) and the edge can store a feature (optional)
author: Anonymous
"""
from __future__ import annotations

from typing import Any, NamedTuple, Sequence

import numpy as np
from numpy.typing import NDArray

import gymnasium as gym
from gymnasium.spaces.box import Box
from gymnasium.spaces.discrete import Discrete
from gymnasium.spaces.multi_discrete import MultiDiscrete
from gymnasium.spaces.space import Space

class EdgeInstance(NamedTuple):
    """A Graph space instance.
    * edges (Optional[np.ndarray]): an (m x ...) sized array representing the features for m edges, (...) must adhere to the shape of the edge space.
    * edge_links (Optional[np.ndarray]): an (m x 2) sized array of ints representing the indices of the two nodes that each edge connects.
    """
    source_nodes: NDArray[Any]
    target_nodes: NDArray[Any]
    edge_features: NDArray[Any] | None


class Edge(Space[EdgeInstance]):
    r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.

    Example:
        >>> from gymnasium.spaces import Graph, Box, Discrete
        >>> observation_space = Edge(
            source_space=Discrete(3),
            target_space=Discrete(3),
            edge_space=Box(0, 1, shape=(5,))
            seed=42)
        >>> observation_space.sample(num_edges=1)
        EdgeInstance(
            source_nodes=array([2]), 
            target_nodes=array([2]), 
            edge_features=array(
                [[0.09417735, 0.97562236, 0.7611397 , 0.7860643 , 0.12811363]],
                dtype=float32))
        >>> observation_space.sample(num_edges=10)
        EdgeInstance(
            source_nodes=array([2, 1, 2, 2, 0, 2, 2, 2, 0, 1]), 
            target_nodes=array([1, 2, 1, 2, 1, 0, 1, 0, 2, 1]), 
            edge_features=array(
                [[0.75808775, 0.35452595, 0.970698  , 0.8931211 , 0.7783835 ],
                [0.19463871, 0.466721  , 0.04380377, 0.1542895 , 0.68304896],
                [0.7447622 , 0.96750975, 0.32582536, 0.3704597 , 0.46955582],
                [0.18947136, 0.12992151, 0.47570494, 0.22690935, 0.669814  ],
                [0.4371519 , 0.8326782 , 0.7002651 , 0.31236663, 0.8322598 ],
                [0.80476433, 0.38747838, 0.2883281 , 0.6824955 , 0.13975248],
                [0.1999082 , 0.00736227, 0.78692436, 0.66485083, 0.7051654 ],
                [0.78072906, 0.45891577, 0.5687412 , 0.139797  , 0.11453007],
                [0.66840297, 0.47109622, 0.5652361 , 0.76499885, 0.6347183 ],
                [0.5535794 , 0.55920714, 0.3039501 , 0.03081783, 0.4367174 ]],
                dtype=float32))
    """

    def __init__(
        self,
        source_space: Discrete,
        target_space: None | Discrete,
        edge_space: None | Box | Discrete = None,
        seed: int | np.random.Generator | None = None,
    ):
        r"""Constructor of :class:`Edge`.

        The argument ``edge_space`` specifies the base space that each edge feature will use.
        This argument must be either a Box or Discrete instance.

        Args:
            source_space (Discrete): space of the source nodes.
            target_space (Optional[Discrete]): space of the target nodes. If None, then it is the same as the source_space.
            edge_space (Optional[Box, Discrete]): space of the edge features.
        """
        self.source_space = source_space
        if target_space is None:
            self.target_space = source_space
        else:   
            self.target_space = target_space
        self.edge_space = edge_space

        super().__init__(None, None, seed)

    @property
    def is_np_flattenable(self):
        """Checks whether this space can be flattened to a :class:`spaces.Box`."""
        return False

    def _generate_sample_space(
        self, base_space: None | Box | Discrete, num: int
    ) -> Box | MultiDiscrete | None:
        if num == 0 or base_space is None:
            return None

        if isinstance(base_space, Box):
            return Box(
                low=np.array(max(1, num) * [base_space.low]),
                high=np.array(max(1, num) * [base_space.high]),
                shape=(num,) + base_space.shape,
                dtype=base_space.dtype,
                seed=self.np_random,
            )
        elif isinstance(base_space, Discrete):
            return MultiDiscrete(nvec=[base_space.n] * num, seed=self.np_random)
        else:
            raise TypeError(
                f"Expects base space to be Box and Discrete, actual space: {type(base_space)}."
            )

    def sample(
        self,
        mask: None
        | (
            tuple[
                NDArray[Any] | tuple[Any, ...] | None, # source mask
                NDArray[Any] | tuple[Any, ...] | None, # target mask
            ]
        ) = None,
        num_edges: int | None = None
    ) -> EdgeInstance:
        """Generates a single sample edge with num_nodes between 1 and 10 sampled from the Graph.

        Args:
            mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
                (Box spaces don't support sample masks).
                If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges
            num_nodes: The number of nodes that will be sampled, the default is 10 nodes
            num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes` ^ 2

        Returns:
            A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`.
        """

        if mask is not None:
            source_mask, target_mask = mask
        else:
            source_mask, target_mask = None, None

        sampled_source_space = self._generate_sample_space(self.source_space, num_edges)
        sampled_target_space = self._generate_sample_space(self.target_space, num_edges)
        if self.edge_space is not None:
            sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
            
        return EdgeInstance(
            source_nodes=sampled_source_space.sample(),
            target_nodes=sampled_source_space.sample(),
            edge_features=sampled_edge_space.sample() if self.edge_space is not None else None
        )


    def contains(self, x: EdgeInstance) -> bool:
        """Return boolean specifying if x is a valid member of this space."""
        if isinstance(x, EdgeInstance):
            if isinstance(x.source_nodes, np.ndarray) and isinstance(x.target_nodes, np.ndarray):
                if all(node in self.source_space for node in x.source_nodes) and all(node in self.target_space for node in x.target_nodes):
                    if self.edge_space is not None:
                        if isinstance(x.edge_features, np.ndarray):
                            if all(edge in self.edge_space for edge in x.edge_features):
                                return True
                    else:
                        return x.edge_features is None
        return False

    def __repr__(self) -> str:
        """A string representation of this space.

        The representation will include node_space and edge_space

        Returns:
            A representation of the space
        """
        return f"Edge({self.source_space}, {self.target_space}, {self.edge_space})"

    def __eq__(self, other: Any) -> bool:
        """Check whether `other` is equivalent to this instance."""
        return (
            isinstance(other, Edge)
            and self.source_space == other.source_space
            and self.target_space == other.target_space
            and self.edge_space == other.edge_space
        )

    def to_jsonable(
        self, sample_n: Sequence[EdgeInstance]
    ) -> list[dict[str, list[int | float]]]:
        """Convert a batch of samples from this space to a JSONable data type."""
        ret_n = []
        for sample in sample_n:
            ret = {}
            ret["source"] = sample.source_nodes.tolist()
            ret["target"] = sample.target_nodes.tolist()
            if sample.edge_features is not None:
                ret["edge_features"] = sample.edge_features.tolist()
            ret_n.append(ret)
        return ret_n

    def from_jsonable(
        self, sample_n: Sequence[dict[str, list[list[int] | list[float]]]]
    ) -> list[EdgeInstance]:
        """Convert a JSONable data type to a batch of samples from this space."""
        ret: list[EdgeInstance] = []
        for sample in sample_n:
            if "edge_features" in sample:
                assert self.edge_space is not None
                ret_n = EdgeInstance(
                    np.asarray(sample["source"], dtype=self.source_space.dtype),
                    np.asarray(sample["target"], dtype=self.target_space.dtype),
                    np.asarray(sample["edge_features"], dtype=self.edge_space.dtype)
                )
            else:
                ret_n = EdgeInstance(
                    np.asarray(sample["source"], dtype=self.source_space.dtype),
                    np.asarray(sample["target"], dtype=self.target_space.dtype),
                    None
                )
            ret.append(ret_n)
        return ret
    
import unittest
import numpy as np

class TestEdge(unittest.TestCase):
    def setUp(self):
        self.source_space = Discrete(3)
        self.target_space = Discrete(4)
        self.edge_space = Box(0, 1, shape=(5,))
        # self.edge_space = Discrete(10)
        self.edge = Edge(self.source_space, self.target_space, self.edge_space, seed=42)

    def test_sample(self):
        sample = self.edge.sample(num_edges=2)
        self.assertIsInstance(sample, EdgeInstance)
        self.assertEqual(sample.source_nodes.shape, (2,))
        self.assertEqual(sample.target_nodes.shape, (2,))
        self.assertEqual(sample.edge_features.shape, (2, 5))

    def test_contains(self):
        sample = self.edge.sample(num_edges=2)
        self.assertTrue(self.edge.contains(sample))

        # Create an invalid sample
        invalid_sample = EdgeInstance(
            source_nodes=np.array([3, 4]),  # Invalid values for Discrete(3)
            target_nodes=np.array([0, 1]),
            edge_features=np.array([[0.5], [1.5]])  # Invalid value for Box(0, 1)
        )
        self.assertFalse(self.edge.contains(invalid_sample))

    def test_print(self):
        sample = self.edge.sample(num_edges=10)
        print(sample)

if __name__ == "__main__":
    unittest.main()