from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import os
import numpy as np
from PIL import Image 
import json 


def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def dilate(image, r:int=5):
    if r == 0:
        return image
    else:
        # 5x5 elements
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (r, r)) 
        image = np.array(image * 255, dtype=np.uint8)
        dst = cv2.dilate(image, kernel=kernel) 
        return (dst / 255.0).astype(np.float32)


def get_img_cap(file_name, image_root=None, suffix="png"):
    caption_path = r"/mnt/data/users/caption_val.jsonl"
    
    if image_root is None: 
        image_root = r"/mnt/data/users/condition_val/images"

    caption_dict = {}
    with open(caption_path, 'r') as f:
        lines = f.readlines()
        for aline in lines:
            cur_data = json.loads(aline)
            caption_dict[cur_data["filename"]] = cur_data
    ### image path 
    image_path = caption_dict[file_name]["filename"]
    if suffix == "png":
        image_path = image_path.replace("jpg", suffix)
    image_path = os.path.join(image_root, image_path)
    ### prompt 
    prompt = caption_dict[file_name]["caption"]
    
    return image_path, prompt 


def controlnet_infer(
        base_model_path, controlnet_path, control_image_path, prompt, 
        output_root, save_name, seed=555, 
        control_guidance_start=0.0,
        control_guidance_end=1.0, 
        do_ratio_condition=False, 
        do_predict=False, 
        deteriorate_ratio=0, 
        epoch_percentage=None, 
    ):
    
    print("==> model_path: {}".format(controlnet_path))

    controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        base_model_path, controlnet=controlnet, torch_dtype=torch.float16
    )

    # speed up diffusion process with faster scheduler and memory optimization
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_model_cpu_offload()
    pipe.safety_checker = None

    orig_control_image = load_image(control_image_path)
    image_shape = orig_control_image.size 
    control_image = orig_control_image.resize((512, 512), Image.BILINEAR)
    
    deteriorate_ratio = torch.tensor(deteriorate_ratio, dtype=torch.float32).reshape(-1, 1)
    print("==> prompt: {}".format(prompt))

    # generate image
    print("==> seed: {}".format(seed))
    generator = torch.manual_seed(seed)
    
    """ hyper-parameters
    num_inference_stept = 50
    UniPCMultistepScheduler
    guidance_scale = 7.5
    eta = 0
    controlnet_condition_scalt = 1.0
    guess_mode = False
    control_guidance_start = 0.0, control_guidance_end = 1.0
    """
    output_image = pipe(
            prompt, 
            num_inference_steps=50, 
            generator=generator, 
            image=control_image, 
            num_images_per_prompt=4, 
            control_guidance_start=control_guidance_start, 
            control_guidance_end=control_guidance_end, 
            do_ratio_condition=do_ratio_condition, 
            do_predict=do_predict, 
            deteriorate_ratio=deteriorate_ratio,
            epoch_percentage=epoch_percentage, 
        ).images
    
    if not os.path.exists(output_root):
        os.mkdir(output_root)

    output_image = [aimg.resize(image_shape, Image.BILINEAR) for aimg in output_image]
    output_image.insert(0, orig_control_image)

    mix_image = image_grid(output_image, rows=1, cols=5)
    save_path = os.path.join(output_root, save_name)
    mix_image.save(save_path)    

    print("Congratulations!!!")


def test_predcond(test_filename, demo_id, dilate_ratio, image_root=None, seed=555):
    ### controlnet-seg
    base_model_path = "/mnt/data/StableDiffusion/stable-diffusion-v1-5"
    output_root = "/mnt/data/output_dir/demo_{}".format(demo_id)
    os.makedirs(output_root, exist_ok=True)
    controlnet_path = r"/mnt/data/controlnet"

    if test_filename is None: 
        test_filename = ""
    
    control_image_path, prompt = get_img_cap(test_filename, image_root)
    save_name = "{}_{}.png".format(test_filename.split(".")[0], dilate_ratio)
    dilate_ratio = 0.1 * dilate_ratio 

    # run inference
    controlnet_infer(
            base_model_path, 
            controlnet_path,
            control_image_path,
            prompt,
            output_root=output_root, 
            save_name=save_name,
            seed=seed, 
            do_predict=True, 
            do_ratio_condition=True, 
            deteriorate_ratio=dilate_ratio, 
            epoch_percentage=None, 
        )


if __name__ == "__main__":
    for r in [-20, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18,20]:      
        aid = 5 
        afilename = "demo.jpg"
        image_root = r"/mnt/data"

    print(":) Congratulations!")
