from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, EulerDiscreteScheduler
from diffusers.utils.import_utils import is_xformers_available
from typing import Union, Optional, Any, List, Dict, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

# suppress partial model loading warning
logging.set_verbosity_error()

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.cuda.amp import custom_bwd, custom_fwd 


class SpecifyGradient(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input_tensor, gt_grad):
        ctx.save_for_backward(gt_grad) 
        return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # dummy loss value

    @staticmethod
    @custom_bwd
    def backward(ctx, grad):
        gt_grad, = ctx.saved_tensors
        batch_size = len(gt_grad)
        return gt_grad / batch_size, None


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = True

def add_noise_and_return_std(scheduler, original_samples, noise, timesteps):
    ## modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py#L477
    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
    # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
    # for the subsequent add_noise calls
    scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device=original_samples.device)
    alphas_cumprod = scheduler.alphas_cumprod.to(dtype=original_samples.dtype)
    timesteps = timesteps.to(original_samples.device)

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
    return noisy_samples, sqrt_one_minus_alpha_prod

### Unet2DConditionModel with discriminator
class UNet2DConditionModel_EncoderDecoder(UNet2DConditionModel):
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        class_labels: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        return_dict: bool = True,
        return_flag: str = 'decoder',  # ['encoder', 'decoder', 'encoder_decoder']
    ) -> Union[UNet2DConditionOutput, Tuple]:
        r"""
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """

        assert return_flag in ['encoder', 'decoder', 'encoder_decoder'], f"Invalid return_flag: {return_flag}"
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)

        emb = self.time_embedding(t_emb, timestep_cond)

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb

        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

        if self.encoder_hid_proj is not None:
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)

        # 2. pre-process
        sample = self.conv_in(sample)

        # 初始化编码器特征列表
        encoder_features = []
        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
            # 记录编码器特征
            if return_flag != 'decoder':
                encoder_features.append(sample.clone())  # 深拷贝防止梯度干扰

            down_block_res_samples += res_samples

        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
                new_down_block_res_samples += (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )

            # 记录中间层特征
            if return_flag != 'decoder':
                encoder_features.append(sample.clone())

        if mid_block_additional_residual is not None:
            sample = sample + mid_block_additional_residual

        # 提前返回编码器特征
        if return_flag == 'encoder':
            if return_dict:
                return UNet2DConditionOutput(sample=encoder_features)
            else:
                return (encoder_features,)

        # 5. up
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    upsample_size=upsample_size,
                    attention_mask=attention_mask,
                )
            else:
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )

        # 6. post-process
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        # 组合返回结果
        if return_flag == 'encoder_decoder':
            if return_dict:
                return (UNet2DConditionOutput(sample=sample), encoder_features)
            else:
                return (sample, encoder_features)
        else:
            if not return_dict:
                return (sample,)
            return UNet2DConditionOutput(sample=sample)

class StableDiffusion(nn.Module):
    def __init__(self, device, sd_version='2.1', hf_key=None, opt=None):
        super().__init__()

        self.device = device
        self.sd_version = sd_version
        self.opt = opt

        print(f'[INFO] loading stable diffusion...')
        print('stable diffusion:', sd_version)
        
        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "/gpfs/share/home/2206192113/cvpr_code/sim_3d/model/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "/mnt/workspace/weijian/data/downloads/huggingface_cache/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "/mnt/workspace/weijian/data/downloads/huggingface_cache/stable-diffusion-v1-5"
        elif self.sd_version == 'mvdream':
            model_key = "/gpfs/share/home/2301111469/sim_3d_reward/mvdream_diffusers/weights_imagedream"
        elif self.sd_version == 'imagedream':
            # from mvdream_diffusers.pipeline_mvdream import MVDreamPipeline
            model_key = "/gpfs/share/home/2301111469/sim_3d_reward/imagedream"
        else:
            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        # Create model
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
        # self.unet = UNet2DConditionModel_EncoderDecoder.from_pretrained(model_key, subfolder="unet").to(self.device)
        
        
        
        for p in self.vae.parameters():
            p.requires_grad_(False)
        for p in self.unet.parameters():
            p.requires_grad_(False)
        for p in self.text_encoder.parameters():
            p.requires_grad_(False)
                
        # for p in self.unet.parameters():
        #     p.requires_grad_(False)

        # if is_xformers_available():
        #     self.unet.enable_xformers_memory_efficient_attention()
        
        print("NOT using v-pred")
        opt.v_pred = False
        
        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
        alphas_cumprod = self.scheduler.alphas_cumprod
        # self.sigma_min = 0.029167533 
        # self.sigma_max = 14.614647
        self.sigmas = (((1 - alphas_cumprod) / alphas_cumprod) ** 0.5).to(self.device)

        self.num_train_timesteps = self.scheduler.config.num_train_timesteps
        self.min_step = int(self.num_train_timesteps * opt.t_range[0])
        self.max_step = int(self.num_train_timesteps * opt.t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience

        print(f'[INFO] loaded stable diffusion!')

    def t_lognormal(self, n, P_mean, P_std):
        noise = (torch.randn([n,], device=self.device) * P_std + P_mean).exp()
        # find the closest timestep
        index = torch.cdist(noise.view(1, -1, 1), self.sigmas.view(1, -1, 1)).argmin(2)
        
        return index.view(-1)

    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')

        with torch.no_grad():
            text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')

        with torch.no_grad():
            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    def train_step(self, text_embeddings, pred_rgb, divergence, guidance_scale=100, q_unet = None, pose = None, shading = None, grad_clip = None, as_latent = False, t5 = False):
        
        # interp to 512x512 to be fed into vae.
        assert torch.isnan(pred_rgb).sum() == 0, print(pred_rgb)
        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False).contiguous()
        elif self.opt.latent == True:
            latents = pred_rgb
        else:
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False).contiguous()
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)        
        if self.opt.t1_max < 0: # using default value
            self.opt.t1_max = self.max_step
        if self.opt.t_dist == 'all_log_normal':
            P_mean = self.opt.P_mean
            P_std = self.opt.P_std
            t = self.t_lognormal(1, P_mean, P_std)
        else:
            if t5: # Anneal time schedule
                if self.opt.t_dist == 'log_normal':
                    P_mean = self.opt.P_mean
                    P_std = self.opt.P_std
                    t = self.t_lognormal(1, P_mean, P_std) # (torch.randn([1], device=self.device) * P_std + P_mean).exp().round().clamp(self.min_step, self.max_step + 1).long()
                elif self.opt.t_dist == 'uniform':
                    t2_max = self.opt.t2_max
                    if t2_max < 0: # using default value
                        t2_max = 500
                    t = torch.randint(self.min_step, t2_max + 1, [1], dtype=torch.long, device=self.device)
            else:
                t = torch.randint(self.min_step, self.opt.t1_max + 1, [1], dtype=torch.long, device=self.device)
        # t = torch.randint(self.min_step, 500 + 1, [latents.shape[0]], dtype=torch.long, device=self.device) # (torch.randn([1], device=self.device) * P_std + P_mean).exp().round().clamp(self.min_step, self.max_step + 1).long()
        
        # predict the noise residual with unet, NO grad!
        # with torch.no_grad(): ## requires gradient!

        # self.unet.requires_grad_(False)
        # for p in self.unet.parameters():
        #     p.requires_grad_(False)
        # self.unet.eval()

        # add noise
        noise = torch.randn_like(latents)

        ## modified to also return noise variance
        latents_noisy = self.scheduler.add_noise(latents, noise, t)
        # latents_noisy, xt_x0_std = add_noise_and_return_std(self.scheduler, latents, noise, t)

        xt_x0_std = (1 - self.scheduler.alphas_cumprod.to(self.device)[t]) ** 0.5
        xt_x0_std = xt_x0_std.flatten()
        while len(xt_x0_std.shape) < len(latents_noisy.shape):
            xt_x0_std = xt_x0_std.unsqueeze(-1)

        # pred noise
        # latent_model_input = torch.cat([latents_noisy] * 2)

        temb_uncond, temb_cond = text_embeddings.chunk(2)
        noise_pred_uncond = self.unet(latents_noisy, t, encoder_hidden_states=temb_uncond).sample
        noise_pred_text = self.unet(latents_noisy, t, encoder_hidden_states=temb_cond).sample

        # noise_pred = self.unet(latent_model_input, torch.cat([t]*2), encoder_hidden_states=text_embeddings).sample
        # perform guidance (high scale from paper!)
        # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
        # print("original sds,", self.opt.sds)
        self.opt.sds = False
        if self.opt.sds is False:
            if q_unet is not None:
                if pose is not None:
                    # print("q_unet type: ",type(q_unet))
                    noise_pred_q, classifier_logits = q_unet(latents_noisy, t, c = pose, shading = shading, return_flag="encoder_decoder")
                    rep = classifier_logits[-1].float()
                    noise_pred_q = noise_pred_q.sample
                else:
                    raise NotImplementedError()

                if self.opt.v_pred:
                    sqrt_alpha_prod = self.scheduler.alphas_cumprod.to(self.device)[t] ** 0.5
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
                    while len(sqrt_alpha_prod.shape) < len(latents_noisy.shape):
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
                    sqrt_one_minus_alpha_prod = (1 - self.scheduler.alphas_cumprod.to(self.device)[t]) ** 0.5
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
                    while len(sqrt_one_minus_alpha_prod.shape) < len(latents_noisy.shape):
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
                    noise_pred_q = sqrt_alpha_prod * noise_pred_q + sqrt_one_minus_alpha_prod * latents_noisy

        # w(t), sigma_t^2
        factor1 = self.scheduler.alphas_cumprod.to(self.device)[t] ** 0.5
        factor2 = torch.sqrt(1 - factor1 ** 2)
        factor = factor2 / (factor1 + 1e-4)
        factor = factor.view(-1,1,1,1)
        use_factor = True

        if self.opt.wgt_type == 'orig':
            w = (1 - self.alphas[t]).view(-1,1,1,1)
        elif self.opt.wgt_type == 'nowgt':
            w = 1.0

        if use_factor:
            w = w * factor

        #---------------previous version----------------  
        # regu_scale = 0
        # noise_diff = noise_pred_q - noise_pred_text

        # with torch.no_grad():
        #     cfg_vec = (guidance_scale - 1)*(noise_pred_text - noise_pred_uncond) #guidance_scale可以换成4.5/1.5

        # sim_loss = noise_diff / noise_diff.square().sum([1,2,3], keepdims=True).sqrt()
        # sim_loss = sim_loss * (noise - noise_pred_q)/xt_x0_std
        # # cfg_loss = cfg_vec * latents_noisy
        # cfg_loss = cfg_vec * latents_noisy / noise_diff.square().sum([1,2,3], keepdims=True).sqrt().detach()/xt_x0_std  #是否要多除以这个
        # # cfg_loss = cfg_vec * latents_noisy / cfg_vec.square().sum([1, 2, 3], keepdims=True).sqrt()
        # regu_loss = regu_scale * noise_diff.square().sum([1,2,3], keepdims=True).sqrt()
        # # loss = w * (10 * sim_loss + cfg_loss + regu_loss)
        # # loss = w * (sim_loss + regu_loss)
        # loss = w * (cfg_loss + sim_loss)
        # # print("这里的loss是:", loss)
        # loss = loss.sum([1,2,3]).mean()

        # pseudo_loss = loss.detach().clone()

        # # loss = noise_diff / noise_diff.square().sum([1,2,3], keepdims=True).sqrt()
        # # loss = loss * (noise - noise_pred_q)/xt_x0_std
        # # loss = w * loss 
        # # loss = loss.sum([1,2,3]).mean()

        # # pseudo_loss = loss.detach().clone()

        #--------------------------------afterchange--------------------------------
        xt_x0_std = -((1 - self.scheduler.alphas_cumprod.to(self.device)[t]) ** 0.5) / (self.scheduler.alphas_cumprod.to(self.device)[t] ** 0.5)
        xt_x0_std = xt_x0_std.flatten()
        while len(xt_x0_std.shape) < len(latents_noisy.shape):
            xt_x0_std = xt_x0_std.unsqueeze(-1)

        noise_diff = (noise_pred - noise_pred_q) * xt_x0_std

        # with torch.no_grad():
        #     cfg_vec = (guidance_scale - 1)*(noise_pred_text - noise_pred_uncond) #guidance_scale可以换成4.5/1.5

        sim_loss = noise_diff / noise_diff.square().sum([1,2,3], keepdims=True).sqrt()
        sim_loss = sim_loss * (noise_pred_q - noise) * xt_x0_std

        
        ## FIXME add adverserial loss and complete the loss function
        ## first I need to study how dmd2 deal with CFG + adverserial loss
        ## Then I need to figure how to transform rep into logits, with a additional network or not, I can look at sida code

        # calculate gan loss
        logits= rep.mean(dim=1, keepdim=True)
        y_labels = torch.ones_like(logits)
        bce_loss = nn.BCEWithLogitsLoss()
        with torch.no_grad():
            weight_factor = noise_diff.square().sum([1,2,3], keepdims=True).sqrt()
        loss_gan = 100*bce_loss(logits.clamp(-10,10),y_labels).to(torch.float32)/weight_factor

        # print("divergence", divergence)
        if divergence == "Reverse-KL":
            loss = sim_loss
        elif divergence == "Forward-KL":
            with torch.no_grad():
                density_ratio = -loss_gan/(1-loss_gan)
            loss = density_ratio*sim_loss
        elif divergence == "Jeffrey-KL":
            with torch.no_grad():
                density_ratio = -loss_gan/(1-loss_gan)

            loss = 0.4*density_ratio*sim_loss+0.6*sim_loss
        
        # print(density_ratio)

        loss = loss.sum([1, 2, 3]).mean()

        pseudo_loss = loss.detach().clone()



        return loss, pseudo_loss, latents

    def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):

        if latents is None:
            latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            for i, t in enumerate(self.scheduler.timesteps):
                # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                latent_model_input = torch.cat([latents] * 2)

                # predict the noise residual
                with torch.no_grad():
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

                # perform guidance
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
        
        return latents

    def decode_latents(self, latents):

        latents = 1 / 0.18215 * latents

        with torch.no_grad():
            imgs = self.vae.decode(latents).sample

        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        
        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * 0.18215

        return latents

    def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):

        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]

        # Text embeds -> img latents
        latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
        
        # Img latents -> imgs
        imgs = self.decode_latents(latents) # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype('uint8')

        return imgs


if __name__ == '__main__':

    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument('prompt', type=str)
    parser.add_argument('--negative', default='', type=str)
    parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
    parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
    parser.add_argument('-H', type=int, default=512)
    parser.add_argument('-W', type=int, default=512)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--steps', type=int, default=50)
    opt = parser.parse_args()

    seed_everything(opt.seed)

    device = torch.device('cuda')

    sd = StableDiffusion(device, opt.sd_version, opt.hf_key)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()
