import math

import torch

from experiments.models import stochastic_model

zero_entropy_noise_stddev = 1 / math.sqrt(2 * math.pi * math.e)
zero_entropy_noise_var = 1 / (2 * math.pi * math.e)


def _inject_zero_entropy_noise(output: torch.Tensor):
    noise = torch.randn_like(output) * zero_entropy_noise_stddev
    return output + noise


class StochasticInjectZeroEntropyNoise(stochastic_model.StochasticModel):
    def stochastic_forward_impl(self, intermediate_xk_: torch.Tensor):
        return _inject_zero_entropy_noise(intermediate_xk_)


class InjectZeroEntropyNoise(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        return _inject_zero_entropy_noise(x)
