import torch
import numpy as np


def _broadcast_tensor(a, broadcast_shape):
    while len(a.shape) < len(broadcast_shape):
        a = a[..., None]
    return a.expand(broadcast_shape)

def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.
    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res + torch.zeros(broadcast_shape, device=timesteps.device)


class _WrappedModel_BAGEL:
    def __init__(self, model, timesteps, num_timesteps, cfg_text_scale, cfg_img_scale, cfg_interval):
        self.model = model
        fm_steps = torch.cat([timesteps,torch.zeros_like(timesteps[0]).view(1)])
        self.time_steps = torch.flip(fm_steps, dims=[0])
        self.fm_steps = self.time_steps #/num_timesteps
        # print(self.fm_steps)

        self.cfg_text_scale = cfg_text_scale
        self.cfg_img_scale = cfg_img_scale
        self.cfg_interval = cfg_interval
            
        
    def __call__(self, x, t, y, kwargs):
        
        t = self.time_steps[t]
        timestep = torch.tensor([t] * x.shape[0], device=x.device)
        # t = t.expand(x.shape[0]).to(x.dtype) / 1000
        if t > self.cfg_interval[0] and t <= self.cfg_interval[1]:
            cfg_text_scale_ = self.cfg_text_scale
            cfg_img_scale_ = self.cfg_img_scale
        else:
            cfg_text_scale_ = 1.0
            cfg_img_scale_ = 1.0
        kwargs["cfg_text_scale"] = cfg_text_scale_
        kwargs["cfg_img_scale"] = cfg_img_scale_
        pred = self.model(x_t=x, timestep=timestep, **kwargs)
        return pred
