from PIL import Image as PILImage
import os
import warnings
from pathlib import Path
import cv2

from voi_gs.diffusion_play.sds_like._sd_utils_controlnet import SDControlNet
import torch
import torchvision

from functools import partial
from tqdm.auto import tqdm
import math

import matplotlib.pyplot as plt
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="A cat")
parser.add_argument("--exp_desc", type=str, default="")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--sd_model_name", type=str, default="stabilityai/stable-diffusion-2-1-base")
parser.add_argument("--sd_unet_path", type=str, default=None)
parser.add_argument("--controlnet_model_name", type=str, default="thibaud/controlnet-sd21-depth-diffusers")
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--controlnet_image_path", type=str, default=None)
parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0)

parser.add_argument("--num_train_iterations", type=int, default=1000)
parser.add_argument("--show_iter", type=int, default=100)
parser.add_argument("--acc_step", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--power", type=float, default=2.0)
parser.add_argument("--save_dir", type=str, default="/scratch/izar/skorokho/sds_play_output/")
parser.add_argument("--use_random_noise", type=int, default=1)
parser.add_argument("--use_step_ratio", type=int, default=0)
parser.add_argument("--initial_step", type=float, default=0.0)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()


torch_device = "cuda"

guidance = SDControlNet(
    device="cuda",
    sd_version=args.sd_model_name,
    controlnet_name=args.controlnet_model_name,
    height=args.height,
    width=args.width,
    sd_unet_path=args.sd_unet_path,
    fp16=args.fp16,
)

torch.manual_seed(args.seed)

sds_image = torch.randn(
    (1, guidance.unet.config.in_channels, args.height // 8, args.width // 8),
    device=torch_device,
)
sds_image = sds_image.detach().clone().requires_grad_(True)

optimizer = torch.optim.Adam([sds_image], lr=args.lr)
opt_scheduler = torch.optim.lr_scheduler.PolynomialLR(
    optimizer, 
    total_iters=args.num_train_iterations // args.acc_step, 
    power=args.power
)

def center_crop_to_square(cv2_image):
    h, w = cv2_image.shape[:2]
    if h > w:
        start = (h - w) // 2
        return cv2_image[start: h - start, ...]
    elif w > h:
        start = (w - h) // 2
        return cv2_image[:, start: w - start, ...]
    else:
        return cv2_image

controlnet_image = cv2.imread(args.controlnet_image_path)
controlnet_image = center_crop_to_square(controlnet_image) # depth in my case
controlnet_image = cv2.resize(controlnet_image, (args.height, args.width)) / 255.0
controlnet_image_emb = guidance.get_image_embeds(controlnet_image).to(torch_device)
print(controlnet_image_emb.shape)
torchvision.utils.save_image(controlnet_image_emb, "controlnet_image_emb.jpg")

real_emb = guidance.get_text_embeds(args.prompt)
uncond_emb = guidance.get_text_embeds("")
text_embs = torch.cat([uncond_emb, real_emb])

args.save_dir = guidance.get_save_dir(args.save_dir)
print("Save dir:", args.save_dir)

noise = None
step_ratio = None
if args.use_random_noise == 0: # use random noise
    noise = torch.randn_like(sds_image)


for i in tqdm(range(args.num_train_iterations + 1)):
    # zero_grad
    optimizer.zero_grad()
    if args.use_step_ratio != 0:
        step_ratio = min(1, (1 - args.initial_step) + args.initial_step * i / args.num_train_iterations)
    
    loss = guidance.train_step(
        text_embeddings=text_embs,
        image_embeddings=controlnet_image_emb,
        pred_rgb=sds_image,
        guidance_scale=args.guidance_scale,
        controlnet_conditioning_scale=args.controlnet_conditioning_scale,
        as_latent=True,
        step_ratio=step_ratio,
        noise=noise,
    )

    loss.backward()

    optimizer.step()
    opt_scheduler.step()

    if (i + 1) % args.show_iter == 0:
        result_image = guidance.torch2np(guidance.latents2torch(sds_image))
        guidance.save_images(
            images=result_image, 
            save_name=f"sds_image_{i + 1}.png", 
            prompt=args.prompt,
            save_dir=args.save_dir,
            exp_desc=args.exp_desc,
        )
