import jax

class Model_Checker:

    """Base class for the various model checking algorithms

    Input attributes:
        device: JAX device on which to do most of the computation
        formula: temporal formula to check
    """

    def __init__(self, device, formula):
        self.device = device
        self.formula = formula

    def check(self):
        pass

class PCTL_Model_Checker(Model_Checker):

    """Parent class for the various PCTL model checking algorithms

    Input attributes:
        device: JAX device on which to do most of the computation
        formula: the PCTL formula to check
        labelling_fn: vectorized labelling function used for model checking
        atomic_predicate_map: dictionary mapping each atomic predicate to an index of the labelling_fn
    """

    def __init__(self, device, formula, labelling_fn, atomic_predicate_map):
        super().__init__(device, formula)
        # put the vectorized labelling function on the desired device
        self.labelling_fn = jax.device_put(labelling_fn, device=self.device)
        self.atomic_predicate_map = atomic_predicate_map

    def check(self):
        pass


    