from collections import OrderedDict
from gymnasium import spaces
import numpy as np

class KissatSpace(spaces.Dict):
    """
    literal_candidates: whether the literal can be branched on
    literal_values: value of literal
    clause_refs: ref of clause
    clause_glues: glue (LBD) of clause
    literal/clause_indices: the edge indices
    """
    def __init__(self, seed = None, **spaces_kwargs):
        # spaces for elements
        bool_space = spaces.Box(low=0, high=1, dtype=np.bool_)
        value_space = spaces.Box(low=-1, high=1, dtype=np.int8)
        uint32_space = spaces.Box(low=0, high=2**32-1, dtype=np.uint32)

        # spaces for dict items
        dict_spaces = OrderedDict()
        dict_spaces["literal_candidates"] = spaces.Sequence(bool_space, stack=True)
        dict_spaces["literal_values"] = spaces.Sequence(value_space, stack=True)
        dict_spaces["clause_refs"] = spaces.Sequence(uint32_space, stack=True)
        dict_spaces["clause_glues"] = spaces.Sequence(uint32_space, stack=True)
        dict_spaces["literal_indices"] = spaces.Sequence(uint32_space, stack=True)
        dict_spaces["clause_indices"] = spaces.Sequence(uint32_space, stack=True)
        super().__init__(dict_spaces, seed, **spaces_kwargs)

    def contains(self, x: dict) -> bool:
        # literal
        n_literal = len(x["literal_candidates"])
        if (x["literal_values"].shape[0] != n_literal):
            return False
        if np.any(x["literal_indices"] >= n_literal):
            return False
        
        # clause
        n_clause = len(x["clause_refs"])
        if (x["clause_glues"].shape[0] != n_clause):
            return False
        if np.any(x["clause_indices"] >= n_clause):
            return False
        
        # edge
        if len(x["literal_indices"]) != len(x["clause_indices"]):
            return False
        
        return super().contains(x)
