import torch
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
import numpy as np
from torch.utils.data.distributed import  DistributedSampler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os 
import shutil
import torchvision
from convert_ckpt import add_additional_channels
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from distributed import get_rank, synchronize
from transformers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from copy import deepcopy

try:
    from apex import amp
except:
    pass  
# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = #
from dataset.ai2d_dataset import Ai2D_Dataset
import clip 
from torchvision import transforms
from PIL import Image, ImageDraw
from einops import rearrange
from functools import partial




def alpha_generator(length, type=None):
    """
    length is total timestpes needed for sampling. 
    type should be a list containing three values which sum should be 1
    It means the percentage of three stages: 
    alpha=1 stage 
    linear deacy stage 
    alpha=0 stage. 
    For example if length=100, type=[0.8,0.1,0.1]
    then the first 800 stpes, alpha will be 1, and then linearly decay to 0 in the next 100 steps,
    and the last 100 stpes are 0.    
    """
    if type == None:
        type = [1,0,0]
    assert len(type)==3 
    assert type[0] + type[1] + type[2] == 1
    stage0_length = int(type[0]*length)
    stage1_length = int(type[1]*length)
    stage2_length = length - stage0_length - stage1_length
    if stage1_length != 0: 
        decay_alphas = np.arange(start=0, stop=1, step=1/stage1_length)[::-1]
        decay_alphas = list(decay_alphas)
    else:
        decay_alphas = []
    alphas = [1]*stage0_length + decay_alphas + [0]*stage2_length
    assert len(alphas) == length
    return alphas




def set_alpha_scale(model, alpha_scale):
    from ldm.modules.attention import GatedSelfAttentionDense
    for module in model.modules():
        if type(module) == GatedSelfAttentionDense:
            module.scale = alpha_scale





def project(x, projection_matrix):
    """
    x (Batch*768) should be the penultimate feature of CLIP (before projection)
    projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer 
    defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.  
    this function will return the CLIP feature (without normalziation)
    """
    return x@torch.transpose(projection_matrix, 0, 1)


def inv_project(y, projection_matrix):
    """
    y (Batch*768) should be the CLIP feature (after projection)
    projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer 
    defined in CLIP (out_dim, in_dim).  
    this function will return the CLIP penultimate feature. 
    
    Note: to make sure getting the correct penultimate feature, the input y should not be normalized. 
    If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown.   
    """
    return y@torch.transpose(torch.linalg.inv(projection_matrix), 0, 1)






def read_official_ckpt(ckpt_path):      
    "Read offical pretrained SD ckpt and convert into my style" 
    state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    out = {}
    out["model"] = {}
    out["text_encoder"] = {}
    out["autoencoder"] = {}
    out["unexpected"] = {}
    out["diffusion"] = {}

    for k,v in state_dict.items():
        if k.startswith('model.diffusion_model'):
            out["model"][k.replace("model.diffusion_model.", "")] = v 
        elif k.startswith('cond_stage_model'):
            out["text_encoder"][k.replace("cond_stage_model.", "")] = v 
        elif k.startswith('first_stage_model'):
            out["autoencoder"][k.replace("first_stage_model.", "")] = v 
        elif k in ["model_ema.decay", "model_ema.num_updates"]:
            out["unexpected"][k] = v  
        else:
            out["diffusion"][k] = v     
    return out 


def batch_to_device(batch, device):
    for k in batch:
        if isinstance(batch[k], torch.Tensor):
            batch[k] = batch[k].to(device)
    return batch


def sub_batch(batch, num=1):
    # choose first num in given batch 
    num = num if num > 1 else 1 
    for k in batch:
        batch[k] = batch[k][0:num]
    return batch


def wrap_loader(loader):
    while True:
        for batch in loader:  # TODO: it seems each time you have the same order for all epoch?? 
            yield batch


def disable_grads(model):
    for p in model.parameters():
        p.requires_grad = False


def count_params(params):
    total_trainable_params_count = 0 
    for p in params:
        total_trainable_params_count += p.numel()
    print("total_trainable_params_count is: ", total_trainable_params_count)


def update_ema(target_params, source_params, rate=0.99):
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

           
def create_expt_folder_with_auto_resuming(OUTPUT_ROOT, name):
    name = os.path.join( OUTPUT_ROOT, name )
    writer = None
    checkpoint = None

    if os.path.exists(name):
        all_tags = os.listdir(name)
        all_existing_tags = [ tag for tag in all_tags if tag.startswith('tag')    ]
        all_existing_tags.sort()
        all_existing_tags = all_existing_tags[::-1]
        for previous_tag in all_existing_tags:
            potential_ckpt = os.path.join( name, previous_tag, 'checkpoint_latest.pth' )
            if os.path.exists(potential_ckpt):
                checkpoint = potential_ckpt
                if get_rank() == 0:
                    print('auto-resuming ckpt found '+ potential_ckpt)
                break 
        curr_tag = 'tag'+str(len(all_existing_tags)).zfill(2)
        name = os.path.join( name, curr_tag ) # output/name/tagxx
    else:
        name = os.path.join( name, 'tag00' ) # output/name/tag00

    if get_rank() == 0:
        os.makedirs(name) 
        os.makedirs(  os.path.join(name,'Log')  ) 
        writer = SummaryWriter( os.path.join(name,'Log')  )

    return name, writer, checkpoint


def vis_getitem_data(index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True):
    img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
    canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
    W, H = img.size
    if print_caption:
        caption = out["caption"]
        print(caption)
        print(" ")
    boxes = []
    for box in out["boxes"]:    
        x0,y0,x1,y1 = box
        print(x0, y0, x1, y1)
        boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
    img = draw_box(img, boxes)
    if return_tensor:
        return  torchvision.transforms.functional.to_tensor(img)
    else:
        img.save(name)  


def draw_box(img, boxes):
    colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
    draw = ImageDraw.Draw(img)
    for bid, box in enumerate(boxes):
        draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4)
        # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 
    return img 

# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 






class Diagrammer:
    def __init__(self, config):

        self.config = config
        self.device = torch.device("cuda")

        self.l_simple_weight = 1
        self.name, self.writer, checkpoint = create_expt_folder_with_auto_resuming(config.OUTPUT_ROOT, config.name)
        if get_rank() == 0:
            shutil.copyfile(config.yaml_file, os.path.join(self.name, "train_config_file.yaml")  )
            self.config_dict = vars(config)
            torch.save(  self.config_dict,  os.path.join(self.name, "config_dict.pth")     )


        self.clipmodel, self.preprocess = clip.load("ViT-L/14", device=self.device)
        self.transform_image = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] )])
        self.projection_matrix = torch.load('projection_matrix').cuda()
        
        self.black_PIL_img = Image.new('RGB',(224,224),(255,255,255)) # this is white version
        self.black_img_tensor = self.preprocess(self.black_PIL_img).unsqueeze(0).to(self.device) # 1, 3, 224, 224
        with torch.no_grad():
            self.black_img_emb = self.clipmodel.encode_image(self.black_img_tensor)


        # = = = = = = = = = = = = = = = = = create model and diffusion = = = = = = = = = = = = = = = = = #
        self.model = instantiate_from_config(config.model).to(self.device)
        self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device)
        self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device)
        self.diffusion = instantiate_from_config(config.diffusion).to(self.device)

        
        state_dict = read_official_ckpt(  os.path.join(config.sd14_ckpt_path, config.official_ckpt_name)   )
        
        # modify the input conv for SD if necessary (grounding as unet input; inpaint)
        additional_channels = self.model.additional_channel_from_downsampler
        if self.config.inpaint_mode:
            additional_channels += 5 # 5 = 4(latent) + 1(mask)
        add_additional_channels(state_dict["model"], additional_channels)
        self.input_conv_train = True if additional_channels>0 else False

        # load original SD ckpt (with inuput conv may be modified) 
        missing_keys, unexpected_keys = self.model.load_state_dict( state_dict["model"], strict=False  )
        assert unexpected_keys == []
        original_params_names = list( state_dict["model"].keys()  ) # used for sanity check later 
        
        self.autoencoder.load_state_dict( state_dict["autoencoder"]  )
        self.text_encoder.load_state_dict( state_dict["text_encoder"] , strict=False )
        self.diffusion.load_state_dict( state_dict["diffusion"]  )
 
        self.autoencoder.eval()
        self.text_encoder.eval()
        disable_grads(self.autoencoder)
        disable_grads(self.text_encoder)



        self.blank_list = ['' for i in range(self.config.batch_size)]
        self.blank_prompt_ids = self.get_prompt_ids(self.blank_list).to(self.device)
        with torch.no_grad():
            self.blank_text_embeddings = self.clipmodel.encode_text(self.blank_prompt_ids)
            #self.blank_text_embeddings = self.text_encoder_pooled.encode( self.blank_list, return_pooler_output=True )[1]
            



        # = = = = = = = = = = = = = load from ckpt: (usually for inpainting training) = = = = = = = = = = = = = #
        if self.config.ckpt is not None:
            print("")
            print("self.config.ckpt = ", self.config.ckpt)
            first_stage_ckpt = torch.load(self.config.ckpt, map_location="cpu")
            self.model.load_state_dict(first_stage_ckpt["model"])


        # = = = = = = = = = = = = = = = = = create opt = = = = = = = = = = = = = = = = = #
        params = []
        trainable_names = []
        all_params_name = []
        for name, p in self.model.named_parameters():
            ##### fully trainable 
            params.append(p) 
            trainable_names.append(name)
            all_params_name.append(name) 


        self.opt = torch.optim.AdamW(params, lr=config.base_learning_rate, weight_decay=config.weight_decay) 
        count_params(params)
        
        


        #  = = = = = EMA... It is worse than normal model in early experiments, thus never enabled later = = = = = = = = = #
        if config.enable_ema:
            self.master_params = list(self.model.parameters()) 
            self.ema = deepcopy(self.model)
            self.ema_params = list(self.ema.parameters())
            self.ema.eval()




        # = = = = = = = = = = = = = = = = = = = = create scheduler = = = = = = = = = = = = = = = = = = = = #
        if config.scheduler_type == "cosine":
            self.scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_iters)
        elif config.scheduler_type == "constant":
            self.scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps)
        else:
            assert False 




        
        ##### train #####
        dataset_train = Ai2D_Dataset(ROOT=self.config.DIAGRAM_DATA_ROOT, image_size = (self.config.new_image_size, self.config.new_image_size))
        sampler = DistributedSampler(dataset_train, seed=config.seed) if config.distributed else None 
        loader_train = DataLoader( dataset_train,  batch_size=config.batch_size, 
                                                    shuffle=(sampler is None),
                                                    num_workers=config.workers, 
                                                    pin_memory=True, 
                                                    sampler=sampler)
        self.dataset_train = dataset_train
        self.loader_train = wrap_loader(loader_train)
        
        

        if get_rank() == 0:
            total_image = dataset_train.total_images()
            print("Total training images: ", total_image)     
        



        # = = = = = = = = = = = = = = = = = = = = load from autoresuming ckpt = = = = = = = = = = = = = = = = = = = = #
        self.starting_iter = 0  
        #"""
        if checkpoint is not None:
            checkpoint = torch.load(checkpoint, map_location="cpu")
            self.model.load_state_dict(checkpoint["model"])
            if config.enable_ema:
                self.ema.load_state_dict(checkpoint["ema"])
            self.opt.load_state_dict(checkpoint["opt"])
            self.scheduler.load_state_dict(checkpoint["scheduler"])
            self.starting_iter = checkpoint["iters"]
            if self.starting_iter >= config.total_iters:
                synchronize()
                print("Training finished. Start exiting")
                exit()
            print("########## auto resumed from ckpt ###########")
            print("save every iter: ", self.config.save_every_iters)
        #"""

        # = = = = = = = = = = = = = = = = = = = = misc and ddp = = = = = = = = = = = = = = = = = = = =#    
        
        # func return input for grounding tokenizer 
        self.grounding_tokenizer_input = instantiate_from_config(config.grounding_tokenizer_input)
        self.model.grounding_tokenizer_input = self.grounding_tokenizer_input
        
        # func return input for grounding downsampler  
        self.grounding_downsampler_input = None
        if 'grounding_downsampler_input' in config:
            self.grounding_downsampler_input = instantiate_from_config(config.grounding_downsampler_input)

        if config.distributed:
            self.model = DDP( self.model, device_ids=[config.local_rank], output_device=config.local_rank, broadcast_buffers=False )





    @torch.no_grad()
    def get_prompt_ids(self, prompt):
        tokens = clip.tokenize(prompt)
        return tokens
    
    
    
    @torch.no_grad()
    def get_input(self, batch):
        if len(batch["image"].shape) == 5:
            batch["image"] = batch["image"].squeeze(1)
        
        z = self.autoencoder.encode( batch["image"] )

        context = self.text_encoder.encode( batch["caption"]  )

        _t = torch.rand(z.shape[0]).to(z.device)
        t = (torch.pow(_t, 1) * 1000).long()
        t = torch.where(t!=1000, t, 999) # if 1000, then replace it with 999
        
        inpainting_extra_input = None
        grounding_extra_input = None
        if self.grounding_downsampler_input != None:
            grounding_extra_input = self.grounding_downsampler_input.prepare(batch)

        return z, t, context, inpainting_extra_input, grounding_extra_input 


    def run_one_step(self, batch):
        x_start, t, context, inpainting_extra_input, grounding_extra_input = self.get_input(batch)
        noise = torch.randn_like(x_start)
        x_noisy = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise)

        grounding_input = self.grounding_tokenizer_input.prepare(batch)
        input = dict(x=x_noisy, 
                    timesteps=t, 
                    context=context, 
                    inpainting_extra_input=inpainting_extra_input,
                    grounding_extra_input=grounding_extra_input,
                    grounding_input=grounding_input)
        model_output = self.model(input)
        
        loss = torch.nn.functional.mse_loss(model_output, noise) * self.l_simple_weight

        self.loss_dict = {"loss": loss.item()}

        return loss 
        



    def process_text_embedding(self, batch, name):
        batch_tokens_positive_text_embeddings_list = []
        for i in range(30):
            if (tuple(self.blank_list) == batch[name][i]) or (list(self.blank_list) == batch[name][i]):
                #print("yes")
                batch_tokens_positive_text_embeddings = self.blank_text_embeddings
            else:
                token_ids = self.get_prompt_ids(batch[name][i]).to(self.device)
                with torch.no_grad():
                    batch_tokens_positive_text_embeddings = self.clipmodel.encode_text(token_ids)
            batch_tokens_positive_text_embeddings_list.append(batch_tokens_positive_text_embeddings)
        
        return torch.stack(batch_tokens_positive_text_embeddings_list).permute(1, 0, 2)

    

    def process_image_embedding(self, batch, n_bf, b, f, o, box_name='boxes', mask_name='masks', bbox_format='xywh'):
        
        cropped_images = torch.zeros(n_bf, 30, 3, 224, 224)
        
        for i in range(n_bf):
            obj_boxes = batch[box_name][i]
            n_objs = min(int(sum(batch[mask_name][i])), 30)
            pil_img_array = ((batch['image'][i][0]/2+0.5)*255).to(torch.uint8).permute(1, 2, 0).numpy()
            pil_img_array[pil_img_array>255] = 255
            pil_img_array[pil_img_array<0] = 0
            pil_img_idx = Image.fromarray(pil_img_array, 'RGB')
            width, height = pil_img_idx.size
            for obj_idx in range(n_objs):
                left = int(obj_boxes[obj_idx][0] * width) 
                top = int(obj_boxes[obj_idx][1] * height) 
                if bbox_format == 'xywh':
                    right = int((obj_boxes[obj_idx][0]+obj_boxes[obj_idx][2]) * width)
                    bottom = int((obj_boxes[obj_idx][1]+obj_boxes[obj_idx][3]) * height)
                elif bbox_format == 'xyxy':
                    right = int(obj_boxes[obj_idx][2] * width)
                    bottom = int(obj_boxes[obj_idx][3] * height)   
                    
                if left < right and top < bottom:
                    pil_img_idx_cropped = pil_img_idx.crop((left, top, right, bottom))
                    im_cropped = pil_img_idx_cropped.resize((224, 224), Image.BICUBIC)
                else:
                    im_cropped = Image.new(mode="RGB", size=(224, 224))

                cropped_images[i][obj_idx] = self.transform_image(im_cropped).to(self.device) # 3, 224, 224
                    
        
        cropped_images = rearrange(cropped_images, "bf o c h w -> (bf o) c h w")
        flattened_mask = rearrange(batch[mask_name], "bf o -> (bf o)")
        nonzero_indices = torch.nonzero(flattened_mask).squeeze()
        selected_cropped_images = torch.index_select(cropped_images, 0, nonzero_indices)
        cropped_image_embeddings = self.black_img_emb.repeat(n_bf * 30, 1)
        
    
        selected_batch_img_emb_list = []
        mini_batch_size = 20
        if selected_cropped_images.shape[0] > 0:
            with torch.no_grad():
                for start_idx in range(0, len(selected_cropped_images), mini_batch_size):
                    selected_batch_cropped_images = selected_cropped_images[start_idx:(start_idx+mini_batch_size)].to(self.device)
                    selected_batch_img_emb = self.clipmodel.encode_image(selected_batch_cropped_images)
                    
                    selected_batch_img_emb_list.append(selected_batch_img_emb)
            selected_batch_img_emb_list = torch.concat(selected_batch_img_emb_list)
            cropped_image_embeddings[list(nonzero_indices), :] = selected_batch_img_emb_list
        
        image_embeddings = rearrange(cropped_image_embeddings, "(b f o) d -> (b f) o d", o=o, b=b, f=f) 
        CLIP_feature = project(image_embeddings, self.projection_matrix.half())
        image_embeddings_in_text_space = inv_project(CLIP_feature.float(), self.projection_matrix)
            
        return  image_embeddings_in_text_space
            
    



    def pre_process_batch_for_img_imgtextemd_data(self, batch):
        
        if type(batch['image']) is not list:
            if len(batch['image'].shape) == 4:
                # change from (b c h w) to (b 1 c h w)
                batch['image'] = batch['image'].unsqueeze(1)

            n_bf = batch['image'].shape[0] * batch['image'].shape[1]
            b, f, o = batch['image'].shape[0], batch['image'].shape[1], 30
        else:
            n_bf, b, f, o = 1, 1, 1, 30
            
        if batch['use_text_embedding']:
            batch['text_embeddings'] = self.process_text_embedding(batch, name='entity_image_tokens_positive')
            del batch['entity_image_tokens_positive']
            
        if batch['use_arrow_embedding']:
            batch['arrow_embeddings'] = self.process_text_embedding(batch, name='entity_arrow_tokens_positive')
            
        if batch['use_text_crop_embedding']:
            batch['text_crop_embeddings'] = self.process_image_embedding(batch, n_bf, b, f, o, box_name='entity_text_boxes', mask_name='entity_text_masks', bbox_format='xywh')
        
        if 'grounding_input' in batch.keys():
            del batch['grounding_input']
        if 'text_embeddings' in batch.keys() and not "use_text_embedding" in batch.keys():
            del batch['text_embeddings']
            
        batch['position_net_point_or_box'] = 'box'
        
        if 'text_crop_embeddings' in batch:
            batch['text_embeddings'] = torch.concat([batch['arrow_embeddings'], batch['text_embeddings'], batch['text_crop_embeddings']], axis = 1)
            batch['masks'] = torch.concat([batch['entity_arrow_masks'], batch['entity_image_masks'], batch['entity_text_masks']], axis = 1)
            batch['boxes'] = torch.concat([batch['entity_arrow_boxes'], batch['entity_image_boxes'], batch['entity_text_boxes']], axis = 1)
            del batch['text_crop_embeddings'], batch['arrow_embeddings']
        elif batch['use_arrow_embedding'] == True:
            batch['text_embeddings'] = torch.concat([batch['arrow_embeddings'], batch['text_embeddings']], axis = 1)
            batch['masks'] = torch.concat([batch['entity_arrow_masks'], batch['entity_image_masks']], axis = 1)
            batch['boxes'] = torch.concat([batch['entity_arrow_boxes'], batch['entity_image_boxes']], axis = 1)
            del batch['arrow_embeddings']
        else:
            batch['masks'] = batch['entity_image_masks']
            batch['boxes'] = batch['entity_image_boxes']
            
        
        batch['boxes'][:,:,2] += batch['boxes'][:,:,0]
        batch['boxes'][:,:,3] += batch['boxes'][:,:,1]
        batch['boxes'][batch['boxes']>1.0] = 1.0 
        batch['boxes'][batch['boxes']<0.0] = 0.0

            
        del batch['entity_image_boxes'], batch['entity_image_masks'] 
        del batch['entity_text_boxes'], batch['entity_text_masks'] 
        del batch['entity_arrow_tokens_positive'], batch['entity_arrow_boxes'], batch['entity_arrow_masks']
        

        return batch
    
    
    

    def start_training(self):

        iterator = tqdm(range(self.starting_iter, self.config.total_iters), desc='Training progress',  disable=get_rank() != 0 )
        self.model.train()
        for iter_idx in iterator: # note: iter_idx is not from 0 if resume training
            self.iter_idx = iter_idx
            self.opt.zero_grad()
            
            flag_error = True 
            while flag_error:
                try:
                    batch = next(self.loader_train)      
                    batch['use_text_embedding'] = True
                    batch['use_text_crop_embedding'] = True
                    batch['use_arrow_embedding'] = True              
                    batch = self.pre_process_batch_for_img_imgtextemd_data(batch)
                    flag_error = False 
                except:
                    pass
                
                
            batch_to_device(batch, self.device)

            loss = self.run_one_step(batch)
            loss.backward()
            self.opt.step() 
            self.scheduler.step()
            if self.config.enable_ema:
                update_ema(self.ema_params, self.master_params, self.config.ema_rate)


            if (get_rank() == 0):
                if (iter_idx % 10 == 0):
                    self.log_loss() 
                if (iter_idx == 0)  or  ( iter_idx % self.config.save_every_iters == 0 )  or  (iter_idx == self.config.total_iters-1):
                    self.save_ckpt_and_result()
            synchronize()

        
        synchronize()
        print("Training finished. Start exiting")
        exit()


    def log_loss(self):
        for k, v in self.loss_dict.items():
            self.writer.add_scalar(  k, v, self.iter_idx+1  )  # we add 1 as the actual name
    

    @torch.no_grad()
    def save_ckpt_and_result(self):

        model_wo_wrapper = self.model.module if self.config.distributed else self.model
        iter_name = self.iter_idx + 1     # we add 1 as the actual name
       
        ckpt = dict(model = model_wo_wrapper.state_dict(),
                    text_encoder = self.text_encoder.state_dict(),
                    autoencoder = self.autoencoder.state_dict(),
                    diffusion = self.diffusion.state_dict(),
                    opt = self.opt.state_dict(),
                    scheduler= self.scheduler.state_dict(),
                    iters = self.iter_idx+1,
                    config_dict=self.config_dict,
        )
        if self.config.enable_ema:
            ckpt["ema"] = self.ema.state_dict()
        torch.save( ckpt, os.path.join(self.name, "checkpoint_"+str(iter_name).zfill(8)+".pth") )
        torch.save( ckpt, os.path.join(self.name, "checkpoint_latest.pth") )




    @torch.no_grad()
    def start_inference(self, output_folder, customized):
        
        model_wo_wrapper = self.model.module if self.config.distributed else self.model

        ##### dataset test #####
        self.num_images_val = 1
        dataset_test = Ai2D_Dataset(ROOT=self.config.DIAGRAM_DATA_ROOT, test=True, customized=customized)
        self.num_images_test = self.config.num_images_val
        loader_test = DataLoader( dataset_test,  batch_size=self.num_images_val, shuffle=False, num_workers=self.config.workers, pin_memory=True, sampler=None)
        self.dataset_test = dataset_test
        self.loader_test = wrap_loader(loader_test)    
        batch_here = self.num_images_val

        
        if 'alpha_type' in self.config:
            alpha_type = self.config.alpha_type
        else:
            alpha_type = [1.0, 0, 0.0]
            
        print(f"************* alpha_type = {alpha_type} *************")
        
        alpha_generator_func = partial(alpha_generator, type=alpha_type)
        plms_sampler = PLMSSampler(self.diffusion, model_wo_wrapper, alpha_generator_func=alpha_generator_func, set_alpha_scale=set_alpha_scale)
       
        shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size)
        
        
        # extra input for inpainting 
        inpainting_extra_input = None          
        grounding_extra_input = None


        for iter_idx in tqdm(range(len(self.dataset_test))):
        
            batch =  next(self.loader_test)
            batch['use_text_crop_embedding'] = False
            batch['use_text_embedding'] = True
            batch['use_arrow_embedding'] = False
            
            batch = self.pre_process_batch_for_img_imgtextemd_data(batch)
            batch_to_device(batch, self.device)

            batch_caption_with_background = [batch["caption"][i] + ", white background"  for i in range(batch_here)]
            uc = self.text_encoder.encode( batch_here*[""]  )
            context = self.text_encoder.encode( batch_caption_with_background  )
    
            grounding_input = self.grounding_tokenizer_input.prepare(batch)
            input = dict( x=None, 
                          timesteps=None, 
                          context=context, 
                          inpainting_extra_input=inpainting_extra_input,
                          grounding_extra_input=grounding_extra_input,
                          grounding_input=grounding_input )
    
            samples = plms_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5)
    
            autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. 
    
            samples = autoencoder_wo_wrapper.decode(samples).cpu()
            samples = torch.clamp(samples, min=-1, max=1)
            
            save_path = os.path.join(self.config.OUTPUT_ROOT, output_folder)
            os.makedirs(save_path, exist_ok=True)
            torchvision.utils.save_image(samples[0], os.path.join(save_path, batch['id'][0]+".png"))
            

          
