import abc

import torch
import torch.nn as nn


class Noise(abc.ABC, nn.Module):
    """
    Baseline forward method to get the total + rate of noise at a timestep
    """

    def forward(self, t):
        # Assume time goes from 0 to 1
        return self.total_noise(t), self.rate_noise(t)

    @abc.abstractmethod
    def total_noise(self, t):
        """
        Total noise ie \int_0^t g(t) dt + g(0)
        """
        pass


class LogLinearNoise(Noise):
    """Log Linear noise schedule."""

    def __init__(self, eps_max=1e-3, eps_min=1e-5, **kwargs):
        super().__init__()
        self.eps_max = eps_max
        self.eps_min = eps_min
        self.sigma_max = self.total_noise(torch.tensor(1.0))
        self.sigma_min = self.total_noise(torch.tensor(0.0))

    def k(self):
        return torch.tensor(1)

    def rate_noise(self, t):
        return (1 - self.eps_max - self.eps_min) / (
            1 - ((1 - self.eps_max - self.eps_min) * t + self.eps_min)
        )

    def total_noise(self, t):
        """
        sigma_min=-log(1-eps_min), when t=0
        sigma_max=-log(eps_max), when t=1
        """
        return -torch.log1p(-((1 - self.eps_max - self.eps_min) * t + self.eps_min))


class PowerMeanNoise(Noise):
    """The noise schedule using the power mean interpolation function.

    This is the schedule used in EDM
    """

    def __init__(self, sigma_min=0.002, sigma_max=80, rho=7, **kwargs):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.raw_rho = rho

    def rho(self):
        # Return the softplus-transformed rho for all num_numerical values
        return torch.tensor(self.raw_rho)

    def total_noise(self, t):
        sigma = (
            self.sigma_min ** (1 / self.rho())
            + t
            * (self.sigma_max ** (1 / self.rho()) - self.sigma_min ** (1 / self.rho()))
        ).pow(self.rho())
        return sigma

    def inverse_to_t(self, sigma):
        t = (sigma.pow(1 / self.rho()) - self.sigma_min ** (1 / self.rho())) / (
            self.sigma_max ** (1 / self.rho()) - self.sigma_min ** (1 / self.rho())
        )
        return t


class PowerMeanNoise_PerColumn(nn.Module):
    def __init__(
        self,
        num_numerical,
        sigma_min=0.002,
        sigma_max=80,
        rho_init=1,
        rho_offset=2,
        **kwargs,
    ):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.num_numerical = num_numerical
        self.rho_offset = rho_offset
        self.rho_raw = nn.Parameter(
            torch.tensor([rho_init] * self.num_numerical, dtype=torch.float32)
        )

    def rho(self):
        # Return the softplus-transformed rho for all num_numerical values
        return nn.functional.softplus(self.rho_raw) + self.rho_offset

    def total_noise(self, t):
        """
        Compute total noise for each t in the batch for all num_numerical rhos.
        t: [batch_size]
        Returns: [batch_size, num_numerical]
        """
        # batch_size = t.size(0)

        rho = self.rho().to(t.device)

        sigma_min_pow = self.sigma_min ** (1 / rho)  # Shape: [num_numerical]
        sigma_max_pow = self.sigma_max ** (1 / rho)  # Shape: [num_numerical]

        sigma = (sigma_min_pow + t * (sigma_max_pow - sigma_min_pow)).pow(
            rho
        )  # Shape: [batch_size, num_numerical]

        return sigma

    def rate_noise(self, t):
        return None

    def inverse_to_t(self, sigma):
        """
        Inverse function to map sigma back to t, with proper broadcasting support.
        sigma: [batch_size, num_numerical] or [batch_size, 1]
        Returns: t: [batch_size, num_numerical]
        """
        rho = self.rho().to(sigma.device)

        sigma_min_pow = self.sigma_min ** (1 / rho)  # Shape: [num_numerical]
        sigma_max_pow = self.sigma_max ** (1 / rho)  # Shape: [num_numerical]

        # To enable broadcasting between sigma and the per-column rho values, expand rho where needed.
        t = (sigma.pow(1 / rho) - sigma_min_pow) / (sigma_max_pow - sigma_min_pow)

        return t


class PowerMeanUnified(nn.Module):
    def __init__(
        self,
        num_numerical,
        sigma_min=0.002,
        sigma_max=80,
        rho_init=1,
        rho_offset=2,
        **kwargs,
    ):
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.num_numerical = num_numerical
        self.rho_offset = rho_offset
        # Single shared parameter for all columns
        self.rho_raw = nn.Parameter(torch.tensor(rho_init, dtype=torch.float32))

    def rho(self):
        # Return the softplus-transformed rho as a scalar
        return nn.functional.softplus(self.rho_raw) + self.rho_offset

    def total_noise(self, t):
        """
        Compute total noise for each t in the batch with a shared rho.
        t: [batch_size]
        Returns: [batch_size, num_numerical]
        """
        rho = self.rho().to(t.device)  # Scalar value

        sigma_min_pow = self.sigma_min ** (1 / rho)  # Scalar
        sigma_max_pow = self.sigma_max ** (1 / rho)  # Scalar

        sigma = (sigma_min_pow + t * (sigma_max_pow - sigma_min_pow)).pow(
            rho
        )  # Shape: [batch_size]

        # Broadcast to match num_numerical
        sigma = sigma.repeat(1, self.num_numerical)  # [batch_size, num_numerical]

        return sigma

    def rate_noise(self, t):
        return None

    def inverse_to_t(self, sigma):
        """
        Inverse function to map sigma back to t, with shared parameter.
        sigma: [batch_size, num_numerical] or [batch_size, 1]
        Returns: t: [batch_size, num_numerical]
        """
        rho = self.rho().to(sigma.device)  # Scalar value

        sigma_min_pow = self.sigma_min ** (1 / rho)  # Scalar
        sigma_max_pow = self.sigma_max ** (1 / rho)  # Scalar

        # For unified parameter, we can use any column of sigma since they should be the same
        # If sigma has multiple columns, take the mean or first column
        if sigma.dim() > 1 and sigma.size(-1) > 1:
            sigma_unified = sigma.mean(dim=-1, keepdim=True)  # [batch_size, 1]
        else:
            sigma_unified = sigma

        t_unified = (sigma_unified.pow(1 / rho) - sigma_min_pow) / (
            sigma_max_pow - sigma_min_pow
        )  # Shape: [batch_size, 1] or [batch_size]

        # Expand to match num_numerical if needed
        if t_unified.dim() == 1:
            t_unified = t_unified.unsqueeze(-1)
        t = t_unified.expand(
            -1, self.num_numerical
        )  # Shape: [batch_size, num_numerical]

        return t


class LogLinearNoise_PerColumn(nn.Module):
    def __init__(
        self,
        num_categories,
        eps_max=1e-3,
        eps_min=1e-5,
        k_init=-6,
        k_offset=1,
        **kwargs,
    ):
        super().__init__()
        self.eps_max = eps_max
        self.eps_min = eps_min
        # Use softplus to ensure k is positive
        self.num_categories = num_categories
        self.k_offset = k_offset
        self.k_raw = nn.Parameter(
            torch.tensor([k_init] * self.num_categories, dtype=torch.float32)
        )

    def k(self):
        return torch.nn.functional.softplus(self.k_raw) + self.k_offset

    def rate_noise(self, t, noise_fn=None):
        """
        Compute rate noise for all categories with broadcasting.
        t: [batch_size]
        Returns: [batch_size, num_categories]
        """
        k = self.k().to(t.device)  # Shape: [num_categories]

        numerator = (1 - self.eps_max - self.eps_min) * k * t.pow(k - 1)
        denominator = 1 - ((1 - self.eps_max - self.eps_min) * t.pow(k) + self.eps_min)
        rate = numerator / denominator  # Shape: [batch_size, num_categories]

        return rate

    def total_noise(self, t, noise_fn=None):
        k = self.k().to(t.device)  # Shape: [num_categories]

        total_noise = -torch.log1p(
            -((1 - self.eps_max - self.eps_min) * t.pow(k) + self.eps_min)
        )

        return total_noise

    def inverse_to_t(self, sigma):
        """
        Inverse function to map sigma back to t, with proper broadcasting support.
        sigma: [batch_size, num_categories] or [batch_size, 1]
        Returns: t: [batch_size, num_categories]
        """
        k = self.k()  # Shape: [num_categories]

        # To enable broadcasting between sigma and the per-column k values, expand k where needed.

        t = (
            ((-sigma).exp() - 1 + self.eps_min) / (self.eps_max + self.eps_min - 1)
        ).pow(1 / k)

        # Set nan values to 0
        t[torch.isnan(t)] = 0

        return t


class LogLinearUnified(nn.Module):
    def __init__(
        self,
        num_categories,
        eps_max=1e-3,
        eps_min=1e-5,
        k_init=-6,
        k_offset=1,
        **kwargs,
    ):
        super().__init__()
        self.eps_max = eps_max
        self.eps_min = eps_min
        self.num_categories = num_categories
        self.k_offset = k_offset
        # Single shared parameter for all columns
        self.k_raw = nn.Parameter(torch.tensor(k_init, dtype=torch.float32))

    def k(self):
        return torch.nn.functional.softplus(self.k_raw) + self.k_offset

    def rate_noise(self, t, noise_fn=None):
        """
        Compute rate noise for all categories with a shared parameter.
        t: [batch_size]
        Returns: [batch_size, num_categories]
        """
        k = self.k().to(t.device)  # Scalar value

        numerator = (1 - self.eps_max - self.eps_min) * k * t.pow(k - 1)
        denominator = 1 - ((1 - self.eps_max - self.eps_min) * t.pow(k) + self.eps_min)
        rate = numerator / denominator  # Shape: [batch_size]

        # Expand to match num_categories
        rate = rate.repeat(
            1, self.num_categories
        )  # Shape: [batch_size, num_categories]

        return rate

    def total_noise(self, t, noise_fn=None):
        k = self.k().to(t.device)  # Scalar value

        total_noise = -torch.log1p(
            -((1 - self.eps_max - self.eps_min) * t.pow(k) + self.eps_min)
        )  # Shape: [batch_size]

        # Expand to match num_categories
        total_noise = total_noise.repeat(1, self.num_categories)

        return total_noise

    def inverse_to_t(self, sigma):
        """
        Inverse function to map sigma back to t, with shared parameter.
        sigma: [batch_size, num_categories] or [batch_size, 1]
        Returns: t: [batch_size, num_categories]
        """
        k = self.k()  # Scalar value

        # For unified parameter, we can use any column of sigma since they should be the same
        # If sigma has multiple columns, take the mean or first column
        if sigma.dim() > 1 and sigma.size(-1) > 1:
            sigma_unified = sigma.mean(dim=-1, keepdim=True)  # [batch_size, 1]
        else:
            sigma_unified = sigma

        t_unified = (
            ((-sigma_unified).exp() - 1 + self.eps_min)
            / (self.eps_max + self.eps_min - 1)
        ).pow(1 / k)  # Shape: [batch_size, 1] or [batch_size]

        # Set nan values to 0
        t_unified[torch.isnan(t_unified)] = 0

        # Expand to match num_categories if needed
        if t_unified.dim() == 1:
            t_unified = t_unified.unsqueeze(-1)
        t = t_unified.expand(
            -1, self.num_categories
        )  # Shape: [batch_size, num_categories]

        return t


class NoNoise(Noise):
    """Dummy noise schedule used for dimension tables."""

    def __init__(self, dim, **kwargs):
        super().__init__()
        self.dim = max(dim, 1)

    def total_noise(self, t):
        return torch.zeros((t.size(0), self.dim), device=t.device)

    def rate_noise(self, t):
        return torch.zeros((t.size(0), self.dim), device=t.device)

    def rho(self):
        return torch.tensor(0)

    def k(self):
        return torch.tensor(0)
