import cv2
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from functools import partial
from tqdm.auto import tqdm
import math
import sys
from d3dr.diffusion.sd_utils import StableDiffusion

import matplotlib.pyplot as plt
import numpy as np
import argparse

from try_dds_controlnet_diff_init import (
    center_crop_to_square,
    read_image,
    get_args as _get_args,
    masked_mean,
)

def get_args():
    parser = _get_args(do_parse=False)
    parser.add_argument("--image_steps", type=int, default=16)
    parser.add_argument("--latent_steps", type=int, default=16)
    parser.add_argument("--lr_latent", type=float, default=0.1)
    args = parser.parse_args()
    return args

def get_image_shadows(image_full, image_shadows, mask):
    image_full_detach = image_full.detach() # should not change outside of mask, shadows will do it
    # variant 1: softplus
    # shadows = F.softplus(image_shadows) # positive value
    # variant 2: relu
    # shadows = F.relu(image_shadows) # positive value
    # image = image_full * mask + (image_full_detach - shadows) * (1.0 - mask)
    # variant 3: multiplication

    # shadows = 1.0 - image_shadows.clip(0, 1)
    shadows = image_shadows.clip(-1, 1).abs()
    # bg_shadow = (image_full_detach * 0.5 + 0.5) * shadows * 2 - 1
    bg_shadow = (image_full_detach - shadows)
    # image_full * 
    image = image_full * mask + (bg_shadow) * (1.0 - mask)
    return image

def main():
    args = get_args()

    if (args.add_random_noise_mask + args.add_mean_init) > 1:
        raise ValueError("Only one of add_random_noise_mask and add_mean_init can be set to True")

    torch_device = "cuda"
    guidance = StableDiffusion(
        device="cuda",
        sd_version=args.sd_model_name,
        height=args.height,
        width=args.width,
        sd_unet_path=args.sd_unet_path,
        fp16=args.fp16,
        lora_adapters_paths=args.lora_adapters_paths,
        t_range=(0.02, args.t_range_max),
    )


    # Load images
    init_image_nocomp, init_image_torch_nocomp, init_latent_nocomp = read_image(args.image_nocomp_path, args.height, args.width, guidance)
    init_image_comp, init_image_torch_comp, init_latent_comp = read_image(args.image_comp_path, args.height, args.width, guidance)

    # Load mask
    init_mask = (cv2.imread(args.mask_path, cv2.IMREAD_GRAYSCALE))
    init_mask = center_crop_to_square(init_mask)
    init_mask = (cv2.resize(init_mask, (args.height, args.width))[None, None, ...] / 255.0 > 0.5).astype(np.float32)
    with torch.no_grad():
        mask = torch.from_numpy(init_mask).to(device=torch_device, dtype=guidance.precision_t)
        mask_small = F.interpolate(mask, size=(args.height // 8, args.width // 8), mode="bilinear", align_corners=False)

    torch.manual_seed(args.seed)

    # initialize the sds_image (the correct name is dds)
    with torch.no_grad():
        if args.add_mean_init > 0:
            init_image_torch_comp = \
                init_image_torch_comp.to(torch.float32) * (1.0 - mask.to(torch.float32)) + \
                masked_mean(init_image_torch_comp.to(torch.float32), mask.to(torch.float32), dim=(2, 3))[..., None, None] *\
                    mask.to(torch.float32)
            init_image_torch_comp = init_image_torch_comp.to(guidance.precision_t)
        if args.add_random_noise_mask > 0:
            init_image_torch_comp = init_image_torch_comp * (1.0 - mask) + torch.randn_like(init_image_torch_comp) * mask

        init_latent_comp = guidance.torch2latents(init_image_torch_comp)

    sds_image_full = init_image_torch_comp.detach().clone().requires_grad_(True)
    sds_image_shadows = torch.rand((1, 1, 512, 512)).to(sds_image_full) * 0.1
    sds_image_shadows.requires_grad_(True)

    # optimizer and scheduler
    optimizer = torch.optim.SGD([sds_image_full, sds_image_shadows], lr=args.lr)
    opt_scheduler = torch.optim.lr_scheduler.PolynomialLR(
        optimizer, 
        total_iters=args.num_train_iterations, 
        power=args.power
    )

    args.save_dir = guidance.get_save_dir(args.save_dir)
    print("Save dir:", args.save_dir)

    # Load prompts
    prompt_initial = guidance.get_text_embeds(args.prompt_2)
    prompt_desired = guidance.get_text_embeds(args.prompt_1)
    uncond_emb = guidance.get_text_embeds("")
    text_embeddings_initial=torch.cat([uncond_emb, prompt_initial]) 
    text_embeddings_desired=torch.cat([uncond_emb, prompt_desired])

    step_ratio = None
    noise = None
    if args.use_random_noise == 0: # use random noise
        noise = torch.randn_like(init_latent_comp)

    for i in tqdm(range(args.num_train_iterations // args.image_steps)):
        with torch.no_grad():
            image_full = get_image_shadows(sds_image_full, sds_image_shadows, mask)
            opt_latent = guidance.torch2latents(image_full).detach().requires_grad_(False)
            for _ in range(args.latent_steps):
                if args.use_step_ratio != 0:
                    step_ratio = min(1, (1 - args.initial_step) + args.initial_step * i / args.num_train_iterations)

                grad = guidance.train_step_dds(
                    text_embeddings_initial=text_embeddings_initial, 
                    text_embeddings_desired=text_embeddings_desired, 
                    latents_initial=init_latent_nocomp, 
                    rgb_pred=opt_latent, 
                    guidance_scale=args.guidance_scale,
                    as_latent=True, 
                    step_ratio=step_ratio,
                    noise=noise,
                    use_weights=False,
                    return_grad=True,
                )

                opt_latent -= args.lr_latent * grad
            sds_image_1 = guidance.latents2torch(opt_latent)
            torchvision.utils.save_image(sds_image_1 * 0.5 + 0.5, os.path.join(args.save_dir, f"opt_latent_{i}.png"))
            

        for _ in range(args.image_steps):
            optimizer.zero_grad()

            image_full = get_image_shadows(sds_image_full, sds_image_shadows, mask)
            loss = (image_full - sds_image_1).pow(2).sum(dim=(1, 2, 3)).mean()
            loss.backward()

            # one should mask the gradient
            # if args.mask_grad != 0:
            #     sds_image_full.grad = sds_image_full.grad * mask
            # zero_gra
            optimizer.step()
            opt_scheduler.step()

        if ((i + 1) * args.image_steps) // args.show_iter > ((i) * args.image_steps) // args.show_iter:
                # result_image = guidance.torch2np(guidance.latents2torch(sds_image))
                # result_image = guidance.torch2np(guidance.latents2torch(curr_latent))
                image_full = get_image_shadows(sds_image_full, sds_image_shadows, mask)
                result_image = guidance.torch2np(image_full)
                guidance.save_images(
                    images=result_image, 
                    save_name=f"dds_image_{i + 1}.png", 
                    prompt=f"{args.prompt_2} -> {args.prompt_1}",
                    save_dir=args.save_dir,
                    exp_desc=args.exp_desc,
                )

                guidance.save_images(
                    images=guidance.torch2np(sds_image_full), 
                    save_name=f"dds_noshadow_{i + 1}.png", 
                    prompt=f"{args.prompt_2} -> {args.prompt_1}",
                    save_dir=args.save_dir,
                    exp_desc=args.exp_desc,
                )

                # torch.save(sds_image_shadows, os.path.join(args.save_dir, f"shadows_{i}.pth"))

        
if __name__ == "__main__":
    main()
