import os
import torch
import torch
from PIL import Image
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.stable_diffusion import (
    load_sd_components,
    load_text_components,
    generate_images,
    compute_text_embedding,
)
from utils.datasets import transform_image, load_and_encode_image
from tqdm import tqdm
from torchvision import transforms

from torchmetrics.functional import pairwise_cosine_similarity

torch_device = "cuda"

import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--img_path",
        type=str,
        default="/home/datasets/coco2014/train2014/COCO_train2014_000000403953.jpg",
    )
    parser.add_argument(
        "--prompt", type=str, default="A woman is in a kitchen doing work."
    )
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--steps", type=int, default=50)
    parser.add_argument("--thrs", type=float, default=220)
    return parser.parse_args()


def sscd_between(original: torch.Tensor, generated: torch.Tensor) -> float:
    sscd_m = torch.jit.load("sscd_disc_mixup.torchscript.pt").cuda().eval()
    torch.set_num_threads(4)

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    skew_320 = transforms.Compose(
        [
            transforms.Resize([320, 320]),
            normalize,
        ]
    )

    features_before = sscd_m(skew_320(original.to(torch_device))).cpu()
    features_after = sscd_m(skew_320(generated.to(torch_device))).cpu()

    cosine_similarity = (
        pairwise_cosine_similarity(features_before, features_after)[0].max().item()
    )

    return cosine_similarity


def load_everything():
    vae, unet, scheduler = load_sd_components("v1-4")
    tokenizer, text_encoder = load_text_components("v1-4")

    vae.to(torch_device)
    text_encoder.to(torch_device)
    unet.to(torch_device)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    return vae, unet, scheduler, tokenizer, text_encoder


def get_latents_and_embedding(
    img_path, prompt, batch_size, vae, tokenizer, text_encoder
):
    latents = load_and_encode_image(img_path, vae)
    latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)
    latents = latents.to(torch_device)
    embedding_optim = (
        compute_text_embedding(prompt, tokenizer, text_encoder)
        .to(torch_device)
        .requires_grad_(True)
    )

    return latents, embedding_optim


def evaluate_sscd(
    img_path, embedding_optim, tokenizer, text_encoder, vae, unet, scheduler
):
    text_embedding_optim = embedding_optim.detach().clone()

    images = generate_images(
        None,
        tokenizer,
        text_encoder,
        vae,
        unet,
        scheduler,
        text_embedding_optim,
        50,
        10,
        7,
        5,
    )

    gen_tensors = torch.stack(
        [transforms.ToTensor()(img) for img in images], dim=0
    ).unsqueeze(1)
    for gen_tensor in gen_tensors:
        print(
            sscd_between(
                transforms.ToTensor()(Image.open(img_path)).unsqueeze(0), gen_tensor
            )
        )


def run_constrained_optimization(
    img_path, latents, embedding_optim, steps, batch_size, thrs, scheduler, unet
):
    optimizer = torch.optim.Adam([embedding_optim], lr=0.1)
    generator = torch.Generator(device=latents.device).manual_seed(1)

    pb = tqdm(range(steps))
    for step in pb:
        noise = torch.empty_like(latents).normal_(generator=generator)
        timesteps = torch.randint(
            0,
            scheduler.config.num_train_timesteps,
            (batch_size,),
            device=latents.device,
            generator=generator,
        )
        timesteps = timesteps.long()

        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        embedding_repeated = torch.repeat_interleave(
            embedding_optim, dim=0, repeats=batch_size
        )

        model_pred = unet(
            noisy_latents, timesteps, embedding_repeated, return_dict=False
        )[0]
        loss = torch.nn.functional.mse_loss(
            model_pred.float(), noise.float(), reduction="mean"
        )

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        coef = embedding_optim.norm() / min(embedding_optim.norm(), thrs)
        embedding_optim.data /= coef

        pb.set_postfix(loss=loss.item(), norm=embedding_optim.norm().item())
        if step % 1000 == 0:
            print("_" * 100)
            print(f"Step {step}")
            evaluate_sscd(
                img_path,
                embedding_optim,
                tokenizer,
                text_encoder,
                vae,
                unet,
                scheduler,
            )
    evaluate_sscd(
        img_path, embedding_optim, tokenizer, text_encoder, vae, unet, scheduler
    )


if __name__ == "__main__":
    vae, unet, scheduler, tokenizer, text_encoder = load_everything()
    args = parse_args()
    latents, embedding_optim = get_latents_and_embedding(
        args.img_path,
        args.prompt,
        args.batch_size,
        vae,
        tokenizer,
        text_encoder,
    )
    run_constrained_optimization(
        args.img_path,
        latents,
        embedding_optim,
        args.steps,
        args.batch_size,
        args.thrs,
        scheduler,
        unet,
    )
