import torch
import argparse
from PIL import Image, ImageDraw
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
import numpy as np
from transformers import CLIPProcessor, CLIPModel
import os 
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
import json
try:
    from apex import amp
except:
    pass  
from PIL import Image
device = "cuda"
# = = = = = = = = = = = = = = = = = = useful functions = = = = = = = = = = = = = = = = = #
def batch_to_device(batch, device):
    for k in batch:
        if isinstance(batch[k], torch.Tensor):
            batch[k] = batch[k].to(device)
    return batch


def wrap_loader(loader):
    while True:
        for batch in loader:  # TODO: it seems each time you have the same order for all epoch?? 
            yield batch

def disable_grads(model):
    for p in model.parameters():
        p.requires_grad = False


def count_params(params):
    total_trainable_params_count = 0 
    for p in params:
        total_trainable_params_count += p.numel()
    print("total_trainable_params_count is: ", total_trainable_params_count)
           
def complete_mask(has_mask, max_objs):
    mask = torch.ones(1,max_objs)
    if has_mask == None:
        return mask 

    if type(has_mask) == int or type(has_mask) == float:
        return mask * has_mask
    else:
        for idx, value in enumerate(has_mask):
            mask[0,idx] = value
        return mask
def prepare_clip(version):
    model = CLIPModel.from_pretrained(version).cuda()
    processor = CLIPProcessor.from_pretrained(version)
    return model, processor

@torch.no_grad()
def prepare_batch(meta, model, processor, batch=1, max_objs=30):
    phrases, images = meta.get("phrases"), meta.get("images")
    images = [None]*len(phrases) if images==None else images 
    phrases = [None]*len(images) if phrases==None else phrases 

    # version = "/mnt/data1/yaqili/model_weights/clip-vit-large-patch14"
    # model = CLIPModel.from_pretrained(version).cuda()
    # processor = CLIPProcessor.from_pretrained(version)

    boxes = torch.zeros(max_objs, 4)
    masks = torch.zeros(max_objs)
    text_masks = torch.zeros(max_objs)
    image_masks = torch.zeros(max_objs)
    text_embeddings = torch.zeros(max_objs, 768)
    image_embeddings = torch.zeros(max_objs, 768)
    
    text_features = []
    image_features = []
    for phrase, image in zip(phrases,images):
        text_features.append(  get_clip_feature(model, processor, phrase, is_image=False) )
        image_features.append( get_clip_feature(model, processor, image,  is_image=True) )

    for idx, (box, text_feature, image_feature) in enumerate(zip( meta['locations'], text_features, image_features)):
        boxes[idx] = torch.tensor(box)
        masks[idx] = 1
        if text_feature is not None:
            text_embeddings[idx] = text_feature
            text_masks[idx] = 1 
        if image_feature is not None:
            image_embeddings[idx] = image_feature
            image_masks[idx] = 1 

    out = {
        "boxes" : boxes.unsqueeze(0).repeat(batch,1,1),
        "masks" : masks.unsqueeze(0).repeat(batch,1),
        "text_masks" : text_masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta.get("text_mask"), max_objs ),
        "image_masks" : image_masks.unsqueeze(0).repeat(batch,1)*complete_mask( meta.get("image_mask"), max_objs ),
        "text_embeddings"  : text_embeddings.unsqueeze(0).repeat(batch,1,1),
        "image_embeddings" : image_embeddings.unsqueeze(0).repeat(batch,1,1)
    }
    

    return batch_to_device(out, device) 

def load_ckpt(ckpt_path):
    
    
    saved_ckpt = torch.load(ckpt_path)
    config = saved_ckpt["config_dict"]["_content"]


    
    model = instantiate_from_config(config['model']).to(device).eval()
    autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval()
    text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval()
    diffusion = instantiate_from_config(config['diffusion']).to(device)
    

    # donot need to load official_ckpt for self.model here, since we will load from our ckpt
    model.load_state_dict( saved_ckpt['model'] )
    autoencoder.load_state_dict( saved_ckpt["autoencoder"]  )
    if 'transformer.text_model.embeddings.position_ids' in saved_ckpt["text_encoder"]:
        del saved_ckpt["text_encoder"]['transformer.text_model.embeddings.position_ids']
    text_encoder.load_state_dict( saved_ckpt["text_encoder"]  )
    diffusion.load_state_dict( saved_ckpt["diffusion"]  )
    

    return model, autoencoder, text_encoder, diffusion, config

def project(x, projection_matrix):
    """
    x (Batch*768) should be the penultimate feature of CLIP (before projection)
    projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer 
    defined in CLIP (out_dim, in_dim), thus we need to apply transpose below.  
    this function will return the CLIP feature (without normalziation)
    """
    return x@torch.transpose(projection_matrix, 0, 1)


def get_clip_feature(model, processor, input, is_image=False):
    which_layer_text = 'before'
    which_layer_image = 'after_reproject'

    if is_image:
        if input == None:
            return None
        image = Image.open(input).convert("RGB")
        inputs = processor(images=[image],  return_tensors="pt", padding=True)
        inputs['pixel_values'] = inputs['pixel_values'].cuda() # we use our own preprocessing without center_crop 
        inputs['input_ids'] = torch.tensor([[0,1,2,3]]).cuda()  # placeholder
        outputs = model(**inputs)
        feature = outputs.image_embeds 
        if which_layer_image == 'after_reproject':
            feature = project( feature, torch.load('projection_matrix').cuda().T ).squeeze(0)
            feature = ( feature / feature.norm() )  * 28.7 
            feature = feature.unsqueeze(0)
    else:
        if input == None:
            return None
        inputs = processor(text=input,  return_tensors="pt", padding=True)
        inputs['input_ids'] = inputs['input_ids'].cuda()
        inputs['pixel_values'] = torch.ones(1,3,224,224).cuda() # placeholder 
        inputs['attention_mask'] = inputs['attention_mask'].cuda()
        outputs = model(**inputs)
        if which_layer_text == 'before':
            feature = outputs.text_model_output.pooler_output
    return feature

def inference(args):
    
    checkpoint = args.ckpt
    model, autoencoder, text_encoder, diffusion, config = load_ckpt(checkpoint)

    version = "/mnt/data1/yaqili/model_weights/clip-vit-large-patch14"
    clip_model, clip_processor = prepare_clip(version)
    
    grounding_tokenizer_input = instantiate_from_config(config['grounding_tokenizer_input'])
    model.grounding_tokenizer_input = grounding_tokenizer_input
    grounding_downsampler_input = None
    if 'grounding_downsampler_input' in config:
        grounding_downsampler_input = instantiate_from_config(config.grounding_downsampler_input)

    full_dict = {}
    
    my_meta_list = json.load(open(args.input_json, 'r'))
    sum_len = len(my_meta_list)
    for idx, meta in enumerate(tqdm(my_meta_list)):
        # if idx > (sum_len) / 4:
        #     continue

        image_path = inference_one_image(meta, idx, model, text_encoder, diffusion, autoencoder, grounding_tokenizer_input, clip_model, clip_processor)
        full_dict[meta['prompt']] = {}
        full_dict[meta['prompt']]['image_path'] = image_path

    with open(args.output_json,'w') as write_f:
        json.dump(full_dict, write_f,indent = 4, ensure_ascii=False)



@torch.no_grad()
def inference_one_image(meta, image_idx, model, text_encoder, diffusion, autoencoder, grounding_tokenizer_input, clip_model, clip_processor):
    
    output_folder_no_layout = os.path.join( args.folder, 'raw_image')
    output_folder = os.path.join( args.folder, 'vis_layout')
    img_name = meta["prompt"]+'_'+(str(image_idx).zfill(6))+'.png'
    # img_name = str(int(image_idx))+'.png'
    image_path = os.path.join(output_folder_no_layout, img_name)
    if os.path.exists(image_path):
        return image_path

    norm_bbox_list = []
    for bbox in meta['locations']:
        norm_bbox = [num / 512 for num in bbox]
        norm_bbox_list.append(norm_bbox)
    meta['locations']=norm_bbox_list
    meta['alpha_type'] = [0.3, 0.0, 0.7]
    
    batch_size = 1
    batch = prepare_batch(meta, clip_model, clip_processor)
    model_wo_wrapper = model
    
    
    # Do an inference on one training batch 
    batch_here = batch_size
    # batch = sub_batch( next(self.loader_train), batch_here)
    batch_to_device(batch, device)
    
    uc = text_encoder.encode( batch_size*[""] )

    context = text_encoder.encode(  meta["prompt"]  )
    
    plms_sampler = PLMSSampler(diffusion, model_wo_wrapper)      
    shape = (batch_here, model_wo_wrapper.in_channels, model_wo_wrapper.image_size, model_wo_wrapper.image_size)
    
    # extra input for inpainting 
    inpainting_extra_input = None
            
    grounding_extra_input = None
    grounding_input = grounding_tokenizer_input.prepare(batch)
    input = dict( x=None, 
                    timesteps=None, 
                    context=context, 
                    inpainting_extra_input=inpainting_extra_input,
                    grounding_extra_input=grounding_extra_input,
                    grounding_input=grounding_input )
    samples = plms_sampler.sample(S=50, shape=shape, input=input, uc=uc, guidance_scale=5)
    
    autoencoder_wo_wrapper = autoencoder # Note itself is without wrapper since we do not train that. 
    samples = autoencoder_wo_wrapper.decode(samples).cpu()
    
    
    
    os.makedirs( output_folder, exist_ok=True)
    os.makedirs( output_folder_no_layout, exist_ok=True)

    for sample in samples:

        img_name = meta["prompt"]+'_'+(str(image_idx).zfill(6))+'.png'
        sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
        sample = sample.cpu().numpy().transpose(1,2,0) * 255 
        sample = Image.fromarray(sample.astype(np.uint8))
        draw = ImageDraw.Draw(sample)
        image_path = os.path.join(output_folder_no_layout, img_name)
        sample.save(  image_path   )
        boxes = meta['locations']
        for i, box in enumerate(boxes):
            draw.rectangle([(int(box[0]*512), int(box[1]*512)),(int(box[2]*512),int(box[3]*512))], outline=128, width=2)
        sample.save(  os.path.join(output_folder, img_name)   )    
    return image_path

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--folder", type=str,  default="/mnt/data1/yaqili/code/GLIGEN/output/inference_output/hrs/test_rela_gate_5000", help="root folder for output")
    parser.add_argument("--ckpt", type=str,  default="/mnt/data1/yaqili/code/GLIGEN/output/train_output/combine_10468/test_rela_gate/tag00/checkpoint_00005000.pth", help="root folder for output")
    
    parser.add_argument("--batch_size", type=int, default=1, help="")
    parser.add_argument("--no_plms", action='store_true', help="use DDIM instead. WARNING: I did not test the code yet")
    parser.add_argument("--guidance_scale", type=float,  default=7.5, help="")
    parser.add_argument("--negative_prompt", type=str,  default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality', help="")
    parser.add_argument("--input_json", type=str,  default='/mnt/data1/yaqili/code/eval/HRS/code/HRS_benchmark-main/uni_hrs_prompts/real_full_prompts/color/501_gligen_input_color.json', help="")
    parser.add_argument("--output_json", type=str,  default='/mnt/data1/yaqili/code/GLIGEN/output/inference_output/hrs/test_rela_gate_5000/output_dict.json', help="")
    
    args = parser.parse_args()
    inference(args)
    
    # trainer = Trainer()
    # trainer.start_training()

# CUDA_VISIBLE_DEVICES=3 python trainer_org_quick_rela_gate_eval.py --base_learning_rate 5e-5 --name test_rela_gate --dataset=/mnt/data1/yaqili/data/color900_other500_coco1k_no_prefix/color900_other500_coco1k_no_prefix.json --pt_path /mnt/data1/yaqili/data/color900_other500_coco1k_no_prefix/color900_other500_coco1k_no_prefix.pt --yaml_file=/mnt/data1/yaqili/code/GLIGEN/configs/flickr_text.yaml  --DATA_ROOT=../../DATA   --batch_size=2 --save_every_iters 100 --total_iters 6000 --start_save 1000
