from .generic import AbstractDelayGenerator
from scipy.stats import truncnorm
import random

class Fixed(AbstractDelayGenerator):
    def __init__(self, length):
        self._length = length
    
    def generate(self, obs):
        return self._length


class TruncatedGaussian(AbstractDelayGenerator):
    def __init__(self, mu, scale, minimum, maximum):
        maximum += 1 # to compensate for rounding in generate
        self._density = truncnorm((minimum-mu)/scale,
                                  (maximum-mu)/scale,
                                  loc=mu,
                                  scale=scale)
    
    def generate(self, obs):
        d = int(self._density.rvs(1)[0])
        return d
    
class Uniform(AbstractDelayGenerator):
    def __init__(self, minimum, maximum):
        self.minimum = minimum
        self.maximum = maximum
    
    def generate(self, obs):
        d = random.randint(self.minimum, self.maximum)
        return d
    
class UniformlyIncreasing(AbstractDelayGenerator):
    def __init__(self, minimum, maximum):
        self.minimum = minimum
        self.maximum = maximum
    
    def generate(self, obs):
        r = obs['time'] / 500
        return min(int(self.minimum + r * (self.maximum - self.minimum)), self.maximum)

class UniformlyDecreasing(AbstractDelayGenerator):
    def __init__(self, minimum, maximum):
        self.minimum = minimum
        self.maximum = maximum
    
    def generate(self, obs):
        r = obs['time'] / 500
        return max(int(self.maximum - r * (self.maximum - self.minimum)), self.minimum)

