import numpy as np
from model.generative_model import Generative_Model

class Hard_MDP_Unichain(Generative_Model):
    def __init__(self, p):
        self.transition_map = {self.sa_pairs[0]: np.array([1-p,p]),
                               self.sa_pairs[1]: np.array([1-p,p]),
                               self.sa_pairs[2]: np.array([p,1-p]),
                               self.sa_pairs[3]: np.array([p,1-p])}
        self.reward_map = {self.sa_pairs[0]: np.array([0,1]),
                           self.sa_pairs[1]: np.array([0,1]),
                           self.sa_pairs[2]: np.array([1,0]),
                           self.sa_pairs[3]: np.array([1,0])}
        
    @property
    def is_mdp(self):
        return True
    @property
    def states(self):
        return np.array([0, 1])
    @property
    def rewards(self):
        return np.array([0, 1])
    @property
    def sa_pairs(self):
        return [(0,0), (0, 1), (1, 0), (1, 1)]
    @property
    def action_at_state(self):
        return self.generate_action_at_state()