from options.TerminationFunctions.TerminationFunction import TerminationFunction
from options.SharedModels.NNSharedTerminationOpenGrid import NNSharedTerminationOpenGrid

class NNTerminationFunction(TerminationFunction):
    def __init__(self, shared_model: NNSharedTerminationOpenGrid, option_idx: int):   
        self.shared_model = shared_model
        self.option_idx = option_idx

    def get_termination_probability(self, state: int, target_state: int) -> float:
        return self.shared_model.get_termination_prob(state, target_state, self.option_idx)