import functools
from typing import Protocol

import jax

import rlax
from rlax._src import distributions


class Policy(Protocol):

    def sample(self, key, observations):
        raise NotImplementedError

    def sample_and_split(self, key, observations):
        key, a_key = jax.random.split(key)
        return key, self.sample(a_key, observations)


class DiscretePolicy(Policy, Protocol):
    action_distribution: distributions.DiscreteDistribution

    def probs(self, observations):
        raise NotImplementedError


class EpsilonGreedyPolicy(DiscretePolicy):

    def __init__(self, epsilon):
        self.action_distribution = rlax.epsilon_greedy(epsilon)

    def preferences(self, observations):
        raise NotImplementedError

    def sample(self, key, observations):
        return self.action_distribution.sample(key, self.preferences(observations))

    def probs(self, observations):
        return self.action_distribution.probs(self.preferences(observations))

    @functools.partial(jax.jit, static_argnums=0)
    def sample_and_split(self, key, observations):
        return super().sample_and_split(key, observations)

