from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler
from ip_adapter import IPAdapterFull, IPAdapterPlus
from ip_adapter.ip_adapter_faceid import IPAdapterFaceID
import torch
from PIL import Image
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import cv2
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
from transformers import CLIPModel, CLIPProcessor
from facenet_pytorch import MTCNN, InceptionResnetV1
from torch.nn import functional as F
import os
import pickle

pos_prep = [" in ", " at ", " on ", " through ", " over ", " by ", " along ", " from ", " off "]
def compute_iou(boxes1, boxes2):
    # Compute the iou between box
    left = np.maximum(boxes1[:,None,0], boxes2[None,:,0])
    top = np.maximum(boxes1[:,None,1], boxes2[None,:,1])
    right = np.minimum(boxes1[:,None,2], boxes2[None,:,2])
    bottom = np.minimum(boxes1[:,None,3], boxes2[None,:,3])

    width = np.clip(right - left, 0, None)
    height = np.clip(bottom - top, 0, None)
    inter = width * height
    
    area1 = (boxes1[:,2] - boxes1[:,0]) * (boxes1[:,3] - boxes1[:,1])
    area2 = (boxes2[:,2] - boxes2[:,0]) * (boxes2[:,3] - boxes2[:,1])
    union = area1[:,None] + area2[None,:]
    iou = inter / union
    return iou

@torch.no_grad()
def data_filtering(img_list,
                   prompt,
                   clip_model, 
                   clip_processor,
                   mtcnn,
                   face_model,
                   src_img: Image,
                   clip_thres,
                   face_thres,
                   select_top_number=5):
    vision_batch_size = 4
    num_iters = (len(img_list) -1) // vision_batch_size + 1
    img_embeddings = []
    for i in range(num_iters):
        imgs = img_list[i*vision_batch_size:(i+1)*vision_batch_size]
        try:
            imgs = clip_processor.image_processor(imgs, return_tensors="pt")["pixel_values"]
        except:
            breakpoint()
        img_features = clip_model.get_image_features(imgs.to("cuda"))
        img_embeddings.append(img_features)
    
    img_embeddings = torch.cat(img_embeddings)
    img_embeddings = F.normalize(img_embeddings, dim=-1)
    clip_binary_mask = torch.ones([len(img_embeddings)], dtype=torch.bool).cuda()
    # action_prompt_length = 0
    # action_prompt=""
    # for prep in pos_prep:
    #     if prep in prompt:
    #         a_prompt = prompt[0]
    #         a_prompt = prompt.split(prep)[0]
    #         if len(a_prompt) > action_prompt_length:
    #             action_prompt = a_prompt
    #             action_prompt_length = len(action_prompt)
    # if action_prompt_length > 0:
    #     input_ids = clip_processor.tokenizer(
    #         action_prompt,
    #         padding="max_length",
    #         truncation=True,
    #         max_length=clip_processor.tokenizer.model_max_length,
    #         return_tensors="pt",
    #     ).input_ids
    #     text_embeddings = clip_model.get_text_features(input_ids.to("cuda"))
    #     text_embeddings = F.normalize(text_embeddings, dim=-1, p=2)
    #     cosine_sim = (img_embeddings @ text_embeddings.T).squeeze(dim=-1)
    #     clip_binary_mask = cosine_sim > clip_thres
    #     print(action_prompt)
    #     print(cosine_sim)
    input_ids = clip_processor.tokenizer(
        prompt,
        padding="max_length",
        truncation=True,
        max_length=clip_processor.tokenizer.model_max_length,
        return_tensors="pt",
    ).input_ids
    text_embeddings = clip_model.get_text_features(input_ids.to("cuda"))
    text_embeddings = F.normalize(text_embeddings, dim=-1, p=2)
    cosine_sim = (img_embeddings @ text_embeddings.T).squeeze(dim=-1)
    clip_binary_mask *= (cosine_sim > clip_thres)
    clip_binary_mask = clip_binary_mask.cpu()

    # Next, apply the face detection on the face image compute the face embeddings
    batch_boxes, batch_probs, batch_points = mtcnn.detect(src_img, landmarks=True)
    batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(batch_boxes, batch_probs, batch_points,
                                                          src_img, mtcnn.selection_method)
    ind = np.argmax(batch_probs)
    batch_boxes = batch_boxes[ind:ind+1,:]
    src_face_img = mtcnn.extract(src_img, batch_boxes, save_path=None)

    valid_face_img = []
    tgt_face_img = []
    # Detect face in the list img
    for img in img_list:
        batch_boxes, batch_probs, batch_points = mtcnn.detect(img, landmarks=True)
        if batch_boxes is None:
            valid_face_img.append(False)
            tgt_face_img.append(torch.zeros_like(src_face_img))
        else:
            batch_boxes, batch_probs, batch_points = mtcnn.select_boxes(batch_boxes, batch_probs, batch_points,
                                                          src_img, mtcnn.selection_method)
            ind = np.argmax(batch_probs)
            batch_boxes = batch_boxes[ind:ind+1, :]
            face_img = mtcnn.extract(img, batch_boxes, save_path=None)
            tgt_face_img.append(face_img)
            valid_face_img.append(True)
    
    valid_face_img = torch.tensor(valid_face_img)
    tgt_face_img = [src_face_img] + tgt_face_img
    batch_size = 3
    num_iters = (len(img_list) - 1) // batch_size + 1
    face_embeddings = []
    for i in range(num_iters):
        imgs = torch.stack(tgt_face_img[i*batch_size:(i+1)*batch_size])
        embeddings = face_model(imgs.cuda()).detach()
        face_embeddings.append(embeddings)
    face_embeddings = torch.cat(face_embeddings)
    src_embeddings = face_embeddings[0:1]
    tgt_embeddings = face_embeddings[1:]
    face_sim = (tgt_embeddings @ src_embeddings.T).squeeze(dim=-1)
    face_binary_mask = face_sim > face_thres
    face_binary_mask = face_binary_mask.cpu()

    # Further select the top images
    similarity = torch.sqrt(cosine_sim * face_sim)
    valid_img = valid_face_img * clip_binary_mask * face_binary_mask
    valid_img = valid_img.cpu().numpy()
    similarity = similarity.cpu().numpy()
    selected_imgs = [img_list[i] for i in range(len(img_list)) if valid_img[i]]
    similarity = [similarity[i] for i in range(len(similarity)) if valid_img[i]]

    indices = np.argsort(similarity)[::-1][:select_top_number]
    selected_imgs = [selected_imgs[i] for i in indices]

    return selected_imgs

def generate_data(ip_model, prompt, cond_img_dir, num_samples=16, num_samples_per_iter=4):
    num_iters = num_samples // num_samples_per_iter
    cond_image = Image.open(cond_img_dir)
    img_list = []
    seed=42
    neg_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, non nudity, bad hands, unnatural hands, disfigured hands"
    for _ in range(num_iters):
        images = ip_model.generate(
            pil_image=cond_image,
            num_samples=num_samples_per_iter,
            prompt=prompt,
            negative_prompt=neg_prompt,
            scale=0.6, width=512, height=512,
            num_inference_steps=30, seed=seed
        )
        seed+=1
        img_list+=images
    return img_list

@torch.no_grad()
def gen_seg_mask(img,
                 ref_img,
                 owl_processor,
                 owl_model,
                 segmented_prompt,
                 sam_predictor,
                 mtcnn,
                 face_model,
                 device="cuda"):
    inputs = owl_processor(text=segmented_prompt, images=img, return_tensors="pt")
    inputs["input_ids"] = inputs["input_ids"].to(device)
    inputs["attention_mask"] = inputs["attention_mask"].to(device)
    inputs["pixel_values"] = inputs["pixel_values"].to(device)
    outputs = owl_model(**inputs)
    target_sizes = torch.Tensor([img.size[::-1]])
    results = owl_processor.post_process_object_detection(outputs=outputs, threshold=0.2, target_sizes=target_sizes)
    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
    boxes = boxes.cpu().numpy().astype(np.float32)
    labels = labels.cpu().numpy().astype(np.int64)
    scores = scores.cpu().numpy().astype(np.float32)
    if 0 not in labels:
        return None
    f_boxes = boxes[labels==0, :]
    face_boxes, face_probs, face_points = mtcnn.detect(img, landmarks=True)
    face_boxes, face_probs, face_points = mtcnn.select_boxes(
          face_boxes, face_probs, face_points, img, method=mtcnn.selection_method
      )
    face_boxes = face_boxes.astype(np.float32)
    iou = compute_iou(face_boxes, f_boxes)
    indices = np.argmax(iou, axis=-1)
    if len(face_boxes) > 1:
        face_imgs = mtcnn.extract(img, face_boxes, save_path=None)
        ref_boxes, _ = mtcnn.detect(ref_img, landmarks=False)
        ref_imgs = mtcnn.extract(img, ref_boxes, save_path=None)
        face_imgs = torch.cat([ref_imgs, face_imgs]).cuda()
        face_embeddings = face_model(face_imgs)
        ref_embeddings = face_embeddings[:1]
        face_embeddings = face_embeddings[1:]
        face_sim = face_embeddings @ ref_embeddings.T
        face_sim = face_sim[:,0].cpu().numpy()
        selected_face_id = np.argmax(face_sim)
    else:
        selected_face_id = 0
    
    selected_box = indices[selected_face_id]
    face_box = f_boxes[selected_box]
    sam_predictor.set_image(np.array(img))
    all_masks = []
    masks, sam_scores, logits = sam_predictor.predict(
        box=face_box,
        multimask_output=False
    )
    all_masks.append(masks)
    return_boxes = [face_box]
    return_labels = [0]
    unique_cls = np.unique(labels)
    r_boxes = []; r_labels = []
    for cls_id in unique_cls:
        if cls_id!=0:
            cls_scores = scores[labels==cls_id]
            cls_boxes = boxes[labels==cls_id]
            indice = np.argmax(cls_scores)
            box = cls_boxes[indice]
            r_boxes.append(box)
            r_labels.append(cls_id)
    for box, label in zip(r_boxes, r_labels):
        if label!=0:
            masks, scores, logits = sam_predictor.predict(
                box=box,
                multimask_output=False
            )
            all_masks.append(masks)
            return_boxes.append(box)
            return_labels.append(label)
    return_boxes = np.stack(return_boxes)
    all_masks = np.concatenate(all_masks,axis=0)
    all_masks = np.sum(all_masks, axis=0)
    mask = np.where(all_masks>=1, 1, 0)
    return mask.astype(np.int64), return_boxes, return_labels

if __name__ == "__main__":
    with open("prompt_list2.txt", "r") as f:
        prompt_list = f.read()
    prompt_list = prompt_list.split("\n")
    # with open("seg_list2.txt", "r") as f:
    #     seg_list = f.read()
    # seg_list = seg_list.split("\n")
    # for i in range(len(seg_list)):
    #     if "," in seg_list[i]:
    #         seg_list[i] = seg_list[i].split(",")
    #     else:
    #         seg_list[i] = [seg_list[i]]
    base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "models/image_encoder/"
    ip_ckpt = "models/ip-adapter-plus-face_sd15.bin"
    device = "cuda"
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)
    owl_processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")
    owl_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14")
    owl_model.to("cuda")
    sam_checkpoint = "pretrained/sam_checkpoint/sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device="cuda")
    sam_predictor = SamPredictor(sam)
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    clip_model.to("cuda")
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    mtcnn = MTCNN(image_size=160,
                  margin=0,
                  min_face_size=20,
                  thresholds=[0.6, 0.7, 0.7],
                  factor=0.709,
                  post_process=True,
                  device="cuda")
    face_model = InceptionResnetV1(pretrained="vggface2").eval().to("cuda")
    save_idx = 9
    img_save_dir = f"rendered_imgs/{save_idx}"
    os.makedirs(img_save_dir, exist_ok=True)
    os.makedirs(os.path.join(img_save_dir, "img"), exist_ok=True)
    os.makedirs(os.path.join(img_save_dir, "mask"), exist_ok=True)
    img_idx = 0
    annotations=[]
    for i in range(2000*save_idx, 25000):
        if i==2000*(save_idx + 1): break
        print(f"Generate for image {i}")
        src_img = Image.open(f"data/CelebAMask-HQ/CelebA-HQ-img/{i}.jpg")
        selected_prompt_indices = np.random.permutation(len(prompt_list))[:5]
        selected_prompt_list = [prompt_list[id] for id in selected_prompt_indices]
        # selected_seg_list = [seg_list[id] for id in selected_prompt_indices]
        for j in range(len(selected_prompt_list)):
            prompt = selected_prompt_list[j]
            # texts = selected_seg_list[j]
            print(f"Generate with prompt {prompt}, seg list {texts}")
            img_list = generate_data(ip_model, prompt, f"data/CelebAMask-HQ/CelebA-HQ-img/{i}.jpg",
                                     num_samples=4, num_samples_per_iter=4)
            try:
                img_list = data_filtering(img_list,
                                        prompt,
                                        clip_model,
                                        clip_processor,
                                        mtcnn,
                                        face_model,
                                        src_img,
                                        clip_thres=0.21,
                                        face_thres=0.5,
                                        select_top_number=1)
            except:
                continue
            if len(img_list) == 0: continue
            for idx in range(len(img_list)):
                try:
                    # mask, boxes, labels = gen_seg_mask(img_list[idx],
                    #                                 src_img,
                    #                                 owl_processor,
                    #                                 owl_model,
                    #                                 texts,
                    #                                 sam_predictor,
                    #                                 mtcnn,
                    #                                 face_model)
                    # mask = (255*mask).astype(np.uint8)
                    # mask = np.stack([mask, mask, mask], axis=-1)
                    # img_list[idx].save(os.path.join(img_save_dir, "img", f"{img_idx}.png"))
                    # Image.fromarray(mask).save(os.path.join(img_save_dir, "mask", f"{img_idx}.png"))
                    annotation={}
                    annotation["ref_img"] = f"data/CelebAMask-HQ/CelebA-HQ-img/{i}.jpg"
                    annotation["rendered_img"] = os.path.join(img_save_dir, "img", f"{img_idx}.png")
                    # annotation["mask_img"] = os.path.join(img_save_dir, "mask", f"{img_idx}.png")
                    annotation["prompt"] = prompt
                    # annotation["seg_list"] = texts
                    # annotation["box"] = boxes.tolist()
                    # annotation["label"] = labels
                    annotations.append(annotation)
                    img_idx += 1
                except:
                    continue
        if (i+1)%500 == 0:
            with open(os.path.join(img_save_dir, f"annotations_{i}.pkl"), "wb") as f:
                pickle.dump(annotations, f)
                del annotations
                annotations = []