from numpy.random import default_rng
from scipy.stats import norm


class Interface:
    def __init__(self, graph):
        self.graph, self.oracle = graph, None
        self.seed()

    def seed(self, seed=None):
        self.rng = default_rng(seed)

    def __enter__(self): pass
    def __exit__(self, exc_type, exc_value, traceback): pass
    def print(self, _): pass


class OracleInterface(Interface):
    """
    Oracle class implementing the five modes of irrationality in the SimTeacher algorithm. From:
        Lee, K., L. Smith, A. Dragan, and P. Abbeel. "B-Pref: Benchmarking Preference-Based Reinforcement Learning." 
        Neural Information Processing Systems (NeurIPS) (2021).

    (1) "Myopic" recency bias with discount factor gamma
    (2) Query skipping if max(ret_i, ret_j) is below d_skip
        - NOTE: This may reduce the effective feedback budget
    (3) Nonzero beta in Bradley-Terry model, or Gaussian noise with standard deviation in Thurstone's model
    (4) Random flipping of P_i with probability epsilon
    (5) Equal preference expression if abs(P_i - 0.5) is below p_equal

    NOTE: Order of implementation here: (1),(2),(3),(4),(5)
    is different to the original paper: (2),(5),(1),(3),(4).
    """

    # Defaults
    P = {"gamma": 1, "beta": 0, "sigma": 0, "d_skip": -float("inf"), "p_equal": 0, "epsilon": 0, "return_P_i": False}

    def __init__(self, graph, P):
        Interface.__init__(self, graph)
        self.oracle = P["oracle"]
        self.P.update(P)
        assert not (self.P["beta"] and self.P["sigma"]), "Cannot simultaneously use both " \
                                                         "beta (Bradley-Terry) and sigma (Thurstone)"

    def __call__(self, i, j):
        ep_i, ep_j = self.graph.nodes[i], self.graph.nodes[j]
        ret_i = self.myopic_sum(self.oracle(ep_i["states"], ep_i["actions"], ep_i["next_states"]))
        ret_j = self.myopic_sum(self.oracle(ep_j["states"], ep_j["actions"], ep_j["next_states"]))
        if max(ret_i, ret_j) < self.P["d_skip"]:   return "skip", {}
        diff = ret_i - ret_j
        if self.P["beta"] > 0:    P_i = (1. / (1. + (-diff / self.P["beta"]).exp())).item() # Bradley-Terry
        elif self.P["sigma"] > 0: P_i = norm.cdf(diff / self.P["sigma"]) # Thurstone
        else:                     P_i = 0.5 if diff == 0 else 1. if diff > 0 else 0. # Deterministic
        if self.rng.random() <= self.P["epsilon"]: P_i = 1. - P_i
        info = {"confidence": max(P_i, 1-P_i)}
        if self.P["return_P_i"]:                   return P_i, info
        elif abs(P_i - 0.5) <= self.P["p_equal"]:  return 0.5, info
        elif self.rng.random() < P_i:              return 1.0, info
        else:                                      return 0.0, info

    def myopic_sum(self, rewards):
        if self.P["gamma"] == 1: return sum(rewards)
        return sum([r*(self.P["gamma"]**t) for t,r in enumerate(reversed(rewards))])
