import torch
import torch.nn as nn
import numpy as np
from types import SimpleNamespace
from scipy.interpolate import interp1d

class EpsFromScore(nn.Module):
    def __init__(self, score_model):
        super().__init__()
        self.score_model = score_model
        self.num_timesteps = score_model.num_timesteps

    def forward(self, x, t):
        """Takes integer timesteps as input and converts them to floats in [0, 1]."""
        T = self.num_timesteps
        t_scaled = t / T # Scale t to [0, 1] and make sure it's a scalar
        score_val = self.score_model(x, t_scaled)
        # If t is a tensor, then convert it to a scalar and move to CPU
        if isinstance(t_scaled, torch.Tensor):
            t_scaled = t_scaled.cpu().item()
        sigma = self.score_model.sigma_fn(t_scaled)
        output = SimpleNamespace(sample=(-sigma * score_val))
        return output
    
class ScoreFromEps(nn.Module):
    def __init__(self, eps_model, scheduler):
        super().__init__()
        self.eps_model = eps_model
        self.num_timesteps = scheduler.config.num_train_timesteps
        alpha_bars_np = scheduler.alphas_cumprod.cpu().numpy()
        sigmas_np = np.sqrt((1 - alpha_bars_np)) # I think this is correct, but we should double check
        _ts = np.linspace(0, 1, len(alpha_bars_np))
        self.sigma_fn = interp1d(_ts, sigmas_np, kind="linear", fill_value="extrapolate")

    def forward(self, x, t):
        """Takes float timesteps as input and converts them to integers in [0, num_timesteps]."""
        T = self.num_timesteps
        t_scaled = int(t * T)
        t_scaled = min(t_scaled, T-1)
        t_scaled = torch.tensor([t_scaled], dtype=torch.long, device=x.device)
        eps = self.eps_model(x.float(), t_scaled).sample
        # If t is a tensor, then convert it to a scalar and move to CPU
        if isinstance(t, torch.Tensor):
            t = t.cpu().item()
        sigma = float(self.sigma_fn(t))
        score = -eps / sigma
        return score
