# Modified from OpenAI's diffusion repos
#     GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
#     ADM:   https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
#     IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py


import math

import numpy as np
import torch as th
import enum

# from . import dist_util
from .diffusion_utils import normal_kl, continuous_gaussian_log_likelihood, discretized_gaussian_log_likelihood


def mean_flat(tensor):
    """Mean over all non-batch dimensions."""
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


class ModelMeanType(enum.Enum):
    """Types of model mean predictions."""

    PREVIOUS_X = enum.auto()  
    START_X = enum.auto() 
    EPSILON = enum.auto() 


class ModelVarType(enum.Enum):
    """Ways the model parameterizes variance."""

    LEARNED = enum.auto()
    FIXED_SMALL = enum.auto()
    FIXED_LARGE = enum.auto()
    LEARNED_RANGE = enum.auto()


class LossType(enum.Enum):
    MSE = enum.auto() 
    RESCALED_MSE = (
        enum.auto()
    )  
    KL = enum.auto()  
    RESCALED_KL = enum.auto() 

    def is_vb(self):
        return self == LossType.KL or self == LossType.RESCALED_KL


def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
    betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    warmup_time = int(num_diffusion_timesteps * warmup_frac)
    betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
    return betas


def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    """Deprecated helper for constructing beta schedules."""
    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "warmup10":
        betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
    elif beta_schedule == "warmup50":
        betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """Return a predefined beta schedule for the requested length."""
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        return get_beta_schedule(
            "linear",
            beta_start=scale * 0.0001,
            beta_end=scale * 0.02,
            num_diffusion_timesteps=num_diffusion_timesteps,
        )
    elif schedule_name == "squaredcos_cap_v2":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")


def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """Discretize a continuous alpha_bar schedule into beta values."""
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


class MultimodalGaussianDiffusion:
    """Diffusion helpers for multimodal weight and architecture tensors."""

    def __init__(
        self,
        *,
        betas,
        model_mean_type,
        model_var_type,
        loss_type
    ):

        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type
        self.loss_type = loss_type

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        ) if len(self.posterior_variance) > 1 else np.array([])

        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
        )

    def q_mean_variance(self, x_start, t):
        """Return mean, variance, and log-variance of q(x_t | x_0)."""
        mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise=None):
        """Sample x_t from q(x_t | x_0)."""
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """Return posterior stats for q(x_{t-1} | x_t, x_0)."""
        assert x_start.shape == x_t.shape
        posterior_mean = (
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, model, x, t_arch, t_weight, clip_denoised=False, denoised_fn=None, model_kwargs=None):
        """Compute model statistics and predicted x_0 for both modalities."""
        if model_kwargs is None:
            model_kwargs = {}

        # B, C = x.shape[:2]
        B = x["architecture"].shape[0]
        assert t_weight.shape == t_arch.shape == (B,)
        
        # weight_output, architecture_output = model(x["architecture"],x["weight"], t_arch, t_weight, **model_kwargs)
        architecture_output, weight_output = model(x["architecture"],x["weight"], t_arch, t_weight, **model_kwargs)

        def process_xstart(x):
            if denoised_fn is not None:
                x = denoised_fn(x)
            if clip_denoised:
                return x.clamp(-1, 1)
            return x
        
        def get_variance(model_output, x, t_x):
            dim=2
            if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
                assert model_output.shape[dim] == x.shape[dim]*2
                model_output, model_var_values = th.split(model_output, x.shape[dim], dim=dim)
                if self.model_var_type == ModelVarType.LEARNED:
                    model_log_variance = model_var_values
                    model_variance = th.exp(model_log_variance) 
                else:
                    min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t_x, x.shape)
                    max_log = _extract_into_tensor(np.log(self.betas), t_x, x.shape)
                    # The model_var_values is [-1, 1] for [min_var, max_var].
                    frac = (model_var_values + 1) / 2  
                    model_log_variance = frac * max_log + (1 - frac) * min_log
                    model_variance = th.exp(model_log_variance)
            else:
                model_variance, model_log_variance = {
                    ModelVarType.FIXED_LARGE: (
                        np.append(self.posterior_variance[1], self.betas[1:]),
                        np.log(np.append(self.posterior_variance[1], self.betas[1:])),
                    ),
                    ModelVarType.FIXED_SMALL: (
                        self.posterior_variance,
                        self.posterior_log_variance_clipped,
                    ),
                }[self.model_var_type]
                model_variance = _extract_into_tensor(model_variance, t_x, x.shape)
                model_log_variance = _extract_into_tensor(model_log_variance, t_x, x.shape)


            if self.model_mean_type == ModelMeanType.PREVIOUS_X:
                pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t_x, xprev=model_output))
                model_mean = model_output
            elif self.model_mean_type == ModelMeanType.START_X:
                pred_xstart = process_xstart(model_output)
                model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t_x)
            else:  # ModelMeanType.EPSILON
                pred_xstart = process_xstart(
                    self._predict_xstart_from_eps(x_t=x, t=t_x, eps=model_output)
                )
                model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t_x)
            
            return model_mean, model_variance, model_log_variance, pred_xstart
                   
        
        architecture_mean, architecture_variance, architecture_log_variance, architecture_pred_xstart = get_variance(architecture_output, x["architecture"], t_arch)
        weight_mean, weight_variance, weight_log_variance, weight_pred_xstart = get_variance(weight_output, x["weight"], t_weight)

        # return {
        #     "mean": {'weight': weight_mean, 'architecture': architecture_mean},
        #     "variance": {'weight': weight_variance, 'architecture': architecture_variance},
        #     "log_variance": {'weight': weight_log_variance, 'architecture': architecture_log_variance},
        #     "pred_xstart": {'weight': weight_pred_xstart, 'architecture': architecture_pred_xstart},
        #     "model_predict": {'weight': weight_output, 'architecture': architecture_output},
        # }

        return {    
                "mean"  : { "architecture": architecture_mean, "weight": weight_mean},
                "log_variance": { "architecture": architecture_log_variance, "weight": weight_log_variance},
                "pred_xstart": { "architecture": architecture_pred_xstart, "weight": weight_pred_xstart},
                "model_predict": { "architecture": architecture_output, "weight": weight_output},
        }

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_xstart_from_xprev(self, x_t, t, xprev):
        assert x_t.shape == xprev.shape
        return (  # (xprev - coef2*x_t) / coef1
            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
            - _extract_into_tensor(
                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
            )
            * x_t
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
    
    def p_sample(
        self,
        model,
        x,
        t_arch,
        t_weight,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        noise=None,
        
    ):
        """Draw x_{t-1} samples for each modality at the requested timestep."""
        t_arch = t_weight if t_arch is None else t_arch
        # x_arch = x["architecture"]
        # x_weight = x["weight"]
        
        out = self.p_mean_variance(
            model,
            x,
            t_arch,
            t_weight,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        # noise = th.randn_like(x)
        # nonzero_mask = (
        #     (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        # )  # no noise when t == 0
        
        architecture_noise = th.randn_like(x["architecture"])
        weight_noise = th.randn_like(x["weight"])
        
        architecture_nonzero_mask = (
            (t_arch != 0).float().view(-1, *([1] * (len(x["architecture"].shape) - 1)))
        )  # no noise when t == 0
        
        weight_nonzero_mask = (
            (t_weight != 0).float().view(-1, *([1] * (len(x["weight"].shape) - 1)))
        )  # no noise when t == 0
        
        if cond_fn is not None:
            # out["mean"] = self.condition_mean(cond_fn, out, x, t_weight, model_kwargs=model_kwargs)
            print("Conditioned sampling not implemented yet for NiT diffusion.")
            
        # sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        architecture_sample = out["mean"]["architecture"] + architecture_nonzero_mask * th.exp(0.5 * out["log_variance"]["architecture"]) * architecture_noise
        weight_sample = out["mean"]["weight"] + weight_nonzero_mask * th.exp(0.5 * out["log_variance"]["weight"]) * weight_noise
        
        return  {"sample": {"architecture": architecture_sample, "weight": weight_sample},
                "pred_xstart": {"architecture": out["pred_xstart"]["architecture"], "weight": out["pred_xstart"]["weight"]}}

    def _vb_terms_bpd(
            self, model, x_start, x_t, t_arch,t_weight, clip_denoised=False, model_kwargs=None
    ):
        """Compute variational bound terms for each modality."""
        
        architecture_true_mean, _, architecture_true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start["architecture"], x_t=x_t["architecture"], t=t_arch
        )
        
        weight_true_mean, _, weight_true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start["weight"], x_t=x_t["weight"], t=t_weight
        )
        
        
        true_mean = {"architecture": architecture_true_mean, "weight": weight_true_mean}
        true_log_variance_clipped = {"architecture": architecture_true_log_variance_clipped, "weight": weight_true_log_variance_clipped}
        
        out = self.p_mean_variance(model, x_t, t_arch,t_weight, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
        
        kl ={}
        decoder_nll = {}
        output = {}
        
        for key in ["architecture", "weight"]:
            kl[key] = normal_kl(true_mean[key], true_log_variance_clipped[key], out["mean"][key], out["log_variance"][key])
            kl[key] = mean_flat(kl[key]) / np.log(2.0)

            if key == "architecture":
                decoder_nll[key] = -discretized_gaussian_log_likelihood(x_start[key], means=out["mean"][key], log_scales=0.5 * out["log_variance"][key], bin_width=1.0/5.0)
            else:
                decoder_nll[key] = -continuous_gaussian_log_likelihood(x_start[key], means=out["mean"][key], log_scales=0.5 * out["log_variance"][key])

            assert decoder_nll[key].shape == x_start[key].shape
            decoder_nll[key] = mean_flat(decoder_nll[key]) / np.log(2.0)
            
            if key == "weight":
                output[key] = th.where((t_weight == 0), decoder_nll[key], kl[key])
            else:  
                output[key] = th.where((t_arch == 0), decoder_nll[key], kl[key])
            
        return {"output": output, "pred_xstart": out["pred_xstart"]}
   
  
    def multimodal_training_losses(self, model, x_start, model_kwargs=None, noise=None):
        """Single-step MoNL loss mixing joint and conditional cases."""
        if model_kwargs is None:
            model_kwargs = {}
        
        B = x_start["weight"].shape[0]
        device = x_start["weight"].device
        
        is_conditional = th.rand(B, device=device) < 0.5
        t_ref = th.randint(0, self.num_timesteps, (B,), device=device, dtype=th.long)
        t_weight = t_ref
        t_arch = th.where(is_conditional, th.zeros_like(t_ref), t_ref)
        
        
        if noise is None:
            noise = {"architecture": th.randn_like(x_start["architecture"]), "weight": th.randn_like(x_start["weight"])}
        
        architecture_x_start = x_start["architecture"]
        weight_x_start = x_start["weight"]
                            
        architecture_x_t = self.q_sample(x_start["architecture"], t=t_arch, noise=noise["architecture"])
        weight_x_t = self.q_sample(x_start["weight"], t=t_weight, noise=noise["weight"])


        if is_conditional.any():
            conditional_mask = is_conditional.reshape(B, *([1] * (x_start["architecture"].dim() - 1)))
            
            architecture_x_t = th.where(conditional_mask, x_start["architecture"],architecture_x_t)
            
            noise["architecture"] = th.where(conditional_mask,th.zeros_like(noise["architecture"]),noise["architecture"])

        architecture_output, weight_output = model(architecture_x_t, weight_x_t, t_arch, t_weight, **model_kwargs)

        
        architecture_loss = {}
        weight_loss = {}
        
        if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
            assert False, "Currently not implemented"
            
        elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
            
            if self.model_var_type in [
                ModelVarType.LEARNED,
                ModelVarType.LEARNED_RANGE,
            ]:
                # assert model_output.shape == (B, C * 2, *x_t.shape[2:])
                architecture_model_output, architecture_model_var_values = th.split(architecture_output, architecture_x_start.shape[2], dim=2)
                weight_model_output, weight_model_var_values = th.split(weight_output, weight_x_start.shape[2], dim=2)

                
                architecture_frozen_out = th.cat([architecture_model_output.detach(), architecture_model_var_values], dim=2)
                weight_frozen_out = th.cat([weight_model_output.detach(), weight_model_var_values], dim=2)
                frozen_out = {"architecture": architecture_frozen_out, "weight": weight_frozen_out}
                x_t = {"architecture": architecture_x_t, "weight": weight_x_t}

                vb_loss= self._vb_terms_bpd(
                    model=lambda *args, r=frozen_out, **kwargs: [r["architecture"],r["weight"]],
                    x_start=x_start,
                    x_t=x_t,
                    t_arch=t_arch,
                    t_weight=t_weight,
                    clip_denoised=False,
                    model_kwargs=model_kwargs,
                )["output"]
                
                architecture_loss["vb"] = vb_loss["architecture"]
                weight_loss["vb"] = vb_loss["weight"]
                
                if self.loss_type == LossType.RESCALED_MSE:
                    architecture_loss["vb"] *= self.num_timesteps / 1000.0
                    weight_loss["vb"] *= self.num_timesteps / 1000.0
            else:
                
                weight_model_output = weight_output
                architecture_model_output = architecture_output
            
            architecture_target = {
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=architecture_x_start, x_t=architecture_x_t, t=t_arch
                )[0],
                ModelMeanType.START_X: architecture_x_start,
                ModelMeanType.EPSILON: noise["architecture"],
            }[self.model_mean_type]

            weight_target = {
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=weight_x_start, x_t=weight_x_t, t=t_weight
                )[0],
                ModelMeanType.START_X: weight_x_start,
                ModelMeanType.EPSILON: noise["weight"],
            }[self.model_mean_type]



            architecture_loss_raw = (architecture_target - architecture_model_output) ** 2
            architecture_loss["mse"] = mean_flat(architecture_loss_raw)


            weight_loss_raw = (weight_target - weight_model_output) ** 2
            weight_loss["mse"] = mean_flat(weight_loss_raw)

        term = {"loss": 0}
        for key in weight_loss.keys():
            term[f"{key}_architecture"] = architecture_loss[key]
            term[f"{key}_weight"] = weight_loss[key]
            term["loss"] += term[f"{key}_weight"] + term[f"{key}_architecture"]

        return term


    def monl_joint_p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False
    ):
        """Joint sampling with matched timesteps for both modalities."""
        return self._monl_joint_sampling(
            model, shape, model_kwargs, progress, device
        )

    def monl_conditional_p_sample_loop(
        self,
        model,
        architecture_given,
        shape,
        noise=None,
        clip_denoised=False,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False
    ):
        """Conditional sampling with a clean architecture reference."""
        return self._monl_conditional_sampling(
            model, architecture_given, shape, model_kwargs, progress, device
        )

    def _monl_joint_sampling(self, model, shape, model_kwargs=None, 
                                         progress=False, device=None):
        """Helper implementing the joint MoNL reverse loop."""
        if device is None:
            device = next(model.parameters()).device
            
        if model_kwargs is None:
            model_kwargs = {}
        
        # Initialize noise
        architecture = th.randn(*shape["architecture"], device=device)
        weight = th.randn(*shape["weight"], device=device)
        x = {"architecture": architecture, "weight": weight}
        
        indices = list(range(self.num_timesteps))[::-1]  # T-1 down to 0
        
        if progress:
            from tqdm.auto import tqdm
            indices = tqdm(indices)
            
        batch_size = shape["weight"][0]
        
        for i in indices:
            t_arch = th.tensor([i] * batch_size, device=device)
            t_weight = th.tensor([i] * batch_size, device=device)
                
            with th.no_grad():
                out = self.p_sample(
                    model,
                    x,
                    t_arch,
                    t_weight,
                    clip_denoised=False,
                    model_kwargs=model_kwargs,
                )
                x = out["sample"]
                
        return x

    def _monl_conditional_sampling(self, model, architecture_given, shape, 
                                 model_kwargs=None, progress=False, device=None):
        """Helper implementing the conditional MoNL reverse loop."""
        if device is None:
            device = next(model.parameters()).device
            
        if model_kwargs is None:
            model_kwargs = {}
            
        weight = th.randn(*shape["weight"], device=device)
        x = {"architecture": architecture_given, "weight": weight}
        
        indices = list(range(self.num_timesteps))[::-1]  # T-1 down to 0
        
        if progress:
            from tqdm.auto import tqdm
            indices = tqdm(indices)
            
        for i in indices:
            t_arch = th.zeros(shape["weight"][0], device=device, dtype=th.long)  
            t_weight = th.tensor([i] * shape["weight"][0], device=device)
            
            with th.no_grad():
                out = self.p_sample(
                    model,
                    x,
                    t_arch,
                    t_weight,
                    clip_denoised=False,
                    model_kwargs=model_kwargs,
                )
                
                x = {"architecture": architecture_given, "weight": out["sample"]["weight"]}
        return x


def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """Gather array values for each timestep and broadcast to the desired shape."""
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + th.zeros(broadcast_shape, device=timesteps.device)
