import torch
from ..configs.exponential_scale_scheduler import ExponentialDecayScale


class ExponentialDecayScaleGenerator:
    def __init__(self, config: ExponentialDecayScale):
        self.config = config
        self.samples = self.generate().tolist()
        self.samples.append(config.end_value)
        self.step = 0

    def generate(self):
        t = torch.linspace(0, 1, self.config.n_samples)
        decayed_values = self.config.start_value * torch.exp(
            -self.config.damping_factor * t
        )
        actual_end_value = self.config.end_value
        values = decayed_values + (actual_end_value - decayed_values[-1]) * (
            1 - torch.exp(-self.config.damping_factor * t)
        )
        return values

    def get_value(self, step=True):
        """
        if it runs out of values, then it would keep returning the last sample forever
        """
        value = self.samples[min(self.step, len(self.samples) - 1)]

        if step is True:
            self.step += 1

        return value
