import torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from ldm.models.diffusion.gaussian_smoothing import GaussianSmoothing
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter
import os 
import shutil
import torchvision
from convert_ckpt import add_additional_channels
import math
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from distributed import get_rank, synchronize, get_world_size
from transformers import get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
from copy import deepcopy
from inpaint_mask_func import draw_masks_from_boxes
from dataset.dataset_my_quick import Test10
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
try:
    from apex import amp
except:
    pass  
import random
import numpy as np

# 固定PyTorch的随机种子
torch.manual_seed(42)
# 固定CUDA的随机种子（如果使用GPU）
torch.cuda.manual_seed(42)
# 固定Python的随机种子
random.seed(42)
# 固定NumPy的随机种子
np.random.seed(42)


# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = #
class LLMResize(nn.Module):
    def __init__(self, llm_size=6656, sd_text_size=768):
        super(LLMResize, self).__init__()
        self.fc = nn.Linear(llm_size, sd_text_size)

    def forward(self, x):
        out = self.fc(x)
        return out


class ImageCaptionSaver:
    def __init__(self, base_path, nrow=8, normalize=True, scale_each=True, range=(-1,1) ):
        self.base_path = base_path 
        self.nrow = nrow
        self.normalize = normalize
        self.scale_each = scale_each
        self.range = range

    def __call__(self, images, real, masked_real, captions, seen):
        
        save_path = os.path.join(self.base_path, str(seen).zfill(8)+'.png')
        torchvision.utils.save_image( images, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range )
        
        save_path = os.path.join(self.base_path, str(seen).zfill(8)+'_real.png')
        torchvision.utils.save_image( real, save_path, nrow=self.nrow)

        if masked_real is not None:
            # only inpaiting mode case 
            save_path = os.path.join(self.base_path, str(seen).zfill(8)+'_mased_real.png')
            torchvision.utils.save_image( masked_real, save_path, nrow=self.nrow, normalize=self.normalize, scale_each=self.scale_each, range=self.range)

        assert images.shape[0] == len(captions)

        save_path = os.path.join(self.base_path, 'captions.txt')
        with open(save_path, "a") as f:
            f.write( str(seen).zfill(8) + ':\n' )    
            for cap in captions:
                f.write( cap + '\n' )  
            f.write( '\n' ) 

def set_requires_grad(model, value):
    for param in model.parameters():
        param.requires_grad = value       


def read_official_ckpt(ckpt_path):      
    "Read offical pretrained SD ckpt and convert into my style" 
    # state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    out = {}
    out["model"] = {}
    out["text_encoder"] = {}
    out["autoencoder"] = {}
    out["unexpected"] = {}
    out["diffusion"] = {}
    saved_ckpt = torch.load(ckpt_path)
    saved_ckpt['unexpected'] = {}

    return saved_ckpt 


def batch_to_device(batch, device):
    for k in batch:
        if isinstance(batch[k], torch.Tensor):
            batch[k] = batch[k].to(device)
    return batch


def sub_batch(batch, num=1):
    # choose first num in given batch 
    num = num if num > 1 else 1 
    for k in batch:
        batch[k] = batch[k][0:num]
    return batch


def wrap_loader(loader):
    while True:
        for batch in loader:  # TODO: it seems each time you have the same order for all epoch?? 
            yield batch


def disable_grads(model):
    for p in model.parameters():
        p.requires_grad = False


def count_params(params):
    total_trainable_params_count = 0 
    for p in params:
        total_trainable_params_count += p.numel()
    print("total_trainable_params_count is: ", total_trainable_params_count)


def update_ema(target_params, source_params, rate=0.99):
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)

           
def create_expt_folder_with_auto_resuming(OUTPUT_ROOT, name):
    name = os.path.join( OUTPUT_ROOT, name )
    writer = None
    checkpoint = None

    if os.path.exists(name):
        all_tags = os.listdir(name)
        all_existing_tags = [ tag for tag in all_tags if tag.startswith('tag')    ]
        all_existing_tags.sort()
        all_existing_tags = all_existing_tags[::-1]
        for previous_tag in all_existing_tags:
            potential_ckpt = os.path.join( name, previous_tag, 'checkpoint_latest.pth' )
            if os.path.exists(potential_ckpt):
                checkpoint = potential_ckpt
                if get_rank() == 0:
                    print('auto-resuming ckpt found '+ potential_ckpt)
                break 
        curr_tag = 'tag'+str(len(all_existing_tags)).zfill(2)
        name = os.path.join( name, curr_tag ) # output/name/tagxx
    else:
        name = os.path.join( name, 'tag00' ) # output/name/tag00

    if get_rank() == 0:
        os.makedirs(name) 
        os.makedirs(  os.path.join(name,'Log')  ) 
        writer = SummaryWriter( os.path.join(name,'Log')  )

    return name, writer, checkpoint

# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # 
def remove_numbers(text):
    result = ''.join([char for char in text if not char.isdigit()])
    return result

def process_box_phrase(names, bboxes):
    d = {}
    for i, phrase in enumerate(names):
        phrase = phrase.replace('_',' ')
        list_noun = phrase.split(' ')
        for n in list_noun:
            n = remove_numbers(n)
            if not n in d.keys():
                d.update({n:[np.array(bboxes[i])]})
            else:
                d[n].append(np.array(bboxes[i]))
    return d

def Pharse2idx_2(prompt, name_box):
    prompt = prompt.replace('.','')
    prompt = prompt.replace(',','')
    prompt_list = prompt.strip('.').split(' ')
    object_positions = []
    bbox_to_self_att = []
    for obj in name_box.keys():
        obj_position = []
        in_prompt = False
        for word in obj.split(' '):
            if word in prompt_list:
                obj_first_index = prompt_list.index(word) + 1
                obj_position.append(obj_first_index)
                in_prompt = True
            elif word +'s' in prompt_list:
                obj_first_index = prompt_list.index(word+'s') + 1
                obj_position.append(obj_first_index)
                in_prompt = True
            elif word +'es' in prompt_list:
                obj_first_index = prompt_list.index(word+'es') + 1
                obj_position.append(obj_first_index)
                in_prompt = True 
        if in_prompt :
            bbox_to_self_att.append(np.array(name_box[obj]))
        
            object_positions.append(obj_position)

    return object_positions, bbox_to_self_att


def get_all_attention(attn_maps_mid, attn_maps_up , attn_maps_down, res):
    result  = []
    
    for attn_map_integrated in attn_maps_up:
        if attn_map_integrated == []: continue
        attn_map = attn_map_integrated[0][0]
        b, i, j = attn_map.shape
        H = W = int(math.sqrt(i))
        # print(H)
        if H == res:
            result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
    for attn_map_integrated in attn_maps_mid:

    # for attn_map_integrated in attn_maps_mid:
        attn_map = attn_map_integrated[0]
        b, i, j = attn_map.shape
        H = W = int(math.sqrt(i))
        # print(H)
        if (H==res):
            result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
    # import pdb; pdb.set_trace()
    for attn_map_integrated in attn_maps_down:
        if attn_map_integrated == []: continue
        attn_map = attn_map_integrated[0][0]
        if attn_map == []: continue
        b, i, j = attn_map.shape
        H = W = int(math.sqrt(i))
        # print(H)
        if (H==res):
            result.append(attn_map.reshape(-1, res, res,attn_map.shape[-1] ))
    
    result = torch.cat(result, dim=0)
    result = result.sum(0) / result.shape[0]
    return result


def attention_loss(attn_maps_mid, attn_maps_up, attn_maps_down, bboxes, object_positions, box_att , t, res=16, smooth_att = True,sigma=0.5,kernel_size=3 ):
    
    device = torch.device("cuda")
    attn = get_all_attention(attn_maps_mid, attn_maps_up, attn_maps_down, res)
    obj_number = len(bboxes)
    total_loss = 0
   
    attn_text = attn[:, :, 1:-1] 
    attn_text *= 100 
    attn_text = torch.nn.functional.softmax(attn_text, dim=-1)
    current_res =  attn.shape[0]
    H = W = current_res
    
    min_all_inside = 1000
    max_outside = 0
    
    for obj_idx in range(obj_number): # batch size layer
        num_boxes= 0
        for pos_idx, obj_position in enumerate(object_positions[obj_idx]): # token layer
            true_obj_position = obj_position - 1
            if obj_position == 0:
                continue
            att_map_obj = attn_text[:,:, true_obj_position]
            if smooth_att:
                smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(device)
                # print(att_map_obj)
                input = F.pad(att_map_obj.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
                att_map_obj = smoothing(input).squeeze(0).squeeze(0)
            other_att_map_obj = att_map_obj.detach().clone()
            att_copy = att_map_obj.detach().clone()

            for cur_bbox in bboxes[obj_idx]:
                flag = False
                for obj_box in box_att[obj_idx][pos_idx]:
                    if obj_box[0] == 0 and obj_box[2] == 0:
                        break 
                    if cur_bbox[0] == obj_box[0] and cur_bbox[1] == obj_box[1] and cur_bbox[2] == obj_box[2] and  cur_bbox[3] == obj_box[3]:
                        flag = True
                        break
                
                if flag:
                    # 当前token对应此bbox
                    x_min, y_min, x_max, y_max = int(cur_bbox[0] * W), \
                        int(cur_bbox[1] * H), int(cur_bbox[2] * W), int(cur_bbox[3] * H)
                    if att_map_obj[y_min: y_max, x_min: x_max].numel() == 0: 
                        max_inside=1. # 为什么这个不表示全部没体现？这到底什么情况呢 怕出啥意外？所以此时情况特殊？
                    else:
                        max_inside = att_map_obj[y_min: y_max, x_min: x_max].max() # 说明attention map表示prompts中的某些条件需要在x的bbox中出现，希望这个值越大越好，说明这个条件很强。因为已经使用了高斯平滑，所以可以用max
                    if max_inside < 0.1:
                        total_loss += 6*(1. - max_inside) # 这个时候说明条件体现还不明显，因为max值都很小，loss加强
                    elif max_inside < 0.2:
                        total_loss += 1. - max_inside # 这个时候其实条件体现还是不是很明显
                    
                    elif t[obj_idx] > 600:
                        total_loss += 1. - max_inside # 此时可能max_inside 已经比较大了>0.2，但是在t前期还是要进行干预
                    
                    if max_inside < min_all_inside:
                        min_all_inside = max_inside
                    
                    # find max outside the box, find in the other boxes
                    
                    att_copy[y_min: y_max, x_min: x_max] = 0.
                    other_att_map_obj[y_min: y_max, x_min: x_max] = 0.
                else:
                    # 当前token不对应此bbox
                    x_min_out, y_min_out, x_max_out, y_max_out = int(cur_bbox[0] * W), \
                        int(cur_bbox[1] * H), int(cur_bbox[2] * W), int(cur_bbox[3] * H)
                    
                    if other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].numel() == 0: 
                        max_outside_one= 0 # 这是好的，这个obj不应该和其他框有注意力
                    else:
                        max_outside_one = other_att_map_obj[y_min_out: y_max_out, x_min_out: x_max_out].max()
                    
                    att_copy[y_min_out: y_max_out, x_min_out: x_max_out] = 0. # 把和这个框相关的全设置为0
                    
                    if max_outside_one > 0.15:
                        total_loss += 4 * max_outside_one # 这个说明混淆严重，不行，加大权重
                    elif max_outside_one > 0.1:
                        total_loss += max_outside_one # 这个有点混淆，也不行
                    
                    elif t[obj_idx] > 600:
                        total_loss += max_outside_one # 这个说明前期还需要干预一下，也有点混淆
                    
                    if max_outside_one > max_outside:
                        max_outside = max_outside_one

            max_background = att_copy.max() 
            total_loss += len(bboxes[obj_idx]) * max_background /2. # 背景的loss也不能太大
                
    return total_loss/obj_number, min_all_inside, max_outside

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)

def clip_guided(image, prompt, device, feature_extractor, clip_model, tokenizer, text_encoder,clip_guidance_scale = 1):
    # feature_extractor = CLIPImageProcessor.from_pretrained("/mnt/data1/yaqili/model_weights/clip-vit-large-patch14")
    # clip_model = CLIPModel.from_pretrained("/mnt/data1/yaqili/model_weights/clip-vit-large-patch14", torch_dtype=torch.float16)

    normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
    '''
    cut_out_size = (
        feature_extractor.size
        if isinstance(feature_extractor.size, int)
        else feature_extractor.size["shortest_edge"]
    )
    '''
    cut_out_size = [feature_extractor.size["shortest_edge"], feature_extractor.size["shortest_edge"]]
    # print(cut_out_size)
    image = image.unsqueeze(0)

    image = transforms.Resize(cut_out_size)(image)
    image = normalize(image).to(torch.float16)
    # print(image.size())

    image_embeddings_clip = clip_model.get_image_features(image)
    image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)

    
    text_input = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
    # duplicate text embeddings for each generation per prompt
    text_embeddings = text_embeddings.repeat_interleave(1, dim=0)
    clip_text_input = text_input.input_ids.to(device)
    text_embeddings_clip = clip_model.get_text_features(clip_text_input)
    text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
    # duplicate text embeddings clip for each generation per prompt
    text_embeddings_clip = text_embeddings_clip.repeat_interleave(1, dim=0)

    loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale

    return loss

def cal_layout_semantic_loss(idx, sample, batch, device, feature_extractor, clip_model, tokenizer, text_encoder,clip_guidance_scale = 1):
    # print(batch['caption'][idx])
    cnt_obj = 0
    loss_layout = 0
    # print(batch['boxes'].size())
    # print(batch['boxes'][idx])
    # print("batch['masks'][idx]: ", batch['masks'][idx])
    assert idx == 0
    # print("batch['boxes_name'][idx]: ", batch['boxes_name'])
    for cur_item_idx, mask in enumerate(batch['masks'][idx]):
        if mask == 1: 
            print("mask: ",mask)
            norm_bbox = batch['boxes'][idx][cur_item_idx]
            bbox_name = batch['boxes_name'][cur_item_idx]
            H = sample.size(1)
            W = sample.size(2)
            bbox = [int(norm_bbox[0] * H), int(norm_bbox[1] * W), int(norm_bbox[2] * H), int(norm_bbox[3] * W)] 
            cur_obj = sample[:,bbox[1]: bbox[3], bbox[0]:bbox[2]]
            cur_name = bbox_name

            max_edge = max(bbox[3] - bbox[1], bbox[2] - bbox[0])
            min_edge = min(bbox[3] - bbox[1], bbox[2] - bbox[0])
            area_rate = (max_edge * min_edge) / (H * W)

            if max_edge / min_edge > 2.5 or area_rate < 0.01: # fliter some layout
                continue
            
            cnt_obj += 1
            loss_layout += clip_guided(cur_obj, cur_name, device, feature_extractor, clip_model, tokenizer, text_encoder,clip_guidance_scale = clip_guidance_scale)

    if cnt_obj == 0:
        return 0
    
    loss_layout = loss_layout / cnt_obj

    return loss_layout


class Trainer:
    def __init__(self, config):

        self.config = config
        self.device = torch.device("cuda")

        self.l_simple_weight = 1
        self.name, self.writer, checkpoint = create_expt_folder_with_auto_resuming(config.OUTPUT_ROOT, config.name)
        if get_rank() == 0:
            shutil.copyfile(config.yaml_file, os.path.join(self.name, "train_config_file.yaml")  )
            self.config_dict = vars(config)
            torch.save(  self.config_dict,  os.path.join(self.name, "config_dict.pth")     )


        # = = = = = = = = = = = = = = = = = create model and diffusion = = = = = = = = = = = = = = = = = #
        self.model = instantiate_from_config(config.model).to(self.device)
        self.autoencoder = instantiate_from_config(config.autoencoder).to(self.device)
        self.text_encoder = instantiate_from_config(config.text_encoder).to(self.device)
        self.diffusion = instantiate_from_config(config.diffusion).to(self.device)

        ckpt_path = "/mnt/data1/yaqili/code/GLIGEN/pretrained_model/diffusion_pytorch_model.bin"
        state_dict = read_official_ckpt( ckpt_path )
        
        # modify the input conv for SD if necessary (grounding as unet input; inpaint)
        additional_channels = self.model.additional_channel_from_downsampler
        if self.config.inpaint_mode:
            additional_channels += 5 # 5 = 4(latent) + 1(mask)
        add_additional_channels(state_dict["model"], additional_channels)
        self.input_conv_train = True if additional_channels>0 else False

        # load original SD ckpt (with inuput conv may be modified) 
        missing_keys, unexpected_keys = self.model.load_state_dict( state_dict["model"], strict=False  )
        assert unexpected_keys == []
        original_params_names = list( state_dict["model"].keys()  ) # used for sanity check later 
        
        self.autoencoder.load_state_dict( state_dict["autoencoder"]  )
        if 'transformer.text_model.embeddings.position_ids' in state_dict["text_encoder"]:
            del state_dict["text_encoder"]['transformer.text_model.embeddings.position_ids']
        self.text_encoder.load_state_dict( state_dict["text_encoder"]  )
        self.diffusion.load_state_dict( state_dict["diffusion"]  )

        adapter_weight = 0.1 # TODO
        llm_size = 5120
        sd_text_size = 768

        self.autoencoder.eval()
        self.text_encoder.eval()
        # self.diffusion.eval()

        disable_grads(self.autoencoder)
        disable_grads(self.text_encoder)
        # disable_grads(self.diffusion)
        

        self.clip_guide_feature_extractor = CLIPImageProcessor.from_pretrained("/mnt/data1/yaqili/model_weights/laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
        self.clip_guide_clip_model = CLIPModel.from_pretrained("/mnt/data1/yaqili/model_weights/laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
        self.clip_guide_clip_model.to(self.device)

        self.clip_guide_tokenizer = CLIPTokenizer.from_pretrained("/mnt/data1/yaqili/model_weights/stable_diffusion_v15", subfolder="tokenizer")
        self.clip_guide_text_encoder = CLIPTextModel.from_pretrained("/mnt/data1/yaqili/model_weights/stable_diffusion_v15", subfolder="text_encoder")
        self.clip_guide_text_encoder.to(self.device)
        set_requires_grad(self.clip_guide_text_encoder, False)
        set_requires_grad(self.clip_guide_clip_model, False)

        # = = = = = = = = = = = = = load from ckpt: (usually for inpainting training) = = = = = = = = = = = = = #
        if self.config.ckpt is not None:
            first_stage_ckpt = torch.load(self.config.ckpt, map_location="cpu")
            self.model.load_state_dict(first_stage_ckpt["model"])


        # = = = = = = = = = = = = = = = = = create opt = = = = = = = = = = = = = = = = = #
        params = []
        trainable_names = []
        all_params_name = []
        for name, p in self.model.named_parameters():
            
            if ("transformer_blocks" in name) and ("fuser" in name):
                # New added Attention layers 
                params.append(p) 
                trainable_names.append(name)
            
                # 253064800
                # 209102240
            
            elif ("transformer_blocks" in name) and ("attn1" in name):
                # New added Attention layers 
                params.append(p) 
                trainable_names.append(name)
            
            elif ("transformer_blocks" in name) and ("attn2" in name):
                # New added Attention layers 
                params.append(p) 
                trainable_names.append(name)
                # 45046528
                
            elif  "position_net" in name:
                # Grounding token processing network 
                params.append(p) 
                trainable_names.append(name)
            elif  "downsample_net" in name:
                # Grounding downsample network (used in input) 
                params.append(p) 
                trainable_names.append(name)
            elif (self.input_conv_train) and ("input_blocks.0.0.weight" in name):
                # First conv layer was modified, thus need to train 
                params.append(p) 
                trainable_names.append(name)
            else:
                # Following make sure we do not miss any new params
                # all new added trainable params have to be haddled above
                # otherwise it will trigger the following error  
                assert name in original_params_names, name 
            all_params_name.append(name) 


        self.opt = torch.optim.AdamW(params, lr=config.base_learning_rate, weight_decay=config.weight_decay) 
        count_params(params)
        
        


        #  = = = = = EMA... It is worse than normal model in early experiments, thus never enabled later = = = = = = = = = #
        if config.enable_ema:
            self.master_params = list(self.model.parameters()) 
            self.ema = deepcopy(self.model)
            self.ema_params = list(self.ema.parameters())
            self.ema.eval()




        # = = = = = = = = = = = = = = = = = = = = create scheduler = = = = = = = = = = = = = = = = = = = = #
        if config.scheduler_type == "cosine":
            self.scheduler = get_cosine_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps, num_training_steps=config.total_iters)
        elif config.scheduler_type == "constant":
            self.scheduler = get_constant_schedule_with_warmup(self.opt, num_warmup_steps=config.warmup_steps)
        else:
            assert False 




        # = = = = = = = = = = = = = = = = = = = = create data = = = = = = = = = = = = = = = = = = = = #  
        pt_path = config.pt_path
        dataset_path = config.dataset # TODO
        dataset_train = Test10(dataset_path, pt_path)

        train_dataloader = torch.utils.data.DataLoader(
            dataset_train,
            shuffle=True,
            batch_size=config.batch_size,
            num_workers=0,
            drop_last = True,
        )
        
        self.dataset_train = dataset_train
        self.loader_train = wrap_loader(train_dataloader)

        if get_rank() == 0:
            total_image = dataset_train.total_images()
            print("Total training images: ", total_image)     
        



        # = = = = = = = = = = = = = = = = = = = = load from autoresuming ckpt = = = = = = = = = = = = = = = = = = = = #
        self.starting_iter = 0  
        if checkpoint is not None:
            checkpoint = torch.load(checkpoint, map_location="cpu")
            self.model.load_state_dict(checkpoint["model"])
            if config.enable_ema:
                self.ema.load_state_dict(checkpoint["ema"])
            self.opt.load_state_dict(checkpoint["opt"])
            self.scheduler.load_state_dict(checkpoint["scheduler"])
            self.starting_iter = checkpoint["iters"]
            if self.starting_iter >= config.total_iters:
                synchronize()
                print("Training finished. Start exiting")
                exit()


        # = = = = = = = = = = = = = = = = = = = = misc and ddp = = = = = = = = = = = = = = = = = = = =#    
        
        # func return input for grounding tokenizer 
        self.grounding_tokenizer_input = instantiate_from_config(config.grounding_tokenizer_input)
        self.model.grounding_tokenizer_input = self.grounding_tokenizer_input
        
        # func return input for grounding downsampler  
        self.grounding_downsampler_input = None
        if 'grounding_downsampler_input' in config:
            self.grounding_downsampler_input = instantiate_from_config(config.grounding_downsampler_input)

        if get_rank() == 0:       
            self.image_caption_saver = ImageCaptionSaver(self.name)

        if config.distributed:
            self.model = DDP( self.model, device_ids=[config.local_rank], output_device=config.local_rank, broadcast_buffers=False )

    @torch.no_grad()
    def get_input(self, batch):
        pp = batch['caption']

        z = self.autoencoder.encode( batch["image"] )

        context = self.text_encoder.encode( batch["caption"]  )


        _t = torch.rand(z.shape[0]).to(z.device)
        t = (torch.pow(_t, 1) * 1000).long()
        t = torch.where(t!=1000, t, 999) # if 1000, then replace it with 999
        
        inpainting_extra_input = None
        if self.config.inpaint_mode:
            # extra input for the inpainting model 
            inpainting_mask = draw_masks_from_boxes(batch['boxes'], 64, randomize_fg_mask=self.config.randomize_fg_mask, random_add_bg_mask=self.config.random_add_bg_mask).cuda()
            masked_z = z*inpainting_mask
            inpainting_extra_input = torch.cat([masked_z,inpainting_mask], dim=1)              
        
        grounding_extra_input = None
        if self.grounding_downsampler_input != None:
            grounding_extra_input = self.grounding_downsampler_input.prepare(batch)

        return z, t, context, inpainting_extra_input, grounding_extra_input 


    def run_one_step(self, batch):
        
        x_start, _, context, inpainting_extra_input, grounding_extra_input = self.get_input(batch)
        grounding_input = self.grounding_tokenizer_input.prepare(batch)
        noise = torch.randn_like(x_start)

        model_wo_wrapper = self.model.module if self.config.distributed else self.model
        
        plms_sampler = PLMSSampler(self.diffusion, model_wo_wrapper) 

        b = x_start.shape[0]
        plms_sampler.make_schedule2(ddim_num_steps=1000) 
        plms_sampler.ddim_alphas_prev = torch.from_numpy(plms_sampler.ddim_alphas_prev)
        plms_sampler.ddim_timesteps = torch.from_numpy(plms_sampler.ddim_timesteps)

        index = torch.randint(0, 1000, (x_start.shape[0],))
        t = plms_sampler.ddim_timesteps[index].to(self.device).long()
        print("t", t)
        x_noisy = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise)

        input = dict( x=x_noisy, 
                    timesteps=t, 
                    context=context, 
                    inpainting_extra_input=inpainting_extra_input,
                    grounding_extra_input=grounding_extra_input,
                    grounding_input=grounding_input )
        model_output, att_first, att_second, att_third = self.model(input)

        pred_x0 = plms_sampler.p_sample_pred_clean(x_noisy, model_output, index)

        autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. 
        samples = autoencoder_wo_wrapper.decode(pred_x0)

        loss_semantic = 0
        
        for idx, sample in enumerate(samples):
            sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5

            loss_semantic += cal_layout_semantic_loss(idx, sample, batch, self.device, self.clip_guide_feature_extractor, self.clip_guide_clip_model, self.clip_guide_tokenizer, self.clip_guide_text_encoder)
            
            prompt_str = batch["caption"][idx]
            loss_semantic += clip_guided(sample, prompt_str, self.device, self.clip_guide_feature_extractor, self.clip_guide_clip_model, self.clip_guide_tokenizer, self.clip_guide_text_encoder,clip_guidance_scale = 100)
        
        loss_semantic = loss_semantic / len(samples) 
        assert idx == 0
        
        loss_attention, min_inside, max_outside = attention_loss(att_second,att_first,att_third, bboxes=batch['boxes'],
                                    object_positions=batch['position'], box_att=batch['box_att'] ,t = t)
        
        
        loss = torch.nn.functional.mse_loss(model_output, noise) * self.l_simple_weight 
        loss = loss + loss_attention * self.config.attention_weight +  loss_semantic * self.config.semantic_weight # final 1
        
        self.loss_dict = {"loss": loss.item()}

        return loss 

    def start_training(self):

        iterator = tqdm(range(self.starting_iter, self.config.total_iters), desc='Training progress',  disable=get_rank() != 0 )
        # self.save_ckpt_and_result(7000)
        self.model.train()

        for iter_idx in iterator: # note: iter_idx is not from 0 if resume training
            self.iter_idx = iter_idx

            self.opt.zero_grad()
            batch = next(self.loader_train)
            batch_to_device(batch, self.device)

            loss = self.run_one_step(batch)
            loss.backward()
            self.opt.step() 
            self.scheduler.step()
            if self.config.enable_ema:
                update_ema(self.ema_params, self.master_params, self.config.ema_rate)

            if (get_rank() == 0):
                if (iter_idx % 10 == 0):
                    self.log_loss()
                if (iter_idx > 0 ):
                    if (iter_idx == 0)  or  ( iter_idx % self.config.save_every_iters == 0 )  or  (iter_idx == self.config.total_iters-1):
                        self.save_ckpt_and_result(iter_idx)
            synchronize()

        
        synchronize()
        print("Training finished. Start exiting")
        exit()


    def log_loss(self):
        for k, v in self.loss_dict.items():
            self.writer.add_scalar(  k, v, self.iter_idx+1  )  # we add 1 as the actual name
    

    @torch.no_grad()
    def save_ckpt_and_result(self, iter_idx):
        model_wo_wrapper = self.model.module if self.config.distributed else self.model

        iter_name = self.iter_idx + 1     # we add 1 as the actual name
        
        if not self.config.disable_inference_in_training:
            # Do an inference on one training batch 
            batch_here = self.config.batch_size
            batch = sub_batch( next(self.loader_train), batch_here)
            batch_to_device(batch, self.device)

            
            if "boxes" in batch:
                real_images_with_box_drawing = [] # we save this durining trianing for better visualization
                for i in range(batch_here):
                    temp_data = {"image": batch["image"][i], "boxes":batch["boxes"][i]}
                    im = self.dataset_train.vis_getitem_data(out=temp_data, return_tensor=True, print_caption=False)
                    real_images_with_box_drawing.append(im)
                    

                real_images_with_box_drawing = torch.stack(real_images_with_box_drawing)
            else:
                # keypoint case 
                real_images_with_box_drawing = batch["image"]*0.5 + 0.5 
                
            
            uc = self.text_encoder.encode( batch_here*[""] )

            context = self.text_encoder.encode(  batch["caption"]  )
            
            plms_sampler = PLMSSampler(self.diffusion, model_wo_wrapper)      
            shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size)
            
            # extra input for inpainting 
            inpainting_extra_input = None
            if self.config.inpaint_mode:
                z = self.autoencoder.encode( batch["image"] )
                inpainting_mask = draw_masks_from_boxes(batch['boxes'], 64, randomize_fg_mask=self.config.randomize_fg_mask, random_add_bg_mask=self.config.random_add_bg_mask).cuda()
                masked_z = z*inpainting_mask
                inpainting_extra_input = torch.cat([masked_z,inpainting_mask], dim=1)
            
            grounding_extra_input = None
            if self.grounding_downsampler_input != None:
                grounding_extra_input = self.grounding_downsampler_input.prepare(batch)
            
            grounding_input = self.grounding_tokenizer_input.prepare(batch)
            input = dict( x=None, 
                          timesteps=None, 
                          context=context, 
                          inpainting_extra_input=inpainting_extra_input,
                          grounding_extra_input=grounding_extra_input,
                          grounding_input=grounding_input )
            samples = plms_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5)
            
            autoencoder_wo_wrapper = self.autoencoder # Note itself is without wrapper since we do not train that. 
            samples = autoencoder_wo_wrapper.decode(samples).cpu()
            samples = torch.clamp(samples, min=-1, max=1)

            masked_real_image =  batch["image"]*torch.nn.functional.interpolate(inpainting_mask, size=(512, 512)) if self.config.inpaint_mode else None
            self.image_caption_saver(samples, real_images_with_box_drawing,  masked_real_image, batch["caption"], iter_name)
        
        if iter_idx < self.config.start_save:
            return
        ckpt = dict(model = model_wo_wrapper.state_dict(),
                    text_encoder = self.text_encoder.state_dict(),
                    autoencoder = self.autoencoder.state_dict(),
                    diffusion = self.diffusion.state_dict(),
                    opt = self.opt.state_dict(),
                    scheduler= self.scheduler.state_dict(),
                    iters = self.iter_idx+1,
                    config_dict=self.config_dict,
        )
        if self.config.enable_ema:
            ckpt["ema"] = self.ema.state_dict()
        torch.save( ckpt, os.path.join(self.name, "checkpoint_"+str(iter_name).zfill(8)+".pth") )
        torch.save( ckpt, os.path.join(self.name, "checkpoint_latest.pth") )


