import torch


class OUNoise(object):
    def __init__(self, dim_action, mu=0.0, theta=0.15, sigma=0.2, device="cpu"):
        self.dim_action = dim_action
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.state = torch.ones(self.dim_action, device=device) * self.mu
        self._device = device

    def noise(self, scale):
        x = self.state
        dx = (self.theta * (self.mu - x) + self.sigma * torch.randn(*self.state.shape, device=self._device)) * scale
        self.state = x + dx
        return self.state

    def reset(self, _id=None):
        if _id is not None:
            self.state[_id] = (torch.ones(self.dim_action, device=self._device)[_id]) * self.mu
        else:
            self.state = torch.ones(self.dim_action, device=self._device) * self.mu


class GaussianNoise(object):
    def __init__(self, dim_action, mu=0.0, sigma=0.2, device="cpu", **kwargs):
        self.dim_action = dim_action
        self.mu = mu
        self.sigma = sigma
        self._device = device

    def noise(self, scale):
        return torch.normal(mean=self.mu, std=self.sigma, size=self.dim_action, device=self._device) * scale

    def reset(self, _id=None):
        pass
