import os
import torch
import argparse
import numpy as np

from PIL import Image
from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler
from sde_inversion import set_seed


def center_crop(width, height, img):
    img = np.array(img).astype(np.uint8)
    crop = np.min(img.shape[:2])
    img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
          (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2]
    try:
        img = Image.fromarray(img, 'RGB')
    except:
        img = Image.fromarray(img)
    img = img.resize((width, height))
    return img


def mask_inversion(mask ):
    mask = np.array(mask)

    condition_gt0 = (mask >= 128)
    condition_eq0 = (mask < 127)

    mask[condition_gt0] = 0
    mask[condition_eq0] = 255
    return Image.fromarray(mask)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--img_path",
        type=str,
        default='../data/Places/val_large',
    )
    parser.add_argument(
        "--mask_path",
        type=str,
        default='../data/Places/masks_small_places_val_512/masks_val_512_small_eval',
    )
    parser.add_argument(
        "--begin",
        type=int,
        default=0,
    )
    parser.add_argument(
        "--num",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=100,
    )
    parser.add_argument(
        "--algorithm",
        type=str,
        default="sde",
    )
    parser.add_argument(
        "--order",
        type=int,
        default=2,
    )

    opt = parser.parse_args()
    return opt

def main():
    opt = get_args()
    seed = 1234
    set_seed(seed)

    pipeline = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
    )
    pipeline = pipeline.to("cuda")

    if opt.algorithm == 'sde':
        algorithm = 'sde-dpmsolver++'
    elif opt.algorithm == 'ode':
        algorithm = 'dpmsolver++'
    else:
        raise NotImplementedError(opt.algorithm)
    num_inference_steps = opt.steps

    scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    scheduler.config.algorithm_type = algorithm
    pipeline.scheduler = scheduler
    scheduler.config.solver_order = opt.order

    for index in range(opt.begin, opt.begin + opt.num):
        if index > 36499:
            break

        image_name = 'Places365_val_' + str(index + 1).zfill(8)
        image_path = os.path.join(opt.img_path, f'{image_name}.jpg')
        image = center_crop(512, 512, Image.open(image_path))

        mask_path = os.path.join(opt.mask_path, f'{str(index).zfill(6)}.png')
        mask = Image.open(mask_path)
        mask = mask_inversion(mask)

        mask_numpy = np.array(mask)
        assert mask_numpy.shape[0] == mask_numpy.shape[1] == 512

        prompt = "photograph of a beautiful empty scene, highest quality settings"
        image, _ = pipeline(prompt=prompt, image=image, mask_image=mask, return_dict=False,
                            num_inference_steps=num_inference_steps)

        type = 'small' if 'small' in opt.mask_path else 'large'
        path = os.path.join('inpainting_evaluation', type, f'{algorithm}-order={opt.order}-{num_inference_steps}', )
        os.makedirs(path, exist_ok=True)
        image[0].save(os.path.join(path, f'{index}.png'))

if __name__ == "__main__":
    main()