'''
DDS + Diff Init
(I have a little of comments here. Please check out try_dds_controlnet_diff_init.py for more comments)

python3 sds_like/try_dds_diff_init.py \
    --prompt_1 "a <ktn> single-color parrot statue in a living room" \
    --prompt_2 "a living room" \
    --lora_adapters_paths "/scratch/izar/skorokho/voi_3dgs_all_3/f4_lin_dec_living_room_1_000/personalization_object_1/" \
    --image_nocomp_path /scratch/izar/skorokho/voi_3dgs_all_3/f4_lin_dec_living_room_1_000/rendering/real_rgb_scene/00000.png \
    --image_comp_path /scratch/izar/skorokho/voi_3dgs_all_3/f4_lin_dec_living_room_1_000/rendering/initial_rgb_obj_scene/00000.png \
    --mask_path /scratch/izar/skorokho/voi_3dgs_all_3/f4_lin_dec_living_room_1_000/rendering/voi_rgb_obj_scene/00000.png

python3 sds_like/try_dds_diff_init.py \
    --prompt_1 "<rare_token> two dustbins in front of an outdoor container" \
    --prompt_2 "an outdoor container" \
    --lora_adapters_paths "/scratch/izar/skorokho/personalization_play/121/" \
    --image_nocomp_path "/scratch/izar/skorokho/test/dustbins_ifo_container/images_scene/00075.png" \
    --image_comp_path "/scratch/izar/skorokho/test/dustbins_ifo_container/images/00075.png" \
    --mask_path "/scratch/izar/skorokho/test/dustbins_ifo_container/masks/00075.jpg"

python3 sds_like/try_dds_diff_init.py \
    --prompt_1 "<ktn> caution wet floor sign in a bathroom" \
    --prompt_2 "a bathroom" \
    --lora_adapters_paths "/scratch/izar/skorokho/voi_3dgs_all_3/sd15_bathroom_1_001/personalization_object_1/" \
    --image_nocomp_path "/scratch/izar/skorokho/voi_3dgs_all_3/sd15_bathroom_1_001/rendering/real_rgb_scene/00180.png" \
    --image_comp_path "/scratch/izar/skorokho/voi_3dgs_all_3/sd15_bathroom_1_001/rendering/initial_rgb_obj_scene/00180.png" \
    --mask_path "/scratch/izar/skorokho/voi_3dgs_all_3/sd15_bathroom_1_001/rendering/initial_mask_obj_scene/00180.jpg"

    
python3 sds_like/try_dds_diff_init.py \
    --prompt_1 "a cup on a plate" \
    --prompt_2 "a plate" \
    --image_nocomp_path "/scratch/izar/skorokho/real_images/no_cup_l1.jpg" \
    --image_comp_path "/scratch/izar/skorokho/real_images/cupl2l1.jpg" \
    --mask_path "/scratch/izar/skorokho/real_images/mask_composition.png"
'''

from PIL import Image
import cv2
import os
import warnings
from pathlib import Path
import json

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    AutoencoderKL, 
    UNet2DConditionModel, 
    UniPCMultistepScheduler,
    DDIMScheduler,
)

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from tqdm.auto import tqdm
import math
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../dn-splatter/dn_splatter/diffusion"))
from sd_utils import StableDiffusion

import matplotlib.pyplot as plt
import numpy as np
import argparse

from try_dds_controlnet_diff_init import (
    center_crop_to_square,
    read_image,
    get_args,
    masked_mean,
)

def main():
    args = get_args()

    if (args.add_random_noise_mask + args.add_mean_init) > 1:
        raise ValueError("Only one of add_random_noise_mask and add_mean_init can be set to True")

    torch_device = "cuda"
    guidance = StableDiffusion(
        device="cuda",
        sd_version=args.sd_model_name,
        height=args.height,
        width=args.width,
        sd_unet_path=args.sd_unet_path,
        fp16=args.fp16,
        lora_adapters_paths=args.lora_adapters_paths,
        t_range=(0.02, args.t_range_max),
    )


    # Load images
    init_image_nocomp, init_image_torch_nocomp, init_latent_nocomp = read_image(args.image_nocomp_path, args.height, args.width, guidance)
    init_image_comp, init_image_torch_comp, init_latent_comp = read_image(args.image_comp_path, args.height, args.width, guidance)

    # Load mask
    init_mask = (cv2.imread(args.mask_path, cv2.IMREAD_GRAYSCALE))
    init_mask = center_crop_to_square(init_mask)
    init_mask = (cv2.resize(init_mask, (args.height, args.width))[None, None, ...] / 255.0 > 0.5).astype(np.float32)
    with torch.no_grad():
        mask = torch.from_numpy(init_mask).to(device=torch_device, dtype=guidance.precision_t)
        mask_small = F.interpolate(mask, size=(args.height // 8, args.width // 8), mode="bilinear", align_corners=False)

    torch.manual_seed(args.seed)

    # initialize the sds_image (the correct name is dds)
    with torch.no_grad():
        if args.add_mean_init > 0:
            init_image_torch_comp = \
                init_image_torch_comp.to(torch.float32) * (1.0 - mask.to(torch.float32)) + \
                masked_mean(init_image_torch_comp.to(torch.float32), mask.to(torch.float32), dim=(2, 3))[..., None, None] *\
                    mask.to(torch.float32)
            init_image_torch_comp = init_image_torch_comp.to(guidance.precision_t)
        if args.add_random_noise_mask > 0:
            init_image_torch_comp = init_image_torch_comp * (1.0 - mask) + torch.randn_like(init_image_torch_comp) * mask

        init_latent_comp = guidance.torch2latents(init_image_torch_comp)

    sds_image = init_latent_comp.detach().clone().requires_grad_(True)

    # optimizer and scheduler
    optimizer = torch.optim.SGD([sds_image], lr=args.lr)
    opt_scheduler = torch.optim.lr_scheduler.PolynomialLR(
        optimizer, 
        total_iters=args.num_train_iterations, 
        power=args.power
    )

    args.save_dir = guidance.get_save_dir(args.save_dir)
    print("Save dir:", args.save_dir)

    # Load prompts
    prompt_initial = guidance.get_text_embeds(args.prompt_2)
    prompt_desired = guidance.get_text_embeds(args.prompt_1)
    uncond_emb = guidance.get_text_embeds("")
    text_embeddings_initial=torch.cat([uncond_emb, prompt_initial]) 
    text_embeddings_desired=torch.cat([uncond_emb, prompt_desired])

    step_ratio = None
    noise = None
    if args.use_random_noise == 0: # use random noise
        noise = torch.randn_like(init_latent_comp)

    for i in tqdm(range(args.num_train_iterations)):
        # zero_grad
        optimizer.zero_grad()
        if args.use_step_ratio != 0:
            step_ratio = min(1, (1 - args.initial_step) + args.initial_step * i / args.num_train_iterations)

        loss = guidance.train_step_dds(
            text_embeddings_initial=text_embeddings_initial, 
            text_embeddings_desired=text_embeddings_desired, 
            latents_initial=init_latent_nocomp, 
            rgb_pred=sds_image, 
            guidance_scale=args.guidance_scale,
            as_latent=True, 
            step_ratio=step_ratio,
            noise=noise,
            use_weights=False,
        )

        loss.backward()

        # one should mask the gradient
        if args.mask_grad != 0:
            sds_image.grad = sds_image.grad * mask_small

        optimizer.step()
        opt_scheduler.step()

        if (i + 1) % args.show_iter == 0:
            result_image = guidance.torch2np(guidance.latents2torch(sds_image))
            guidance.save_images(
                images=result_image, 
                save_name=f"dds_image_{i + 1}.png", 
                prompt=f"{args.prompt_2} -> {args.prompt_1}",
                save_dir=args.save_dir,
                exp_desc=args.exp_desc,
            )
if __name__ == "__main__":
    main()
