import torch as th
import numpy as np

LOG_PI = th.log(2 * th.tensor(np.pi))

class GaussianDistribution:
    """A class representing a Gaussian (Normal) distribution with methods for sampling and probability calculations.

    This class provides functionality for working with Gaussian distributions including:
    - Sampling from the distribution
    - Computing log probabilities
    - Calculating KL divergence between distributions

    Attributes:
        mean (torch.Tensor): Mean of the Gaussian distribution
        std (torch.Tensor): Standard deviation of the Gaussian distribution
        log_std (torch.Tensor): Log of the standard deviation
        generators (list, optional): List of random number generators for sampling
    """
    def __init__(self, mean, std, log_std, generators=None):
        """Initialize the Gaussian distribution.

        Args:
            mean (torch.Tensor): Mean of the distribution
            std (torch.Tensor): Standard deviation of the distribution
            log_std (torch.Tensor): Log of the standard deviation
            generators (list, optional): List of random number generators for sampling
        """
        self.mean = mean
        self.std = std
        self.log_std = log_std
        self.generators = generators

    def sample(self, size):
        """Sample from the Gaussian distribution.

        Args:
            size (tuple): Size of the sample to generate

        Returns:
            torch.Tensor: Sampled values from the distribution

        Raises:
            ValueError: If generators are provided but their length doesn't match the batch size
        """
        if self.generators is not None:
            if len(self.generators) == size[0]:
                rand_tensor = []
                for i in range(size[0]):
                    rand_tensor.append(th.randn(size[1:], generator=self.generators[i]))
                rand_tensor = th.stack(rand_tensor, dim=0)
            else:
                raise ValueError("Generators must be of the same size as the batch size.")
        else:
            rand_tensor = th.randn(size)
        return rand_tensor.to(self.mean) * self.std.unsqueeze(-1).unsqueeze(-1) + self.mean

    def log_prob(self, x, real=False):
        """Compute the log probability of samples under this distribution.

        Args:
            x (torch.Tensor): Input samples
            real (bool): Whether to compute the log probability of the real distribution.
            In most cases, we only care about the relative difference between two distributions, so we set real to False.
        Returns:
            torch.Tensor: Log probabilities of the samples
        """
        dims_except_batch = list(range(1, len(self.mean.shape)))
        bias = 2 * self.log_std + LOG_PI if real else 0
        return -0.5 * th.mean(
            (x - self.mean) ** 2 / (self.std.unsqueeze(-1).unsqueeze(-1) ** 2 + 1e-6) + bias,
            dim=dims_except_batch
        )

    def kl_divergence(self, other: "GaussianDistribution"):
        """Compute the KL divergence between this distribution and another Gaussian distribution.

        Args:
            other (GaussianDistribution): Another Gaussian distribution to compute KL divergence with

        Returns:
            torch.Tensor: KL divergence between the two distributions
        """
        dims_except_batch = list(range(1, len(self.mean.shape)))
        return th.mean((self.mean - other.mean) ** 2 / (other.std.unsqueeze(-1).unsqueeze(-1) ** 2), dim = dims_except_batch)

class SampleResultsWithBackwardCompatibility(list):
    """A list that also supports accessing additional attributes while maintaining backward compatibility"""
    def __init__(self, samples, **kwargs):
        super().__init__(samples)
        for key, value in kwargs.items():
            setattr(self, key, value)

    def __getitem__(self, key):
        if isinstance(key, str):
            if key == 'samples':
                return list(self)
            elif hasattr(self, key):
                return getattr(self, key)
            else:
                raise KeyError(f"'{key}' not found")
        else:
            return super().__getitem__(key)

    def get(self, key, default=None):
        """Dictionary-style get method"""
        try:
            return self[key]
        except (KeyError, IndexError, TypeError):
            return default
