import torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
import numpy as np
import os 
import torchvision
from einops import rearrange
from PIL import Image, ImageDraw, ImageFont
import json 
import glob
from diffusers import UnCLIPPriorPipeline
import clip
from functools import partial

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL
from models.unet_3d_condition import UNet3DConditionModel


# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = #

os.environ["WANDB__SERVICE_WAIT"] = "300"


class ImageCaptionSaver:
    def __init__(self, base_path, nrow=8, normalize=True, scale_each=True, range=(-1,1) ):
        self.base_path = base_path 
        self.nrow = nrow
        self.normalize = normalize
        self.scale_each = scale_each
        self.range = range
    def __call__(self, images, captions, seen):
        save_path = os.path.join(self.base_path, str(seen).zfill(8)+'.png')
        torchvision.utils.save_image( images, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range )
        assert images.shape[0] == len(captions)
        save_path = os.path.join(self.base_path, 'captions.txt')
        with open(save_path, "a") as f:
            f.write( str(seen).zfill(8) + ':\n' )    
            for cap in captions:
                f.write( cap + '\n' )  
            f.write( '\n' ) 


def read_official_ckpt(ckpt_path):      
    out = {}
    out["diffusion"] = {}
    state_dict = torch.load(ckpt_path, map_location="cpu")
    "Read offical pretrained SD ckpt and convert into my style" 
    if 'state_dict' in state_dict.keys():
        state_dict = state_dict["state_dict"]
        for k,v in state_dict.items():
            if k.startswith('model.diffusion_model'):
                pass
            elif k.startswith('cond_stage_model'):
                pass
            elif k.startswith('first_stage_model'):
                pass
            elif k in ["model_ema.decay", "model_ema.num_updates"]:
                pass
            else:
                out["diffusion"][k] = v     
        return out 


def batch_to_device(batch, device):
    for k in batch:
        if isinstance(batch[k], torch.Tensor):
            batch[k] = batch[k].to(device)
    return batch


def disable_grads(model):
    for p in model.parameters():
        p.requires_grad = False


def tensor_to_vae_latent(t, vae):
    video_length = t.shape[1]
    t = rearrange(t, "b f c h w -> (b f) c h w")
    latents = vae.encode(t).latent_dist.sample()
    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
    latents = latents * 0.18215
    return latents


def decode_latents(latents, vae):
    latents = 1 / 0.18215 * latents
    batch_size, channels, num_frames, height, width = latents.shape
    latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
    image = vae.decode(latents).sample
    video = (
        image[None, :]
        .reshape(
            (
                batch_size,
                num_frames,
                -1,
            )
            + image.shape[2:]
        )
        .permute(0, 2, 1, 3, 4)
    )
    video = video.float()
    return video    


def draw_box(img, boxes, texts=None):
    labels = texts
    draw = ImageDraw.Draw(img)
    mask = Image.new("L", img.size, 0)
    mask_draw = ImageDraw.Draw(mask)
    colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
    x=0
    for box, label in zip(boxes, labels):
        color = tuple(np.random.randint(0, 255, size=3).tolist())
        # draw
        x0, y0, x1, y1 = box
        x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
        draw.rectangle([x0, y0, x1, y1], outline=colors[x % len(colors)], width=6)
        font = ImageFont.load_default()
        if hasattr(font, "getbbox"):
            bbox = draw.textbbox((x0, y0), str(label.split("(")[0]), font)
        else:
            w, h = draw.textsize(str(label.split("(")[0]), font)
            bbox = (x0, y0, w + x0, y0 + h)
        draw.rectangle(bbox, fill=color)
        draw.text((x0, y0), str(label.split("(")[0]), fill="white")
        mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
        x += 1
    return img


def vis_getitem_data(index=None, out=None, return_tensor=False, name="res.jpg", print_caption=False):
    img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
    canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
    W, H = img.size
    if print_caption:
        caption = out["caption"]
        print(caption)
        print(" ")
    boxes = []
    for box in out["boxes"]:    
        x0,y0,x1,y1 = box
        boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
    if 'texts' in out:
        img = draw_box(img, boxes, out['texts'])
    else:
        img = draw_box(img, boxes)
    if return_tensor:
        return  torchvision.transforms.functional.to_tensor(img)
    else:
        img.save(name)  


def flatten_extend(matrix):
    flat_list = []
    for row in matrix:
        flat_list.extend(row)
    return flat_list



def alpha_generator(length, type=None):
    """
    length is total timestpes needed for sampling. 
    type should be a list containing three values which sum should be 1
    It means the percentage of three stages: 
    alpha=1 stage 
    linear deacy stage 
    alpha=0 stage. 
    For example if length=100, type=[0.8,0.1,0.1]
    then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps,
    and the last 100 stpes are 0.    
    """
    if type == None:
        type = [1,0,0]
    assert len(type)==3 
    assert type[0] + type[1] + type[2] == 1
    stage0_length = int(type[0]*length)
    stage1_length = int(type[1]*length)
    stage2_length = length - stage0_length - stage1_length
    if stage1_length != 0: 
        decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1]
        decay_alphas = list(decay_alphas)
    else:
        decay_alphas = []
    alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length
    assert len(alphas) == length
    return alphas


def set_alpha_scale(model, alpha_scale):
    from ldm.modules.attention import GatedCrossAttentionDense, GatedSelfAttentionDense
    for module in model.modules():
        if type(module) == GatedCrossAttentionDense or type(module) == GatedSelfAttentionDense:
            module.scale = alpha_scale


def scale_a_box(box, scale_factor=2):
    x0, y0, x1, y1 = box 
    xc = (x0+x1) / 2
    yc = (y0+y1) / 2
    half_len_x = (x1 - xc) 
    half_len_y = (y1 - yc)
    x0 = max(0.0, xc - scale_factor*half_len_x)
    x1 = min(1.0, xc + scale_factor*half_len_x)
    y0 = max(0.0, yc - scale_factor*half_len_y)
    y1 = min(1.0, yc + scale_factor*half_len_y)
    new_box = [x0, y0, x1, y1]
    return new_box 


def scale_boxes(boxes, total_area_threshold=0.5):
    n_boxes = len(boxes)
    total_area = 0
    for i in range(n_boxes):
        x0, y0, x1, y1 = boxes[i] 
        if x1 < x0:
            x1 = min(1.0, x0+0.1)
        if y1 < y0:
            y1 = min(1.0, y0+0.1)
        area = (x1 - x0) * (y1 - y0)
        if area <= 0.01:
            boxes[i] = scale_a_box([x0, y0, x1, y1])
        area = (x1 - x0) * (y1 - y0)
        total_area += area 
    count = 0
    while total_area < total_area_threshold and count < 100:
        new_area = 0
        for i in range(n_boxes):
            boxes[i] = scale_a_box(boxes[i], scale_factor=1.02)
            x0, y0, x1, y1 = boxes[i] 
            new_area += (x1 - x0) * (y1 - y0)
        total_area = new_area 
        count += 1
    return boxes 
            
            
class Inference:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(self.config.device)
        #### first generate an image from karlo from a given prompt
        self.pipe = UnCLIPPriorPipeline.from_pretrained(self.config.karlo_v1_alpha_path, torch_dtype=torch.float16)
        self.pipe = self.pipe.to(self.device)
        self.CLIPModelv1, self.preprocess = clip.load("ViT-L/14", device=self.device)

        pretrained_modelscope_path = self.config.pretrained_modelscope_path
        self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_modelscope_path, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(pretrained_modelscope_path, subfolder="text_encoder").to(self.device)
        self.autoencoder = AutoencoderKL.from_pretrained(pretrained_modelscope_path, subfolder="vae").to(self.device)
        self.diffusion = instantiate_from_config(config.diffusion).to(self.device)
        self.model = UNet3DConditionModel.from_pretrained(pretrained_modelscope_path, 
                                                            subfolder="unet", 
                                                            low_cpu_mem_usage=False,
                                                            device_map=None,
                                                            ).to(self.device)

        self.model.reset_position_net(img_dim=self.config.image_embedding_dim, text_dim=self.config.text_embedding_dim, 
                                        mid_dim=self.config.position_net_mid_dim, out_dim=self.config.position_net_out_dim, 
                                        position_net_point_or_box=self.config.position_net_point_or_box)
        if self.config.enable_fuser == True:
            self.model.add_gatedSA(key_dim = self.model.position_net.out_dim)
        
        self.model.set_embedding_dim(self.config)
            
        if self.config.multi_scene_cross_frame_attn == True: 
            self.model.set_cross_frame_attn(cross_frame_attn=True, n_frames=self.config.n_sample_frames)

        # this is SD1.4 state_dict
        state_dict = read_official_ckpt( self.config.pretrained_sd14_path )
        self.diffusion.load_state_dict( state_dict["diffusion"]  )

        self.autoencoder.eval()
        self.text_encoder.eval()
        disable_grads(self.autoencoder)
        disable_grads(self.text_encoder)
        disable_grads(self.model)

        # = = = = = = = = = = = = = = = = = = = = load from autoresuming ckpt = = = = = = = = = = = = = = = = = = = = #
        checkpoint = self.config.pretrained_gligen_modelscope_path
        if checkpoint is not None:
            checkpoint = torch.load(checkpoint, map_location="cpu")
            self.model.load_state_dict(checkpoint["model"])

            print("===============================")
            print("auto resumed checkpoint loaded!")
            print("===============================")

        # = = = = = = = = = = = = = = = = = = = = misc and ddp = = = = = = = = = = = = = = = = = = = =#    
        # func return input for grounding tokenizer 
        self.grounding_tokenizer_input = instantiate_from_config(config.grounding_tokenizer_input)
        self.model.grounding_tokenizer_input = self.grounding_tokenizer_input
        self.grounding_downsampler_input = None

        print("########## finished model initialization")
        if "scale_boxes" in self.config:
            print(f"############# use scaled boxes = {self.config.scale_boxes}  ###############")


    @torch.no_grad()
    def get_prompt_ids(self, prompt):
        return self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids
    
    
    @torch.no_grad()
    def pre_process_raw_layouts_for_inference(self, raw_layouts):
        n_bf = len(raw_layouts)
        batch = {}
        batch['use_image_embedding'] = self.config.image_embedding_dim > 0 
        batch['use_text_embedding'] = self.config.text_embedding_dim > 0 
        batch['boxes'] = torch.zeros(n_bf, 30, 4)#.to(self.device)
        batch['masks'] = torch.zeros(n_bf, 30)#.to(self.device)
        if batch['use_text_embedding']:
            batch['text_embeddings'] = torch.zeros(n_bf, 30, 768)#.to(self.device)
        if batch['use_image_embedding']:
            batch['cropped_image_embeddings'] = torch.zeros(n_bf, 30, 768)#.to(self.device)

        all_names = []
        for i in range(n_bf):
            n_objs = len(raw_layouts[i])
            names = [e[0] for e in raw_layouts[i]]
            all_names += names

        all_names = list(set(all_names))
        cropped_img_emb_dict = {}
        for name in all_names:
            cropped_img_emb_dict[name] = self.pipe(name)

        for i in range(n_bf):
            n_objs = len(raw_layouts[i])
            boxes = [e[1] for e in raw_layouts[i]]
            boxes = scale_boxes(boxes) if ("scale_boxes" in self.config and self.config.scale_boxes) else boxes 
            names = [e[0] for e in raw_layouts[i]]
            batch['masks'][i][:n_objs] = 1.0
            batch['boxes'][i][:n_objs] = torch.Tensor(boxes)
            ### check x0<=x1, y0<=y2
            for j in range(n_objs):
                (batch['boxes'][i][j][0] <= batch['boxes'][i][j][2]) == True
                (batch['boxes'][i][j][1] <= batch['boxes'][i][j][3]) == True
            if batch['use_image_embedding']:
                for j in range(n_objs):
                    name = names[j]
                    batch['cropped_image_embeddings'][i][j] = cropped_img_emb_dict[name] #img_emb
            if batch['use_text_embedding']:
                token_ids = torch.zeros(30, 77, dtype = torch.int32).to(self.device)
                batch_entity_ids = self.get_prompt_ids(names).to(self.device)
                token_ids[:n_objs] = batch_entity_ids
                context = self.CLIPModelv1.encode_text(token_ids)
                batch['text_embeddings'][i] = context
           
        return batch


    @torch.no_grad()
    def inference_single_scene(self):

        inference_path = self.config.INFERENCE_DATA_ROOT 
        n_frames = self.config.n_sample_frames

        model_wo_wrapper = self.model.module if self.config.distributed else self.model
        model_wo_wrapper.eval()
        model_wo_wrapper.to(self.device)
        self.text_encoder.to(self.device)

        inference_filenames = os.listdir(inference_path)
        inference_filenames = sorted(inference_filenames)
        inference_files = [os.path.join(inference_path, f) for f in inference_filenames]

        range_start = self.config.range_start 
        range_end = self.config.range_end
        if range_end == -1:
            range_end = len(inference_filenames)
            print("range end: ", range_end)

        if 'alpha_type' in self.config:
            alpha_type = self.config.alpha_type
        else:
            alpha_type = [0.2, 0, 0.8]

        alpha_generator_func = partial(alpha_generator, type=alpha_type)

        if 'sampler_type' in self.config and self.config['sampler_type'] == 'DDIM':
            sampler = DDIMSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)    
        else:
            sampler = PLMSSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
        
        ii = range_start 

        while ii < range_end:
            batch_boxes = []
            batch_masks = []
            batch_text_embeddings = []
            batch_img_embeddings = []
            context_list = []
            entity_list = []
            uc_list = []
            filename_list = [] 
            batch_count = 0 

            while batch_count < self.config.batch_size and ii < range_end:

                allfiles = sorted(glob.glob(f"{inference_files[ii]}/*.json"))
                if len(allfiles) == 0: # remove empty folders without gpt layouts
                    ii += 1
                    continue 

                print(allfiles)
                file = os.path.join(inference_path, inference_filenames[ii])
                filename = inference_filenames[ii]
                caption = filename
                if len(allfiles) == 1: # non UCF
                    f = open(allfiles[0])
                else:
                    f = open(os.path.join(file, f'{batch_count}.json'))

                data = json.load(f)
                raw_layouts = data['layouts']

                if 'gpt_alpha' in self.config and self.config.gpt_alpha:
                    gpt_alpha = max(min(0.3, float(data['score'])), 0)
                    alpha_type = [gpt_alpha, 0, 1-gpt_alpha]
                    alpha_generator_func = partial(alpha_generator, type=alpha_type)
                    if 'sampler_type' in self.config and self.config['sampler_type'] == 'DDIM':
                        sampler = DDIMSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)    
                    else:
                        sampler = PLMSSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
                 
                if len(raw_layouts) != n_frames:
                    print("len(raw_layouts)", len(raw_layouts))
                    if "UCF" in inference_path:
                        batch_count += 1
                    else:
                        ii += 1
                    continue 
                else:
                    try: # in case the layout cannot be correctly parsed 
                        batch = self.pre_process_raw_layouts_for_inference(raw_layouts)
                        entity_names = [[raw_layouts[i][j][0] for j in range(len(raw_layouts[i]))] for i in range(len(raw_layouts))]
                        entity_list.append(entity_names)
                        batch_to_device(batch, self.device)
                        ### context & uc
                        batch_caption_ids = self.get_prompt_ids(1*[""]).to(self.device)
                        uc = self.text_encoder(batch_caption_ids)[0]
                        batch_caption_ids = self.get_prompt_ids(caption).to(self.device)
                        context = self.text_encoder( batch_caption_ids )[0]
                        context_list.append(context)
                        uc_list.append(uc)
                        batch_masks.append(batch['masks'])
                        batch_boxes.append(batch['boxes'])
                        ### entity text embeddings & cropped image embeddings 
                        if batch['use_text_embedding']:
                            batch_text_embeddings.append(batch['text_embeddings'])
                        if batch['use_image_embedding']:    
                            batch_img_embeddings.append(batch['cropped_image_embeddings'])
                        filename_list.append(filename)
                        
                        batch_count += 1
                        ii += 1
                    except:
                        ii += 1

            shape = (len(filename_list), model_wo_wrapper.in_channels, n_frames, 
                self.config.new_image_size // 8, self.config.new_image_size // 8)

            if len(batch_masks) == 0 : 
                continue 
            batch_masks = torch.concat(batch_masks)        
            batch_boxes = torch.concat(batch_boxes)   
            context = torch.concat(context_list)
            uc = torch.concat(uc_list)   

            if batch['use_text_embedding']:
                batch_text_embeddings = torch.concat(batch_text_embeddings)   
            if batch['use_image_embedding']:
                batch_img_embeddings = torch.concat(batch_img_embeddings)   

            full_batch = {'masks': batch_masks, 
                        'boxes': batch_boxes, 
                        'text_embeddings': batch_text_embeddings, 
                        "cropped_image_embeddings": batch_img_embeddings,
                        "use_text_embedding": batch['use_text_embedding'], 
                        "use_image_embedding": batch['use_image_embedding'], 
                        }
            if self.config.position_net_point_or_box == 'point':
                full_batch['centers'] = torch.concat( [ ((full_batch['boxes'][:,:,0] + full_batch['boxes'][:,:,2]) * 0.5).unsqueeze(-1), 
                                            ((full_batch['boxes'][:,:,1] + full_batch['boxes'][:,:,3]) * 0.5).unsqueeze(-1)], axis = -1)
                                            
            full_batch['position_net_point_or_box'] = self.config.position_net_point_or_box

            grounding_input = self.grounding_tokenizer_input.prepare(full_batch)
            input = dict( x=None, 
                            timesteps=None, 
                            context=context, 
                            inpainting_extra_input=None,
                            grounding_extra_input=None,
                            grounding_input=grounding_input )
            input['use_step_caption'] = False 
                
            samples = sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5, seed=seed)
            autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. 
            self.autoencoder.enable_slicing()
            samples = decode_latents(samples, autoencoder_wo_wrapper)
            samples = rearrange(samples, "b c f h w -> (b f) c h w")
            samples = torch.clamp(samples, min=-1, max=1)

            flattened_entity_list = flatten_extend(entity_list)
            gen_images_with_box_drawing = []
            for i in range(len(samples)):
                temp_data = {"image": samples[i], "boxes":full_batch["boxes"][i], "texts":flattened_entity_list[i]}
                im = vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False)
                gen_images_with_box_drawing.append(im)                
            
            gen_images_with_box_drawing = torch.stack(gen_images_with_box_drawing)
            gen_images_without_box_drawing = 0.5*samples + 0.5

            filenames = list(np.concatenate([([i]*n_frames) for i in filename_list], axis=0))
            filenames_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_frames', f) for f in filenames]
            filenames_bbox_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_bbox_frames', f) for f in filenames]

            filenames_unique_gifs = self.config.INFERENCE_DATA_OUTPUT + '_gifs'
            filenames_unique_bbox_gifs = self.config.INFERENCE_DATA_OUTPUT + '_bbox_gifs'
            filenames_unique_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_frames', f) for f in filename_list]
            filenames_unique_bbox_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_bbox_frames', f) for f in filename_list]

            os.makedirs(filenames_unique_gifs, exist_ok = True)
            os.makedirs(filenames_unique_bbox_gifs, exist_ok = True)
            for i in range(len(filenames_unique_frames)):
                os.makedirs(filenames_unique_frames[i], exist_ok = True)
                os.makedirs(filenames_unique_bbox_frames[i], exist_ok = True)
                
            ##### generate images #####
            for i in range(len(gen_images_with_box_drawing)):
                torchvision.utils.save_image(gen_images_without_box_drawing[i], os.path.join(filenames_frames[i], f"frame{i%self.config.n_sample_frames+1}.jpg"))
                torchvision.utils.save_image(gen_images_with_box_drawing[i], os.path.join(filenames_bbox_frames[i], f"frame{i%self.config.n_sample_frames+1}.jpg"))                 

            ##### generate gif #####
            for i in range(len(filename_list)):
                # without bbox 
                filenames = sorted(glob.glob(f"{filenames_unique_frames[i]}/*.jpg"))
                filenames = [filenames[0]] + filenames[8:] + filenames[1:8]
                frames = [Image.open(image) for image in filenames]
                frame_one = frames[0]
                frame_one.save(os.path.join(filenames_unique_gifs, f"{filename_list[i]}.gif"), format="GIF", append_images=frames, save_all=True, duration=200, loop=0)
                # with bbox 
                filenames = sorted(glob.glob(f"{filenames_unique_bbox_frames[i]}/*.jpg"))
                filenames = [filenames[0]] + filenames[8:] + filenames[1:8]
                frames = [Image.open(image) for image in filenames]
                frame_one = frames[0]
                frame_one.save(os.path.join(filenames_unique_bbox_gifs, f"{filename_list[i]}.gif"), format="GIF", append_images=frames, save_all=True, duration=200, loop=0)


    @torch.no_grad()
    def inference_multi_scene(self):
        
        inference_path = self.config.INFERENCE_DATA_ROOT 
        n_frames = self.config.n_sample_frames

        model_wo_wrapper = self.model.module if self.config.distributed else self.model
        model_wo_wrapper.eval()
        model_wo_wrapper.to(self.device)
        self.text_encoder.to(self.device)

        inference_files = glob.glob(inference_path + "/*")
        inference_files = sorted(inference_files)

        range_start = self.config.range_start 
        range_end = self.config.range_end
        if range_end == -1:
            range_end = len(inference_files)

        if 'alpha_type' in self.config:
            alpha_type = self.config.alpha_type
        else:
            alpha_type = [0.2, 0, 0.8]

        alpha_generator_func = partial(alpha_generator, type=alpha_type)
        sampler = PLMSSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)

        ii = range_start 
        
        while ii < range_end:

            file = inference_files[ii]
            filename = file.split("/")[-1]
            caption = filename.replace("_", " ").strip()

            subfiles = glob.glob(f"{file}/*")

            n_subfiles = len(subfiles)
            self.config.batch_size = n_subfiles 

            shape = (self.config.batch_size, model_wo_wrapper.in_channels, n_frames, 
                        self.config.new_image_size // 8, self.config.new_image_size // 8)

            raw_layouts = []
            uc_list = []
            context_list = []
            entity_list = []
            gpt_alpha_list = []
            for jj in range(n_subfiles):
                subfilename = subfiles[jj].split("/")[-1]
                json_filename = glob.glob(f"{subfiles[jj]}/*.json")[0]
                f = open(glob.glob(f"{subfiles[jj]}/*.json")[0])
                data = json.load(f)

                if 'gpt_alpha' in self.config and self.config.gpt_alpha:
                    if 'score' not in data.keys():
                        continue 
                    gpt_alpha = max(min(0.3, float(data['score'])), 0)
                    gpt_alpha_list.append(gpt_alpha)
                    print(f"************* gpt_alpha = {gpt_alpha} *************")

                raw_layouts += data['layouts']
                batch_caption_ids = self.get_prompt_ids(1*[""]).to(self.device)
                uc = self.text_encoder(batch_caption_ids)[0]
                batch_caption_ids = self.get_prompt_ids(json_filename.split("/")[-1].replace(".json", "")).to(self.device)
                context = self.text_encoder( batch_caption_ids )[0]
                context_list.append(context)
                uc_list.append(uc)
                entity_names = [[data['layouts'][jj][j][0] for j in range(len(data['layouts'][jj]))] for jj in range(len(data['layouts']))]
                entity_list.append(entity_names)
            
            if 'gpt_alpha' in self.config and self.config.gpt_alpha:
                gpt_alpha = np.array(gpt_alpha_list).mean()
                if not (gpt_alpha<=0.3 and gpt_alpha>=0):
                    gpt_alpha = 0.2
                alpha_type = [gpt_alpha, 0, 1-gpt_alpha]
                alpha_generator_func = partial(alpha_generator, type=alpha_type)

                sampler = PLMSSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
                print(f"************* alpha_type = {alpha_type} *************")

            context = torch.concat(context_list)
            uc = torch.concat(uc_list)

            batch = self.pre_process_raw_layouts_for_inference(raw_layouts)
            batch_to_device(batch, self.device)
            batch['position_net_point_or_box'] = self.config.position_net_point_or_box

            grounding_input = self.grounding_tokenizer_input.prepare(batch)
            input = dict( x=None, 
                            timesteps=None, 
                            context=context, 
                            inpainting_extra_input=None,
                            grounding_extra_input=None,
                            grounding_input=grounding_input )


            samples = sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5)

            autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. 

            self.autoencoder.enable_slicing()

            samples = decode_latents(samples, autoencoder_wo_wrapper)
            samples = rearrange(samples, "b c f h w -> (b f) c h w")
            samples = torch.clamp(samples, min=-1, max=1)

            flattened_entity_list = flatten_extend(entity_list)
            gen_images_with_box_drawing = []
            for i in range(len(samples)):
                temp_data = {"image": samples[i], "boxes":batch["boxes"][i], "texts":flattened_entity_list[i]}
                im = vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False)
                gen_images_with_box_drawing.append(im)                

            gen_images_with_box_drawing = torch.stack(gen_images_with_box_drawing)
            gen_images_without_box_drawing = 0.5*samples + 0.5

            filename_list = [f"{filename}/{i}" for i in range(self.config.batch_size)]
            filenames = list(np.concatenate([([i]*n_frames) for i in filename_list], axis=0))
            filenames_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_frames', f) for f in filenames]
            filenames_bbox_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_bbox_frames', f) for f in filenames]

            filenames_unique_gifs = self.config.INFERENCE_DATA_OUTPUT + '_gifs'
            filenames_unique_bbox_gifs = self.config.INFERENCE_DATA_OUTPUT + '_bbox_gifs'
            filenames_unique_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_frames', f) for f in filename_list]
            filenames_unique_bbox_frames = [os.path.join(self.config.INFERENCE_DATA_OUTPUT + '_bbox_frames', f) for f in filename_list]

            os.makedirs(filenames_unique_gifs, exist_ok = True)
            os.makedirs(filenames_unique_bbox_gifs, exist_ok = True)
            for i in range(len(filenames_unique_frames)):
                os.makedirs(filenames_unique_frames[i], exist_ok = True)
                os.makedirs(filenames_unique_bbox_frames[i], exist_ok = True)
                
            ##### generate images #####
            for i in range(len(gen_images_with_box_drawing)):
                torchvision.utils.save_image(gen_images_without_box_drawing[i], os.path.join(filenames_frames[i], f"frame{i%self.config.n_sample_frames+1}.jpg"))
                torchvision.utils.save_image(gen_images_with_box_drawing[i], os.path.join(filenames_bbox_frames[i], f"frame{i%self.config.n_sample_frames+1}.jpg"))                 

            ##### generate gif #####
            for i in range(len(filename_list)):
                # without bbox 
                filenames = sorted(glob.glob(f"{filenames_unique_frames[i]}/*.jpg"))
                filenames = [filenames[0]] + filenames[8:] + filenames[1:8]
                frames = [Image.open(image) for image in filenames]
                frame_one = frames[0]
                os.makedirs(os.path.join(filenames_unique_gifs, f"{filename_list[i]}"), exist_ok=True)
                gifname = glob.glob(f"{subfiles[i]}/*.json")[0].split("/")[-1].replace(".json", "")
                frame_one.save(os.path.join(filenames_unique_gifs, f"{filename_list[i]}/{gifname}.gif"), format="GIF", append_images=frames, save_all=True, duration=200, loop=0)
                # with bbox 
                filenames = sorted(glob.glob(f"{filenames_unique_bbox_frames[i]}/*.jpg"))
                filenames = [filenames[0]] + filenames[8:] + filenames[1:8]
                frames = [Image.open(image) for image in filenames]
                frame_one = frames[0]
                os.makedirs(os.path.join(filenames_unique_bbox_gifs, f"{filename_list[i]}"), exist_ok=True)
                frame_one.save(os.path.join(filenames_unique_bbox_gifs, f"{filename_list[i]}/{gifname}.gif"), format="GIF", append_images=frames, save_all=True, duration=200, loop=0)

            ii += 1









