from diffusers import StableDiffusionInpaintPipeline, AutoPipelineForInpainting, StableDiffusionImg2ImgPipeline
import torch
from PIL import Image
import numpy as np
import cv2
import random
import os
from glob import glob
from tqdm import tqdm
import albumentations as A
import gc
import json

def cv2_jpg(img, compress_val):
    img_cv2 = img[:,:,::-1]
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val]
    result, encimg = cv2.imencode('.jpg', img_cv2, encode_param)
    decimg = cv2.imdecode(encimg, 1)
    return decimg[:,:,::-1]

def cv2_scale(img, scale):
    h, w = img.shape[:2]
    new_w = int(w * scale)
    new_h = int(h * scale)
    resized_img = cv2.resize(img, (new_w, new_h))

    return resized_img

def create_crop_transforms(height=224, width=224):
    aug_list = [
        A.PadIfNeeded(min_height=height, min_width=width, border_mode=cv2.BORDER_CONSTANT, value=0),
        A.CenterCrop(height=height, width=width)
    ]
    return A.Compose(aug_list)


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)


def find_nearest_multiple(a, multiple=8):
    n = a // multiple
    remainder = a % multiple
    if remainder == 0:
        return a
    else:
        return (n + 1) * multiple


def stable_diffusion_img2img(pipe, image, prompt, steps=50, seed=2023, strength=0.6, guidance_scale=7.5):
    set_seed(int(seed))
    image_pil = Image.fromarray(image)
   
    new_image = pipe(prompt=prompt, image=image_pil, strength=strength, guidance_scale=guidance_scale, num_inference_steps=steps).images[0]
    return new_image


def read_image(dataset_name, image_path, prompt_type, max_size=1024, min_size=512, fix_size=False):
    # Load Image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Scale image
    longer_side = max(image.shape[0], image.shape[1])
    if (longer_side < min_size) or ((longer_side >= min_size) and fix_size):
        scale = min_size / longer_side
        new_height = round(image.shape[0] * scale)
        new_width = round(image.shape[1] * scale)
        image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

    # crop image
    height, width = image.shape[:2]
    height = height if height < max_size else max_size
    width = width if width < max_size else max_size
    transform = create_crop_transforms(height=height, width=width)
    image = transform(image=image)["image"]

    original_shape = image.shape
    new_height = find_nearest_multiple(original_shape[0], multiple=8)
    new_width = find_nearest_multiple(original_shape[1], multiple=8)
    new_image = np.zeros(shape=(new_height, new_width, 3), dtype=image.dtype)
    new_image[:original_shape[0], :original_shape[1]] = image

    del transform
    del image
    gc.collect()

    # load image caption
    json_path = image_path.replace(f'{dataset_name}/{dataset_name}', f'{dataset_name}/{dataset_name}_caption/{prompt_type}')
    json_path = os.path.splitext(json_path)[0] + ".json"
    with open(json_path, 'r') as f:
        metadata = json.load(f)
        caption = metadata['caption']

    return new_image, original_shape, caption


def func(pipe, dataset_name, image_path, inpaint_save_path, crop_save_path, prompt_type, step=50, max_size=1024, min_size=512, fix_size=False, strength=0.6, guidance_scale=7.5):

    image, original_shape, caption = read_image(dataset_name, image_path, prompt_type, max_size, min_size, fix_size)
    
    new_image = stable_diffusion_img2img(pipe, image, prompt=caption[0].strip(), steps=step, seed=2023, strength=strength, guidance_scale=guidance_scale)
    
    new_image = new_image.crop(box=(0, 0, original_shape[1], original_shape[0]))
    new_image.save(inpaint_save_path)
    if not os.path.exists(crop_save_path):
        image = Image.fromarray(image).crop(box=(0, 0, original_shape[1], original_shape[0]))
        image.save(crop_save_path)


if __name__ == '__main__':
    # load stable diffusion models
    device = 'cuda'
    root = 'dataset'

    sd_model_name = "runwayml/stable-diffusion-v1-5"

    pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        sd_model_name, 
        torch_dtype=torch.float16, 
        safety_checker=None)
    pipe.enable_xformers_memory_efficient_attention()
    # pipe.enable_model_cpu_offload()
    pipe = pipe.to(device)
    print(f"Load model successful:{sd_model_name}")

    # Create reconstructed images for the GenImage dataset.
    GenImage_LIST = [
        'stable_diffusion_v_1_4/imagenet_ai_0419_sdv4',
        'stable_diffusion_v_1_5/imagenet_ai_0424_sdv5',
        'ADM/imagenet_ai_0508_adm',
        'BigGAN/imagenet_ai_0419_biggan',
        'glide/imagenet_glide',
        'Midjourney/imagenet_midjourney',
        'VQDM/imagenet_ai_0419_vqdm',
        'wukong/imagenet_ai_0424_wukong',
    ]

    dataset_name = "Genimage"
    prompt_type = "BLIP"
    max_size = 1024
    min_size = 512
    fix_size = True
    step = 50
    strength = 0.5
    guidance_scale = 7.5

    recon_name = f"{int(guidance_scale*10)}_{int(strength*10)}_{step}_"
    if fix_size:
        recon_name += f"fix{min_size}"
    else:
        recon_name += f"over{min_size}"
    phase_lst = ['val']
    label_lst = ['ai','nature']
    model_index_lst = [0,1,2,3,4,5,6,7]

    failed_num = 0
    for model_index in model_index_lst:
        for phase in phase_lst:
            for label in label_lst:
                if dataset_name == "Genimage":
                    model_name = GenImage_LIST[model_index]
                    image_root = f'{root}/{dataset_name}/{dataset_name}/{model_name}/{phase}/{label}'
                save_root = f'{root}/{dataset_name}/{dataset_name}_recon2/w_{prompt_type}/{recon_name}/{model_name}'
                inpaint_root = f'{save_root}/inpainting/{phase}/{label}'
                crop_root = f'{save_root}/crop/{phase}/{label}'

                os.makedirs(inpaint_root, exist_ok=True)
                os.makedirs(crop_root, exist_ok=True)
                image_paths = sorted(glob(f"{image_root}/*.*"))
                print(f'image_root:{model_name}/{phase}/{label}, {len(image_paths)} images')


                total = len(image_paths)
                for i, image_path in enumerate(tqdm(image_paths)):
                    image_name = os.path.basename(image_path).split('.')[0]
                    inpaint_save_path = os.path.join(inpaint_root, image_name + '.png')
                    crop_save_path = os.path.join(crop_root, image_name + '.png')
                    if os.path.exists(inpaint_save_path) and os.path.exists(crop_save_path):
                        continue
                    try:
                        func(pipe,
                            dataset_name,
                            image_path,
                            inpaint_save_path,
                            crop_save_path,
                            prompt_type=prompt_type,
                            step=step,
                            max_size=max_size,
                            min_size=min_size,
                            fix_size=fix_size,
                            strength=strength,
                            guidance_scale=guidance_scale)
                    except:
                        failed_num += 1
                        print(f'failed')
    print(f'Inference finished! failed_num:{failed_num}')