import math

import numpy as np
import torch

from lambda_ac.rl_types import ExplorationScheduler


class LinearExplorationSchedule(ExplorationScheduler):
    def __init__(self, min, max, steps):
        self.min = min
        self.max = max
        self.steps = steps

        self.current_steps = 0

    def __call__(self, action):
        return torch.clamp(action + torch.randn_like(action) * self._std(), -1, 1)

    def step(self):
        self.current_steps += 1

    def set(self, value):
        self.current_steps = value

    def _std(self):
        return self.max - (self.max - self.min) * min(
            1, self.current_steps / self.steps
        )


class FixedExplorationSchedule(ExplorationScheduler):
    def __init__(self, value):
        self.value = value

    def __call__(self, action):
        return action + torch.randn_like(action) * self.value

    def step(self):
        pass

    def set(self, value):
        pass


class NoExplorationSchedule(ExplorationScheduler):
    def __call__(self, action):
        return action

    def step(self):
        pass

    def set(self, value):
        pass


class LinearSchedule:
    def __init__(self, start, final, duration):
        self.start = start
        self.final = final
        self.duration = duration

    def __call__(self, step):
        mix = np.clip(step / self.duration, 0.0, 1.0)
        return (1.0 - mix) * self.start + mix * self.final


class LinearIncrease:
    def __init__(self, start, min, max, steps):
        self.start = start
        self.min = min
        self.max = max
        self.steps = steps - start

        self.current_steps = 0

    def __call__(self):
        return self._value()

    def step(self):
        self.current_steps += 1

    def set(self, value):
        self.current_steps = value

    def _value(self):
        return self.min + (self.max - self.min) * min(
            1, max(0, self.current_steps - self.start) / self.steps
        )
