import numpy as np
import torch
from torch.distributions.gamma import Gamma
from .abstract import Distribution


class Gaussian_torch:
    def __init__(self, loc=0, scale=1.0):
        self.loc = loc
        self.scale = scale

    def __call__(self, rng, shape=None):
        return self.scale * torch.randn(shape, generator=rng) + self.loc
      

class Gaussian(Distribution):
    def __init__(self, loc=0, scale=1.0):
        self.loc = loc
        self.scale = scale

    def __call__(self, rng, shape=None):
        return self.scale * rng.normal(size=shape) + self.loc


class InverseGamma_torch():
    def __init__(self, shape, scale=1.0):
        self.shape = shape  # shape parameter (alpha)
        self.scale = scale  # scale parameter (beta)
        self.gamma_dist = Gamma(concentration=shape, rate=1/scale)

    def __call__(self, seed, shape=torch.Size([1])):
        torch.manual_seed(seed)
        gamma_samples = self.gamma_dist.sample(shape)
        inverse_gamma_samples = 1.0 / gamma_samples
        return torch.sqrt(inverse_gamma_samples)


class InverseGamma(Distribution):
    def __init__(self, shape, scale=1.0):
        self.shape = shape  # shape parameter (alpha)
        self.scale = scale  # scale parameter (beta)

    def __call__(self, rng, shape=None):
        gamma_samples = rng.gamma(self.shape, 1.0 / self.scale, size=shape)
        return np.sqrt(1.0 / gamma_samples)


class Laplace(Distribution):
    def __init__(self, scale=1):
        self.scale = scale

    def __call__(self, rng, shape=None):
        return self.scale * rng.laplace(size=shape)


class Cauchy(Distribution):
    def __init__(self, scale=1.0):
        self.scale = scale

    def __call__(self, rng, shape=None):
        return self.scale * rng.standard_cauchy(size=shape)


class Uniform(Distribution):
    def __init__(self, low, high):
        self.low = low
        self.high = high

    def __call__(self, rng, shape=None):
        return rng.uniform(size=shape, low=self.low, high=self.high)


class Uniform_torch():
    def __init__(self, low, high):
        self.low = low
        self.high = high

    def __call__(self, rng, shape=torch.Size()):
        return self.low + (self.high - self.low) * torch.rand(shape, generator=rng)


class SignedUniform(Distribution):
    def __init__(self, low, high):
        self.low = low
        self.high = high

    def __call__(self, rng, shape=None):
        sgn = rng.choice([-1, 1], size=shape)
        return sgn * rng.uniform(size=shape, low=self.low, high=self.high)


class SignedUniform_torch():
    def __init__(self, low, high):
        self.low = low
        self.high = high

    def __call__(self, rng, shape=torch.Size()):
        # Generate signs (-1 or 1)
        sgn = 2 * torch.bernoulli(torch.full(shape, 0.5), generator=rng) - 1
        
        # Generate uniform samples
        uniform_samples = self.low + (self.high - self.low) * torch.rand(shape, generator=rng)
        
        # Combine signs and uniform samples
        return sgn * uniform_samples


class RandInt(Distribution):
    def __init__(self, low, high, endpoint=True):
        self.low = low
        self.high = high
        self.endpoint = endpoint

    def __call__(self, rng, shape=None):
        return rng.integers(size=shape, low=self.low, high=self.high, endpoint=self.endpoint)


class Beta(Distribution):
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __call__(self, rng, shape=None):
        return rng.beta(self.a, self.b, size=shape)