#-------------
# 2025.09.23
# Implementation of ControlSwap
# env: env_controlswap.yaml
# run: python controlswap.py
#-------------

from typing import Tuple, Union, Optional, List

import os
import numpy as np
from PIL import Image
from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim.adamw import AdamW
from torch.optim.sgd import SGD
import torch_optimizer as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline
from transformers import CLIPTokenizer, CLIPTextModelWithProjection
from IPython.display import display, clear_output

import torch.nn.functional as F

T = torch.Tensor
TN = Optional[T]
TS = Union[Tuple[T, ...], List[T]]

device = torch.device("cuda")

writer = SummaryWriter(log_dir="./logs/")


def load_1024(image_path: str, left=0, right=0, top=0, bottom=0):
    image = np.array(Image.open(image_path))[:, :, :3]    
    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = np.array(Image.fromarray(image).resize((1024, 1024)))
    return image


@torch.no_grad()
def get_text_embeddings0(pipe: StableDiffusionXLPipeline, text: str) -> T:
    tokens = pipe.tokenizer(
        [text], 
        padding="max_length", 
        max_length=77, 
        truncation=True,
        return_tensors="pt", 
        return_overflowing_tokens=True
    ).input_ids.to(device)
    return pipe.text_encoder(tokens).last_hidden_state.detach()

@torch.no_grad()
def get_text_embeddings(pipe: StableDiffusionXLPipeline, text: str) -> T:
    prompt_embeds, _, text_embeds, _ = pipe.encode_prompt(
        prompt=text,
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=True,
        negative_prompt="",
    )
    return prompt_embeds, text_embeds  # shape: [2, 77, 1280]

@torch.no_grad()
def denormalize(image):
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).astype(np.uint8)
    return image[0]


@torch.no_grad()
def decode(latent: T, pipe: StableDiffusionXLPipeline, im_cat: TN = None):
    with torch.autocast(device_type="cuda", dtype=torch.float32):
        image = pipe.vae.decode((1 / 0.18215) * latent, return_dict=False)[0]
        image = denormalize(image)
        if im_cat is not None:
            image = np.concatenate((im_cat, image), axis=1)
    return Image.fromarray(image)


def init_pipe(device, dtype, unet, scheduler) -> Tuple[UNet2DConditionModel, T, T]:
    with torch.inference_mode():
        alphas = torch.sqrt(scheduler.alphas_cumprod).to(device, dtype=dtype)
        sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(device, dtype=dtype)
    for p in unet.parameters():
        p.requires_grad = False
    return unet, alphas, sigmas


def get_time_ids(height=1024, width=1024, batch_size=4, dtype=torch.float16):
    return torch.tensor(
        [[height, width, 0, 0, height, width]], 
        device=device, 
        dtype=dtype,
    ).repeat(batch_size, 1)


class ControlLoss:
 
    def noise_input(self, z, eps=None, timestep: Optional[int] = None):
        if timestep is None:
            b = z.shape[0]
            timestep = torch.randint(
                low=self.t_min,
                high=min(self.t_max, 1000) - 1,  # Avoid the highest timestep.
                size=(b,),
                device=z.device, dtype=torch.long)
        if eps is None:
            eps = torch.randn_like(z)
        alpha_t = self.alphas[timestep, None, None, None]
        sigma_t = self.sigmas[timestep, None, None, None]
        z_t = alpha_t * z + sigma_t * eps
        return z_t, eps, timestep, alpha_t, sigma_t

    def get_eps_prediction(self, z_t: T, timestep: T, text_embeddings: T, text_embeddings_negative: T, text_embeds: T,
                           alpha_t: T, sigma_t: T, get_raw=False, guidance_scale=7.5):

        latent_input = torch.cat([z_t] * 2)
        timestep = torch.cat([timestep] * 2)
        embedd = text_embeddings.permute(1, 0, 2, 3).reshape(-1, *text_embeddings.shape[2:])
        embedd_negative = text_embeddings_negative.permute(1, 0, 2, 3).reshape(-1, *text_embeddings_negative.shape[2:])
        
        #print("embedd.shape:", embedd.shape) #[4, 77, 2048]
        #print("timestep:", timestep)
        #print("latent_input.shape:", latent_input.shape) #[4, 4, 128, 128]
        #print("text_embeds.shape:", text_embeds.shape) #[4, 1280]
        #print("time_ids:", get_time_ids())
              
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            e_t = self.unet(
                latent_input, 
                timestep, 
                encoder_hidden_states=embedd,
                added_cond_kwargs={
                    "text_embeds": text_embeds,
                    "time_ids": get_time_ids(),
                },
                return_dict=False,
            )[0]
            if self.prediction_type == 'v_prediction':
                e_t = torch.cat([alpha_t] * 2) * e_t + torch.cat([sigma_t] * 2) * latent_input
            e_t_uncond, e_t = e_t.chunk(2)
            if get_raw:
                return e_t_uncond, e_t
            e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
            assert torch.isfinite(e_t).all()
        if get_raw:
            return e_t
        return e_t
  
    def get_control_loss(self, z_source: T, z_target: T, text_emb_source: T, text_emb_target: T, text_emb_negative: T,
                    pooled: T, pooled_negative: T, eps=None, reduction='mean', symmetric: bool = False, calibration_grad=None,
                    timestep: Optional[int] = None, guidance_scale=10, raw_log=False, iter: int = None,
                    region_mask=None, mask_source=None, sub_mask_source=None, seg_mask_source=None) -> TS:
        with torch.inference_mode():
            skip = 3
            if iter % skip == 0:
                z_t_source, eps, timestep, alpha_t, sigma_t = self.noise_input(z_source, eps, timestep)
                z_t_target, _, _, _, _ = self.noise_input(z_target, eps, timestep)
                
                text_embeds = torch.cat([pooled_negative, pooled], dim=0)
                
                eps_pred = self.get_eps_prediction(torch.cat((z_t_source, z_t_target)),
                                                    torch.cat((timestep, timestep)),
                                                    torch.cat((text_emb_source, text_emb_target)),
                                                    text_emb_negative,
                                                    text_embeds,
                                                    torch.cat((alpha_t, alpha_t)),
                                                    torch.cat((sigma_t, sigma_t)),
                                                    guidance_scale=guidance_scale)
                eps_pred_source, eps_pred_target = eps_pred.chunk(2)
                
                """mask process begin"""
                res = int((eps_pred_target.shape[3]))  # 64
                if iter < 550: ######
                    sub_mask_source[sub_mask_source > 0] = 1
                    sub_mask_fuse = F.interpolate(
                        sub_mask_source.unsqueeze(0).unsqueeze(0), (res, res)
                    )
                    seg_mask_source[seg_mask_source > 0] = 1
                    seg_mask_fuse = F.interpolate(
                        seg_mask_source.unsqueeze(0).unsqueeze(0), (res, res)
                    )
                    grad1 = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (eps_pred_target - eps_pred_source)   
                    grad1 = grad1 * sub_mask_fuse
                    
                    grad2 = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * eps_pred_target
                    grad2 = grad2 * seg_mask_fuse
                    
                    grad = grad1 + grad2
                else:
                    mask_source[mask_source > 0] = 1
                    mask_fuse = F.interpolate(
                        mask_source.unsqueeze(0).unsqueeze(0), (res, res)
                    )
                    grad = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (eps_pred_target - eps_pred_source)
                    grad = grad * mask_fuse
                self.former_grad = grad.clone()
            else:
                grad = self.former_grad.clone()
        loss = z_target * grad.clone()
        if symmetric:
            loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])
            loss_symm = self.rescale * z_source * (-grad.clone())
            loss += loss_symm.sum() / (z_target.shape[2] * z_target.shape[3])
        elif reduction == 'mean':
            loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])
        log_loss = loss.clone()
        return loss, log_loss

    def get_grad(self):
        return self.former_grad

    def __init__(self, device, pipe: StableDiffusionXLPipeline, dtype=torch.float32):
        self.t_min = 50
        self.t_max = 950
        self.alpha_exp = 0
        self.sigma_exp = 0
        self.dtype = dtype
        self.unet, self.alphas, self.sigmas = init_pipe(device, dtype, pipe.unet, pipe.scheduler)
        self.prediction_type = pipe.scheduler.prediction_type
        self.former_grad = None
            
            
classes = [
    "dog6",
    "backpack_dog",
    "duck_toy",
    "monster_toy",
    "poop_emoji",
    "clock",
    "pink_sunglasses",
    "robot_toy",
    #"teapot",
    #"cat",
]

diff_prompt_target = {
    "dog6": "sks dog playing soccer, with a soccerball",
    "backpack_dog": "sks backpack, in purple color",
    "duck_toy": "sks duck toy",
    "monster_toy": "sks monster toy, with two arms raised",
    "poop_emoji": "sks poop emoji, in yellow color",
    "clock": "sks clock, made of wood, with a natural wooden texture",
    "pink_sunglasses": "sks pink sunglasses, shining brightly under the sunlight",
    "robot_toy": "sks robot toy, with a rusty metallic body",
    #"teapot": "sks teapot, made of transparent glass",
    #"cat": "sks cat",
}

class_name = classes[0] ######
insert_token = diff_prompt_target[class_name]
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_path = "./diffusers/examples/dreambooth/" + class_name ######

pipeline = StableDiffusionXLPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to(device)
pipeline.load_lora_weights(model_path)

refiner_model_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    refiner_model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
).to(device)


def image_optimization(pipe: StableDiffusionXLPipeline, image: np.ndarray, text_source: str, text_target: str, text_negative: str,
                       num_iters=200, region_mask=None, mask_source=None, sub_mask_source=None, seg_mask_source=None, output_dir=None) -> None:
    control_loss = ControlLoss(device, pipe)
    image_source = torch.from_numpy(image).float().permute(2, 0, 1) / 127.5 - 1
    image_source = image_source.unsqueeze(0).to(device).half()
    with torch.no_grad():
        image_source_f32 = image_source.float()
        with torch.autocast(device_type="cuda", dtype=torch.float32):
            z_source = pipe.vae.encode(image_source_f32)['latent_dist'].mean * 0.18215
        #print("z_source:", z_source.dtype)
        image_target = image_source.clone()
        embedding_null, embedding_null_polled = get_text_embeddings(pipeline, "")
        embedding_text, embedding_text_polled = get_text_embeddings(pipeline, text_source)
        embedding_text_target, embedding_text_target_polled = get_text_embeddings(pipeline, text_target)
        embedding_source = torch.stack([embedding_null, embedding_text], dim=1)
        embedding_target = torch.stack([embedding_null, embedding_text_target], dim=1)
        
        embedding_text_negative, embedding_text_negative_polled = get_text_embeddings(pipeline, text_negative)
        embedding_negative = torch.stack([embedding_null, embedding_text_negative], dim=1)
        
        embedding_target_pooled = torch.cat([embedding_null_polled, embedding_text_target_polled], dim=0)
        embedding_negative_pooled = torch.cat([embedding_null_polled, embedding_text_negative_polled], dim=0)
        
    image_target.requires_grad = True

    z_taregt = z_source.clone()
    z_taregt.requires_grad = True
    base_optimizer = SGD(params=[z_taregt], lr=1e-1)
    optimizer = optim.Lookahead(base_optimizer, k=5, alpha=0.5)

    for i in tqdm(range(num_iters)):
        loss, log_loss = control_loss.get_control_loss(z_source, z_taregt, embedding_source, embedding_target, embedding_negative,
                                               embedding_target_pooled, embedding_negative_pooled, iter=i,
                                               region_mask=region_mask, mask_source=mask_source,
                                               sub_mask_source=sub_mask_source, seg_mask_source=seg_mask_source)
        optimizer.zero_grad()
        (5000 * loss).backward()
        optimizer.step()
        
        # TensorBoard
        writer.add_scalar("Loss/log_loss", log_loss.item(), i+1)
        
    out = decode(z_taregt, pipeline, im_cat=image)
    return out


def mask_find_bboxs(mask):
    non_zero_indices = torch.nonzero(mask > 0)
    top_left = (
        non_zero_indices[:, 0].min().item() / mask.shape[0],
        non_zero_indices[:, 1].min().item() / mask.shape[0],
    )
    bottom_right = (
        non_zero_indices[:, 0].max().item() / mask.shape[0],
        non_zero_indices[:, 1].max().item() / mask.shape[0],
    )
    return top_left[0], bottom_right[0], top_left[1], bottom_right[1]


# image optimization

def load_prompt_as_string(filepath: str) -> str:
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

def fill_placeholders(prompts, insert_token):
    return [prompt.replace("{ }", insert_token) for prompt in prompts]

source_prompt_path = "./ConSwapBench/source_prompt.txt"
target_prompt_path = "./ConSwapBench/target_prompt.txt"
source_prompts = load_prompt_as_string(source_prompt_path)
target_prompts = load_prompt_as_string(target_prompt_path)
source_prompt_list = source_prompts.splitlines()
target_prompt_list_empty = target_prompts.splitlines()
target_prompt_list = fill_placeholders(target_prompt_list_empty, insert_token)
print("Filled Prompts:", target_prompt_list)

output_dir = "./results/zebra"

text_source = "a zebra eating grass in a field"
text_target = "sks dog playing soccerk, with a soccerball in a field"
text_negative = "low quality, worst quality, blurry, bad anatomy"

path = "./ConSwapBench/zebra/"
jpg_files = [
    f for f in os.listdir(path)
    if f.lower().endswith(".jpg") 
    and f.lower() != "gt_bbox.jpg"
    and not f.lower().startswith("square")
]

image_path = path + jpg_files[0]
image = load_1024(image_path)

mask_path = path + "GT_bbox.jpg"
sub_mask_path = path + "GT_sub.png"
seg_mask_path = path + "GT_seg.png"
transform = transforms.Compose([transforms.ToTensor()])

mask_image = Image.open(mask_path)
image_tensor = transform(mask_image).to(device)
mask_source = image_tensor.mean(dim=0)  # torch.Size([1024, 1024])

sub_mask_image = Image.open(sub_mask_path)
sub_image_tensor = transform(sub_mask_image).to(device)
sub_mask_source = sub_image_tensor.mean(dim=0)

seg_mask_image = Image.open(seg_mask_path)
seg_image_tensor = transform(seg_mask_image).to(device)
seg_mask_source = seg_image_tensor.mean(dim=0)

region_mask = mask_find_bboxs(mask_source)

out = image_optimization(pipeline, image, text_source, text_target, text_negative,
                num_iters=1150, region_mask=region_mask, mask_source=mask_source,
                sub_mask_source=sub_mask_source, seg_mask_source=seg_mask_source,
                output_dir=output_dir)

# Refiner
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")
pipe.load_lora_weights(model_path)
out = out.crop((1024, 0, 2048, 1024))

mask = Image.open(mask_path).convert("L")
mask = mask.resize((1024, 1024), resample=Image.LANCZOS)

output = pipe(
    prompt=text_target,
    image=out,
    mask_image=mask,
    strength=0.3,
    num_inference_steps=25,
    guidance_scale=7.5,
).images[0]

output.save(output_dir+"/output_image2.png") ######