import torch

# the_distributions = ["Cauchy", "Gaussian", "Exponential", "Uniform", "Chi", "Lognormal", "t-Distribution"]


# # def batchnoise_func(tr_noise, te_noise, ):
# #     pass

# NOISE_SIZE = 256

# class Noise_Gen(object):
#     def __init__(self, distribution):

#         self.distribution = distribution
        
#         if distribution == "Cauchy":
#             self.noise = torch.distributions.cauchy.Cauchy(torch.zeros(NOISE_SIZE), torch.ones(NOISE_SIZE))
#         elif distribution == "Exponential":
#             self.noise = torch.distributions.exponential.Exponential(torch.ones(NOISE_SIZE))
#         elif distribution == "Uniform":
#             self.noise = torch.distributions.uniform.Uniform(-2*torch.ones(NOISE_SIZE), 2*torch.ones(NOISE_SIZE))
#         elif distribution == "Chi":
#             self.noise = torch.distributions.chi2.Chi2(torch.ones(NOISE_SIZE))
#         elif distribution == "Lognormal":
#             self.noise = torch.distributions.log_normal.LogNormal(torch.zeros(NOISE_SIZE), torch.ones(NOISE_SIZE))
#         elif distribution == "t-Distribution":
#             self.noise = torch.distributions.studentT.StudentT(2*torch.ones(NOISE_SIZE))

#     def gen_noise(self, x):
        
#         if self.distribution == "Gaussian":
#             return torch.randn_like(x)
#         else:
#             return self.noise.sample()[:x.shape]


def my_noise(x, distribution):

    if distribution == "Cauchy":
        noise = torch.distributions.cauchy.Cauchy(torch.zeros(x.shape), torch.ones(x.shape))
    elif distribution == "Gaussian":
        noise = torch.randn_like(x)
    elif distribution == "Exponential":
        noise = torch.distributions.exponential.Exponential(torch.ones(x.shape))
    elif distribution == "Uniform":
        noise = torch.distributions.uniform.Uniform(-2*torch.ones(x.shape), 2*torch.ones(x.shape))
    elif distribution == "Chi":
        noise = torch.distributions.chi2.Chi2(torch.ones(x.shape))
    elif distribution == "Lognormal":
        noise = torch.distributions.log_normal.LogNormal(torch.zeros(x.shape), torch.ones(x.shape))
    else:
        noise = torch.distributions.studentT.StudentT(2*torch.ones(x.shape))

    if distribution == "Gaussian":
        return noise

    return noise.sample()




