import math
from typing import Type

import torch
from torch import Tensor


class BaseDistribution:
    @property
    def param_shape(self) -> tuple[int]:
        """Input: None Output: [*param_shape]"""
        raise NotImplementedError("param_shape method not implemented")

    def log_prob(self, x: Tensor) -> Tensor:
        """Input: [*batch_shape] Output: [*batch_shape]"""
        raise NotImplementedError("log_prob method not implemented")

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        """Input: tuple[int] Output: [*sample_shape, *param_shape]"""
        raise NotImplementedError("sample method not implemented")

    def cdf(self, x: Tensor) -> Tensor:
        """Input: [*batch_shape] Output: [*batch_shape]"""
        raise NotImplementedError("CDF method not implemented")

    def percentile(self, p: Tensor, **kwargs) -> Tensor:
        """Input: [*param_shape, *p_shape] or [*p_shape] Output: [*param_shape, *p_shape]"""
        raise NotImplementedError("Percentile method not implemented")

    def derivative(self, x: Tensor) -> Tensor:
        """Input: [*batch_shape] Output: [*batch_shape]"""
        raise NotImplementedError("Derivative method not implemented")

    def second_derivative(self, x: Tensor) -> Tensor:
        """Input: [*batch_shape] Output: [*batch_shape]"""
        raise NotImplementedError("Second derivative method not implemented")

    def inflection_points(self) -> Tensor:
        """Input: None Output: [*param_shape, num_points]"""
        raise NotImplementedError("Inflection points method not implemented")

    def mode(self) -> Tensor:
        """Input: None Output: [*param_shape]"""
        raise NotImplementedError("Mode method not implemented")

    def __repr__(self):
        attrs = [f'{k}={v.shape}' for k, v in self.__dict__.items() if isinstance(v, Tensor)]
        attrs = ', '.join(attrs)
        return f"{self.__class__.__name__}({attrs})"


class ExponentialDistribution(BaseDistribution):
    def __init__(self, lambda_param: Tensor) -> None:
        self.lambda_param = lambda_param

    @property
    def param_shape(self) -> tuple[int]:
        return self.lambda_param.shape

    def log_prob(self, x: Tensor) -> Tensor:
        return torch.log(self.lambda_param) - self.lambda_param * x

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        if isinstance(sample_shape, int):
            sample_shape = (sample_shape,)
        sample_shape = sample_shape + self.lambda_param.shape
        noise = torch.rand(sample_shape).to(self.lambda_param.device)
        return -torch.log(noise) / self.lambda_param

    def derivative(self, x: Tensor) -> Tensor:
        return -self.lambda_param**2 * torch.exp(-self.lambda_param * x)

    def second_derivative(self, x: Tensor) -> Tensor:
        return self.lambda_param**3 * torch.exp(-self.lambda_param * x)

    def cdf(self, x: Tensor) -> Tensor:
        return 1 - torch.exp(-self.lambda_param * x)

    def percentile(self, p: Tensor, **kwargs) -> Tensor:
        if p.ndim == 1:
            p = p.reshape(*([1] * self.lambda_param.ndim), -1).to(self.lambda_param.device)
        p = p.to(self.lambda_param.device)
        return -torch.log(1 - p) / self.lambda_param.unsqueeze(-1)


class GammaDistribution(BaseDistribution):
    def __init__(self, alpha: Tensor, beta: Tensor) -> None:
        self.alpha = alpha
        self.beta = beta
        self.log_beta = torch.log(beta)
        self.log_gamma_alpha = torch.lgamma(alpha)

    @property
    def param_shape(self) -> tuple[int]:
        return self.alpha.shape

    def log_prob(self, x: Tensor) -> Tensor:
        return (self.alpha - 1) * torch.log(x) - self.beta * x - self.log_gamma_alpha + self.alpha * self.log_beta

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        if isinstance(sample_shape, int):
            sample_shape = (sample_shape,)
        # Convert to torch.Size
        sample_shape = torch.Size(sample_shape)

        # Extend shape - similar to PyTorch's _extended_shape function
        shape = sample_shape + self.alpha.shape

        # Sample using PyTorch's implementation
        value = torch._standard_gamma(self.alpha.expand(shape)) / self.beta.expand(shape)

        # Ensure numerical stability
        value = value.clamp(min=torch.finfo(value.dtype).tiny)

        return value

    def cdf(self, x: Tensor) -> Tensor:
        return torch.igamma(self.alpha, self.beta * x)

    def derivative(self, x: Tensor) -> Tensor:
        pdf = self.log_prob(x).exp()
        return pdf * ((self.alpha - 1) / x - self.beta)

    def second_derivative(self, x: Tensor) -> Tensor:
        pdf = self.log_prob(x).exp()  # Compute the PDF
        first_derivative = self.derivative(x)  # Compute the first derivative
        return first_derivative * ((self.alpha - 1) / x - self.beta) + pdf * (-((self.alpha - 1) / x**2))

    def inflection_points(self) -> Tensor:
        return torch.stack([
            (self.alpha - 1 - torch.sqrt(self.alpha - 1)) / self.beta,
            (self.alpha - 1 + torch.sqrt(self.alpha - 1)) / self.beta,
        ], -1)

    def mode(self) -> Tensor:
        return torch.where(
            self.alpha > 1,
            (self.alpha - 1) / self.beta,
            0
        )

    def percentile(self, p: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError


class LogNormalDistribution(BaseDistribution):
    def __init__(self, mean: Tensor, std: Tensor) -> None:
        self.mean = mean
        self.std = std
        self.sqrt_2 = math.sqrt(2)
        self.sqrt_2pi = math.sqrt(2 * torch.pi)

    @property
    def param_shape(self) -> tuple[int]:
        return self.mean.shape

    def log_prob(self, x: Tensor) -> Tensor:
        return -torch.log(x) - 0.5 * ((torch.log(x) - self.mean) / self.std) ** 2 - torch.log(self.std * self.sqrt_2pi)

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        if isinstance(sample_shape, int):
            sample_shape = (sample_shape,)
        sample_shape = sample_shape + self.mean.shape
        noise = torch.randn(sample_shape).to(self.mean.device)
        return torch.exp(self.mean + self.std * noise)

    def cdf(self, x: Tensor) -> Tensor:
        return 0.5 * (1 + torch.erf((torch.log(x) - self.mean) / (self.std * self.sqrt_2)))

    def percentile(self, p: Tensor, **kwargs) -> Tensor:
        if p.ndim == 1:
            p = p.reshape(*([1] * self.mean.ndim), -1)
        p = p.to(self.mean.device)
        return torch.exp(self.mean.unsqueeze(-1) + self.std.unsqueeze(-1) * self.sqrt_2 * torch.erfinv(2 * p - 1))

    def derivative(self, x: Tensor) -> Tensor:
        pdf = self.log_prob(x).exp()
        return pdf * (-1 / x) * (1 + ((torch.log(x) - self.mean) / (self.std ** 2)))

    def second_derivative(self, x: Tensor) -> Tensor:
        pdf = self.log_prob(x).exp()
        log_x = torch.log(x)
        exp_term = (log_x - self.mean) / (self.std**2)
        term = (1 + exp_term)**2 + (1 - 1 / self.std**2) + exp_term
        return pdf * term / x**2

    def inflection_points(self) -> Tensor:
        sqrt_term = torch.sqrt(1 + 4 / (self.std**2))

        return torch.stack([
            self.mean + (self.std**2 / 2) * (-3 - sqrt_term),
            self.mean + (self.std**2 / 2) * (-3 + sqrt_term),
        ], -1).exp()

    def mode(self) -> Tensor:
        return torch.exp(self.mean - self.std**2)


class WeibullDistribution(BaseDistribution):
    def __init__(self, scale: Tensor, shape: Tensor) -> None:
        self.scale = scale
        self.shape = shape

    @property
    def param_shape(self) -> tuple[int]:
        return self.scale.shape

    def log_prob(self, x: Tensor) -> Tensor:
        return torch.log(self.shape) - self.shape * torch.log(self.scale) + (self.shape - 1) * torch.log(x) - (x / self.scale) ** self.shape

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        if isinstance(sample_shape, int):
            sample_shape = (sample_shape,)
        sample_shape = sample_shape + self.scale.shape
        noise = torch.rand(sample_shape).to(self.scale.device)
        return self.scale * (-torch.log(noise)) ** (1 / self.shape)

    def cdf(self, x: Tensor) -> Tensor:
        return 1 - torch.exp(-(x / self.scale) ** self.shape)

    def derivative(self, x: Tensor) -> Tensor:
        pdf = self.log_prob(x).exp()
        return pdf * ((self.shape - 1) / x - (self.shape / self.scale) * (x / self.scale) ** (self.shape - 1))

    def second_derivative(self, x: Tensor) -> Tensor:
        k = self.shape
        l = self.scale

        pdf = self.log_prob(x).exp()
        first_derivative = self.derivative(x)

        term1 = (k - 1) / x
        term2 = (k / l) * (x / l) ** (k - 1)
        term1_derivative = -(k - 1) / x**2
        term2_derivative = -(k * (k - 1)) / l**2 * (x / l) ** (k - 2)

        second_derivative = first_derivative * (term1 - term2) + pdf * (term1_derivative + term2_derivative)
        return second_derivative

    def inflection_points(self) -> Tensor:
        k = self.shape
        l = self.scale

        term1 = 2 ** (-1 / k) * l
        denominator = k**2
        numerator1 = -((3 - 3 * k) * k)
        numerator2 = k * torch.sqrt(1 - 6 * k + 5 * k**2)
        x = torch.stack([
            term1 * ((numerator1 - numerator2) / denominator) ** (1 / k),
            term1 * ((numerator1 + numerator2) / denominator) ** (1 / k),
        ], -1)
        return x

    def percentile(self, p: Tensor, **kwargs) -> Tensor:
        if p.ndim == 1:
            p = p.reshape(*([1] * self.scale.ndim), -1)
        p = p.to(self.scale.device)
        return self.scale.unsqueeze(-1) * (-torch.log(1 - p)) ** (1 / self.shape.unsqueeze(-1))

    def mode(self) -> Tensor:
        return torch.where(
            self.shape > 1,
            self.scale * ((self.shape - 1) / self.shape) ** (1 / self.shape),
            0
        )


class MixtureDistribution(BaseDistribution):
    def __init__(self, component_distribution: Type[BaseDistribution], logits: Tensor) -> None:
        self.component_distribution = component_distribution
        assert self.component_distribution.param_shape == logits.shape, \
            f"Component distribution shape {self.component_distribution.param_shape} does not match logits shape {logits.shape}"

        self.logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
        self.dim = logits.shape[-1]
        self._weights = torch.softmax(self.logits, dim=-1)

    @property
    def param_shape(self) -> tuple[int]:
        return self.component_distribution.param_shape

    @property
    def weights(self) -> Tensor:
        return self._weights

    def log_prob(self, x: Tensor) -> Tensor:
        x = x.unsqueeze(-1)
        component_log_probs = self.component_distribution.log_prob(x)
        weighted_log_probs = component_log_probs + self.logits
        return torch.logsumexp(weighted_log_probs, dim=-1)

    def sample(self, sample_shape: tuple[int]) -> Tensor:
        if isinstance(sample_shape, int):
            sample_shape = (sample_shape,)

        total_shape = sample_shape + self.logits.shape[:-1]
        sample_ndim = math.prod(sample_shape)

        comp_samples = self.component_distribution.sample(sample_shape)  # [*sample_shape, *param_shape]

        mix_probs = torch.softmax(self.logits, dim=-1).view(-1, self.dim)  # [*param_shape[:-1], k]
        mix_sample = torch.multinomial(mix_probs, sample_ndim, replacement=True).T  # [*sample_shape, *param_shape[:-1]]
        mix_sample = mix_sample.view(*total_shape).to(comp_samples.device)

        mix_sample_expanded = mix_sample.unsqueeze(-1)
        return torch.gather(comp_samples, -1, mix_sample_expanded).squeeze(-1)  # [*sample_shape, *param_shape]

    def mode(self) -> Tensor:
        return self.component_distribution.mode()

    def inflection_points(self) -> Tensor:
        return self.component_distribution.inflection_points().flatten(-2, -1)

    def percentile(self, p: Tensor, exact: bool = True) -> Tensor:
        component_percentiles = self.component_distribution.percentile(p)

        if exact:
            min_values = component_percentiles.min(dim=-2).values
            max_values = component_percentiles.max(dim=-2).values
            return torch.where(p.to(min_values.device) < 0.5, min_values, max_values)
        else:
            weights = self.weights.unsqueeze(-1)
            weighted_avg = (weights * component_percentiles).sum(dim=-2)
            return weighted_avg

    def __repr__(self):
        return f"{self.__class__.__name__}(component_distribution={self.component_distribution}, logits={self.logits.shape})"
