"""
ControlNet generation with different conditions
"""
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import os
from PIL import Image 
import json
import numpy as np
from tqdm import tqdm 
import argparse

from test_config import * 
import cv2 

import ipdb


########################################################################
# Functions
########################################################################
def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def dilate(image, r:int=5):
    if r == 0:
        return image
    else:
        # 5x5 elements
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (r, r)) 
        image = np.array(image * 255, dtype=np.uint8)
        dst = cv2.dilate(image, kernel=kernel) 
        return (dst / 255.0).astype(np.float32)

############################################################################################################################

SEED = 555

def condition_generation(
            condition_mode, 
            output_root, 
            dilate_radius=0.0, 
            do_predict=True, 
            do_ratio_condition=True,
            epoch_percentage=0.0, 
        ):
    ### set random seed
    np.random.seed(SEED)

    ### arg parsers
    config = CONTROLNET_CONFIG[condition_mode]
    base_model_path = BASEMODEL_PATH["sd_v15"]
    controlnet_path = config["model_path"]

    image_root = IMAGE_PATH
    captioned_image_root = CAPTIONED_IMAGE_PATH
    caption_path = CAPTION_PATH
    conditioning_root = config["conditioning_path"]
    conditioning_suffix = config["conditioning_suffix"] 
 
    ### load captions 
    caption_list = []
    with open(caption_path, 'r') as f:
        lines = f.readlines()
        for aline in lines:
            cur_data = json.loads(aline)
            caption_list.append(cur_data)

    num_data = len(caption_list)

    ### load model 
    ### prepare models and pipeline
    data_type = torch.float16 
    controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=data_type)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=controlnet, torch_dtype=data_type)

    # speed up diffusion process with faster scheduler and memory optimization
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_model_cpu_offload()
    pipe.safety_checker = None

    ### load control images 
    batch_size = 4 # 32
    num_images_per_prompt = 4
    num_rounds = int(num_data / batch_size) + 1 
    
    for idx in range(num_rounds):
        cur_start_idx = idx*batch_size
        cur_end_idx = min(num_data, (idx+1)*batch_size)
        cur_item = caption_list[cur_start_idx:cur_end_idx]
        if len(cur_item) == 0:
            continue
        
        cur_filename = [aitem["filename"] for aitem in cur_item]
        
        if "woprompt" in condition_mode:
            cur_caption = ["" for aitem in cur_item]
        else:
            cur_caption = [aitem["caption"] for aitem in cur_item]
        
        cur_image_path = [os.path.join(image_root, afilename) for afilename in cur_filename]
        cur_conditioning_image_path = [os.path.join(conditioning_root, afilename) for afilename in cur_filename]
        cur_captioned_image_path = [os.path.join(captioned_image_root, afilename) for afilename in cur_filename]      
        # print("\t[{}|{}] ==> prompt: {}".format(idx, num_data, cur_caption))
        
        # load images 
        cur_captioned_image = [load_image(apath) for apath in cur_captioned_image_path]

        # dilate control image for test with different radius and get deteriorate_ratio 
        cur_control_image = []
        cur_deteriorate_ratio = []
        for afilename, apath in zip(cur_filename, cur_conditioning_image_path): 
            # convert to numpy for binarize 
            acontrol_image = (np.array(load_image(apath.replace("jpg", conditioning_suffix))) > 127).astype(np.float32)
            acontrol_bbox_image = (np.array(load_image(os.path.join(BBOX_IMAGE_PATH, afilename))) > 127).astype(np.float32)
            # dilated mask area             
            adilate_control_image = dilate(acontrol_image, r=dilate_radius)
            adilate_control_image = acontrol_bbox_image * adilate_control_image
            # get ratio 
            area_mask = np.sum(acontrol_image)
            area_bbox = np.sum(acontrol_bbox_image)
            area_dilate_within_bbox = np.sum(adilate_control_image)
            # collect info 
            area_list = [area_mask, area_dilate_within_bbox, area_bbox]
            deteriorate_ratio = max(area_dilate_within_bbox - area_mask, 0) / (area_bbox - area_mask + 1e-5)

            cur_control_image.append(Image.fromarray(np.uint8(255 * adilate_control_image)))
            cur_deteriorate_ratio.append(deteriorate_ratio)
        cur_deteriorate_ratio = torch.tensor(cur_deteriorate_ratio, dtype=torch.float32).reshape(-1, 1)

        cur_image_shape = [aimg.size for aimg in cur_control_image]
        cur_resize_control_image = [aimg.resize((512, 512), Image.BILINEAR) for aimg in cur_control_image]
               
        # image generation
        generator = [torch.Generator("cuda").manual_seed(np.random.randint(0, 10000)) for i in range(batch_size*num_images_per_prompt)]
        """ hyper-parameters
        num_inference_stept = 50
        UniPCMultistepScheduler
        guidance_scale = 7.5
        eta = 0
        controlnet_condition_scale = 1.0
        guess_mode = False
        control_guidance_start = 0.0, control_guidance_end = 1.0
        """
        output_image_list = pipe(
                cur_caption, 
                num_inference_steps=50, 
                generator=generator, 
                image=cur_resize_control_image, 
                num_images_per_prompt=num_images_per_prompt, 
                do_ratio_condition=do_ratio_condition, 
                do_predict=do_predict, 
                deteriorate_ratio=cur_deteriorate_ratio,
                epoch_percentage=epoch_percentage, 
            ).images
        
        ### save batched images
        num_groups = int(len(output_image_list) / num_images_per_prompt)
        for group_id in range(num_groups):
            # collect data
            cur_output_image_list = output_image_list[group_id*num_images_per_prompt : (group_id+1)*num_images_per_prompt]
            cur_save_path = os.path.join(output_root, cur_filename[group_id])

            # resize output image
            cur_output_image_list = [aimg.resize(cur_image_shape[group_id], Image.BILINEAR) for aimg in cur_output_image_list]
            cur_captioned_image = [acapimg.resize(cur_image_shape[group_id], Image.BILINEAR) for acapimg in cur_captioned_image]
        
            # organize images
            cur_output_image_list.insert(0, cur_control_image[group_id])
            cur_output_image_list.insert(0, cur_captioned_image[group_id])
            mix_image = image_grid(cur_output_image_list, rows=1, cols=6)
            cur_save_path = os.path.join(output_root, cur_filename[group_id]).replace(".jpg", ".png") 
            mix_image.save(cur_save_path)

        # clean cache
        torch.cuda.empty_cache()

        if idx >= (30 - 1):  
            break
    
    print(":) Congratulations!!!")   


def arg_parser():
    parser = argparse.ArgumentParser(description='Generate images for evaluation.')     
    parser.add_argument('--mode', type=str, default=None)
    parser.add_argument('--save_size_mode', type=str, default="raw")
    parser.add_argument('--baseline_on', action='store_true', )
    parser.add_argument('--cfg_scale', type=float, default=7.5)
    parser.add_argument('--condition_scale', type=float, default=1.0)
    parser.add_argument('--control_guidance_start', type=float, default=0.0)
    parser.add_argument('--control_guidance_end', type=float, default=1.0)
    parser.add_argument('--dilate_radius', type=int, default=0)
    parser.add_argument('--do_predict', action='store_true', )
    parser.add_argument('--do_ratio_condition', action='store_true', )
    parser.add_argument('--epoch_percentage', type=float, default=0.0, )

    args = parser.parse_args()
    return args 


def main():
    ### set test mode 
    args = arg_parser()
    condition_mode_list = [args.mode]
    save_size_mode = args.save_size_mode  # 'raw' or '512'
    baseline_on = args.baseline_on 
    cfg_scale = args.cfg_scale
    controlnet_conditioning_scale = args.condition_scale
    control_guidance_start = float(args.control_guidance_start)
    control_guidance_end = float(args.control_guidance_end)

    condition_mode_len = len(condition_mode_list)

    ### loop for test 
    for index, condition_mode in enumerate(condition_mode_list):
        # set output path
        if condition_mode == "debug":
            output_root = r"/mnt/data/gen_with_{}".format(condition_mode)
        else:
            output_root = r"/mnt/data/gen_with_{}_dilate{}".format(condition_mode, int(args.dilate_radius))
        
        if not os.path.exists(output_root):
            os.mkdir(output_root)
        
        # print info 
        print("[{}|{}] Processing: {} ---------------------------------------------------\n ==> Saved in: {}"
              .format(index, condition_mode_len, condition_mode, output_root))

        print("==> dilate radius = {}".format(args.dilate_radius))

        # image generation 
        condition_generation(
                condition_mode, 
                output_root, 
                dilate_radius=args.dilate_radius, 
                do_predict=args.do_predict, 
                do_ratio_condition=args.do_ratio_condition, 
                epoch_percentage=args.epoch_percentage, 
            )
        print("-*"*50 + '\n')


if __name__ == "__main__":
    main()


