'''
DDS' with the controlnet embeddings and the depth image as the initial image
'''


from PIL import Image
import cv2
import os
import warnings
from pathlib import Path
import json

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    AutoencoderKL, 
    UNet2DConditionModel, 
    UniPCMultistepScheduler,
    DDIMScheduler,
)

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from tqdm.auto import tqdm
import math
from voi_gs.diffusion_play.sds_like._sd_utils_controlnet import SDControlNet

import matplotlib.pyplot as plt
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--prompt_1", type=str, default="a <ktn> wet floor sign in a room", help="desired prompt")
parser.add_argument("--prompt_2", type=str, default="a room", help="initial prompt")
parser.add_argument("--exp_desc", type=str, default="", help="the description of the experiment")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--sd_model_name", type=str, default="stabilityai/stable-diffusion-2-1-base")
parser.add_argument("--sd_unet_path", type=str, default="/scratch/izar/skorokho/personalization_play/007/unet/", help="path to the unet model (for personalization)")
parser.add_argument("--controlnet_model_name", type=str, default="thibaud/controlnet-sd21-depth-diffusers")
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--width", type=int, default=512)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0, help="the scale of the controlnet conditioning. Something bigger that 1.0 WORKS POOR!")

# /scratch/izar/skorokho/real_images/yes_cup_l1.jpg
parser.add_argument("--image_comp_path", type=str, default=None, help="path to the obj + scene rgb")
parser.add_argument("--image_nocomp_path", type=str, default=None, help="path to the ONLY scene rgb")
parser.add_argument("--controlnet_comp_path", type=str, default=None, help="path to the obj + scene rgb depth/normal/etc for controlnet")
parser.add_argument("--controlnet_nocomp_path", type=str, default=None, help="path to the ONLY scene rgb depth/normal/etc for controlnet")
parser.add_argument("--mask_path", type=str, default=None, help="path to the mask of obj + scene rgb. 1 - obj, 0 - scene")

parser.add_argument("--num_train_iterations", type=int, default=400)
parser.add_argument("--show_iter", type=int, default=20, help="how often to show the generated image")
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--power", type=float, default=0.5, help="power for the polynomial decay")
parser.add_argument("--save_dir", type=str, default="/scratch/izar/skorokho/dds_harmonization_play_output/")

parser.add_argument("--exp", type=str, default="implicit_dds", choices=["optimize_alpha", "implicit_dds", "explicit_dds"], help="experiment type. For some reason optimize_alpha doesn't work")
parser.add_argument("--add_random_noise_mask", action="store_true", help="initialize with random noise")
parser.add_argument("--add_mean_init", action="store_true", help="initialize with mean")
parser.add_argument("--mask_grad", type=int, default=1, help="mask the gradient? (for latent space it will not be correct...)")
parser.add_argument("--use_random_noise", type=int, default=1, help="in SDS one can take random noise or a FIXED noise to predict")
parser.add_argument("--use_step_ratio", type=int, default=0, help="Larger steps at the beginning of the optimization")
parser.add_argument("--initial_step", type=float, default=0.0, help="if use_step_ratio, then we might want to begin from 0.5 (-> timestep 500) instead of 0.0")
parser.add_argument("--fp16", action="store_true")

args = parser.parse_args()

torch_device = "cuda"

# Load the model
guidance = SDControlNet(
    device="cuda",
    sd_version=args.sd_model_name,
    controlnet_name=args.controlnet_model_name,
    height=args.height,
    width=args.width,
    sd_unet_path=args.sd_unet_path,
)

# Load images
init_image_comp_torch = torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/image_comp_path_0.pt", map_location="cuda").to(torch.float32) * 2.0 - 1.0
init_latent_comp = guidance.torch2latents_resize(init_image_comp_torch).to(torch.float32)
init_image_nocomp_torch = torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/image_nocomp_path_0.pt", map_location="cuda").to(torch.float32) * 2.0 - 1.0
init_latent_nocomp = guidance.torch2latents_resize(init_image_nocomp_torch).to(torch.float32)
mask_small = torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/mask_small_0.pt", map_location="cuda").to(torch.float32)

# Load controlnet embeddings
controlnet_nocomp_emb = torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/controlnet_nocomp_emb_0.pt", map_location="cuda").to(torch.float32)
controlnet_comp_emb = torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/controlnet_comp_emb_0.pt", map_location="cuda").to(torch.float32)

torch.manual_seed(args.seed)

# initialize the sds_image (the correct name is dds)
sds_image = init_latent_comp.detach().clone().requires_grad_(True)

# optimizer and scheduler
optimizer = torch.optim.SGD([sds_image], lr=args.lr)
opt_scheduler = torch.optim.lr_scheduler.PolynomialLR(
    optimizer, 
    total_iters=args.num_train_iterations, 
    power=args.power
)

args.save_dir = guidance.get_save_dir(args.save_dir)
print("Save dir:", args.save_dir)
cv2.imwrite(f"{args.save_dir}/comp.jpg", cv2.cvtColor(guidance.torch2np(init_image_comp_torch)[0], cv2.COLOR_RGB2BGR))
cv2.imwrite(f"{args.save_dir}/nocomp.jpg", cv2.cvtColor(guidance.torch2np(init_image_nocomp_torch)[0], cv2.COLOR_RGB2BGR))

# save the parameters of the experiment
with open(Path(args.save_dir) / "args.json", "w") as f:
    json.dump(vars(args), f, indent=4)

# Load prompts
text_embeddings_initial=torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/text_embeddings_initial_0.pt", map_location="cuda").to(torch.float32)
text_embeddings_desired=torch.load("/scratch/izar/skorokho/dds_3dgs_harmonisation/020/text_embeddings_desired_0.pt", map_location="cuda").to(torch.float32)

step_ratio = None
noise = None
if args.use_random_noise == 0: # use random noise
    noise = torch.randn_like(init_latent_comp)

# The main train cycle!
for i in tqdm(range(args.num_train_iterations)):
    # zero_grad
    optimizer.zero_grad()
    if args.use_step_ratio != 0:
        step_ratio = min(1, (1 - args.initial_step) + args.initial_step * i / args.num_train_iterations)
    
    # the main idea is to use dds, but the initialization of the 
    # optimized image is different. Here we use the controlnet
    # and the conditions from the controlnet are different.
    
    # In other case the initial lighting breaks the optimization
    # (check exp 126 when I tried to generate the statue head)
    # /scratch/izar/skorokho/dds_harmonization_play_output/126/dds_image_160_000.png
    loss = guidance.train_step_dds(
        text_embeddings_initial=text_embeddings_initial, 
        text_embeddings_desired=text_embeddings_desired, 
        image_embeddings_initial=controlnet_nocomp_emb, 
        image_embeddings_desired=controlnet_comp_emb, # they are the same
        latents_initial=init_latent_nocomp, # I call it a "pulling latent"
        rgb_pred=sds_image, 
        guidance_scale=args.guidance_scale,
        controlnet_conditioning_scale=args.controlnet_conditioning_scale,
        as_latent=True, 
        step_ratio=step_ratio,
        noise=noise,
        use_weights=False,
    )

    loss.backward()

    # one should mask the gradient
    if args.mask_grad != 0:
        sds_image.grad = sds_image.grad * mask_small

    optimizer.step()
    opt_scheduler.step()

    if (i + 1) % args.show_iter == 0:
        result_image = guidance.torch2np(guidance.latents2torch(sds_image))
        guidance.save_images(
            images=result_image, 
            save_name=f"dds_image_{i + 1}.png", 
            prompt=f"{args.prompt_2} -> {args.prompt_1}",
            save_dir=args.save_dir,
            exp_desc=args.exp_desc,
        )

