import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.interpolate import interp1d
import math

class SimpleLoss(nn.Module):
    def __init__(self, eps_predictor, scheduler, num_gaussian_samples=3):
        super().__init__()
        self.eps_predictor = eps_predictor
        self.num_gaussian_samples = num_gaussian_samples
        self.num_timesteps = scheduler.config.num_train_timesteps

        # Interpolate alpha_bar as a function of t in [0, 1]
        alpha_bars_np = scheduler.alphas_cumprod.cpu().numpy()
        _ts = np.linspace(0, 1, len(alpha_bars_np))
        self.alpha_bar_fn = interp1d(_ts, alpha_bars_np, kind="linear", fill_value="extrapolate")

    def forward(self, z, t, epsilon=None):
        """
        z: Tensor of shape (B, *) — supports both images and flattened vectors.
        t: float scalar in [0, 1]
        Returns: Tensor of shape (B,) with per-sample losses
        """
        device = z.device
        dtype = z.dtype
        B = z.shape[0]
        S = self.num_gaussian_samples
        sample_shape = z.shape[1:]  # (C, H, W) or (D,)

        # Get alpha_bar(t)
        alpha_bar = float(self.alpha_bar_fn(t))  # scalar
        sqrt_alpha_bar = math.sqrt(alpha_bar)
        sqrt_one_minus_alpha_bar = math.sqrt(1.0 - alpha_bar)

        # Expand z to (S, B, ...)
        z_expand = z.unsqueeze(0).expand(S, *z.shape)
        if epsilon is None:
            # Sample new Gaussian noise if none is provided
            epsilon = torch.randn_like(z_expand)
        else:
            assert epsilon.shape == z_expand.shape, "Shape of epsilon must match shape of z"
        x_t = sqrt_alpha_bar * z_expand + sqrt_one_minus_alpha_bar * epsilon

        # Flatten (S, B, ...) to (S*B, ...)
        x_t = x_t.view(S * B, *sample_shape)
        epsilon = epsilon.view(S * B, *sample_shape)

        # Convert float t to time index expected by model
        T = self.num_timesteps
        t_index = int(t * T)

        # Predict noise
        epsilon_pred = self.eps_predictor(x_t, t_index).sample

        # Compute loss and reshape to (S, B)
        loss = F.mse_loss(epsilon_pred, epsilon, reduction='none')
        loss = loss.view(S, B, *sample_shape)
        loss = loss.mean(dim=tuple(range(2, loss.ndim)))  # mean over all sample dims

        return loss.mean(dim=0)  # mean over S, return (B,) vector