import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn.functional as F
from utils.stable_diffusion import (
    load_sd_components,
    load_text_components,
    generate_images,
    compute_text_embedding,
)
from utils.datasets import transform_image
from torchvision import transforms
from torchmetrics.functional import pairwise_cosine_similarity


def sscd_between(original: torch.Tensor, generated: torch.Tensor) -> float:
    sscd_m = torch.jit.load("sscd_disc_mixup.torchscript.pt").cuda().eval()
    torch.set_num_threads(4)

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    skew_320 = transforms.Compose(
        [
            transforms.Resize([320, 320]),
            normalize,
        ]
    )

    features_before = sscd_m(skew_320(original.to("cuda"))).cpu()
    features_after = sscd_m(skew_320(generated.to("cuda"))).cpu()

    cosine_similarity = (
        pairwise_cosine_similarity(features_before, features_after)[0].max().item()
    )

    return cosine_similarity


def compute_sscd_scores(
    embedding: torch.Tensor,
    original_image: torch.Tensor,
    tokenizer,
    text_encoder,
    vae,
    unet,
    scheduler,
    num_images=5,
    guidance_scale=7,
    num_inference_steps=50,
) -> float:
    """Generate images from embedding and compute highest SSCD score."""
    with torch.no_grad():
        # Generate images
        generated_images = generate_images(
            [""],
            tokenizer,
            text_encoder,
            vae,
            unet,
            scheduler,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            samples_per_prompt=num_images,
            text_embeddings=embedding.detach(),
            seed=100,
        )

        # Convert generated images to tensor
        gen_tensors = torch.stack(
            [transforms.ToTensor()(img) for img in generated_images], dim=0
        ).unsqueeze(1)

        # Compute SSCD scores
        sscd_scores = []
        for gen_tensor in gen_tensors:
            sscd_scores.append(sscd_between(original_image, gen_tensor))

        return max(sscd_scores)


def optimize_embedding(image_path, prompt, steps=1000, batch_size=5, lr=0.1):
    # Load Stable Diffusion components
    vae, unet, scheduler = load_sd_components("v1-4")
    tokenizer, text_encoder = load_text_components("v1-4")

    # Move models to GPU and disable gradients
    torch_device = "cuda"
    vae.to(torch_device)
    vae.requires_grad_(False)
    text_encoder.to(torch_device)
    text_encoder.requires_grad_(False)
    unet.to(torch_device)
    unet.requires_grad_(False)

    # Load and transform image
    image = Image.open(image_path)
    original_image = transforms.ToTensor()(image.convert("RGB")).unsqueeze(0)

    image = transform_image(image.convert("RGB")).unsqueeze(0)
    image = image.to(torch_device)

    # Get latents from image
    latents = vae.encode(image).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)
    latents = latents.cuda()

    # Initialize embedding to optimize
    embedding_optim = (
        compute_text_embedding([prompt], tokenizer, text_encoder)
        .to(torch_device)
        .requires_grad_(True)
    )

    # Set up optimizer
    optimizer = torch.optim.Adam([embedding_optim], lr=lr)

    # Track steps, norms, SSCD scores, and timestep losses
    target_steps = [1, 10, 50, 100, 200, 500, 1000]
    saved_steps = []
    norms = []
    means = []
    stds = []
    sscd_scores = []

    generator = torch.Generator(device=latents.device).manual_seed(1)

    # Run optimization loop
    for step in range(steps):
        noise = torch.empty_like(latents).normal_(generator=generator)
        timesteps = torch.randint(
            0,
            scheduler.config.num_train_timesteps,
            (batch_size,),
            device=latents.device,
            generator=generator,
        )
        timesteps = timesteps.long()

        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        embedding_repeated = torch.repeat_interleave(
            embedding_optim, dim=0, repeats=batch_size
        )
        model_pred = unet(
            noisy_latents, timesteps, embedding_repeated, return_dict=False
        )[0]
        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Store norm and generate images at target steps
        if step + 1 in target_steps:
            # Store step
            saved_steps.append(step + 1)

            # Store norm, mean, and std
            with torch.no_grad():
                norms.append(torch.norm(embedding_optim).item())
                means.append(embedding_optim.mean().item())
                stds.append(embedding_optim.std().item())

            # Compute SSCD score
            sscd_score = compute_sscd_scores(
                embedding_optim,
                original_image,
                tokenizer,
                text_encoder,
                vae,
                unet,
                scheduler,
            )
            sscd_scores.append(sscd_score)

    return (
        embedding_optim,
        saved_steps,
        norms,
        means,
        stds,
        sscd_scores,
    )


def process_images(csv_path, output_path, steps=1000, batch_size=5, lr=0.1):
    """Process multiple images from a CSV file and save all metrics in a single file."""
    df = pd.read_csv(csv_path, sep=";")

    all_saved_steps = []
    all_norms = []
    all_means = []
    all_stds = []
    all_sscd_scores = []

    for idx, row in tqdm(df.iterrows(), total=len(df)):
        image_path = row["image_path"]
        prompt = row["prompt"]
        try:
            (
                embedding,
                saved_steps,
                norms,
                means,
                stds,
                sscd_scores,
            ) = optimize_embedding(
                image_path,
                prompt,
                steps=steps,
                batch_size=batch_size,
                lr=lr,
            )
        except Exception as e:
            print(f"Error optimizing embedding for image {image_path}: {e}")
            continue

        # Store metrics
        all_saved_steps.append(saved_steps)
        all_norms.append(norms)
        all_means.append(means)
        all_stds.append(stds)
        all_sscd_scores.append(sscd_scores)

    # Convert lists to numpy arrays
    all_saved_steps = np.array(all_saved_steps)  # Shape: (num_images, num_steps)
    all_norms = np.array(all_norms)  # Shape: (num_images, num_steps)
    all_means = np.array(all_means)  # Shape: (num_images, num_steps)
    all_stds = np.array(all_stds)  # Shape: (num_images, num_steps)
    all_sscd_scores = np.array(all_sscd_scores)  # Shape: (num_images, num_steps)

    # Save all metrics to a single file
    np.savez(
        output_path,
        saved_steps=all_saved_steps,
        norms=all_norms,
        means=all_means,
        stds=all_stds,
        sscd_scores=all_sscd_scores,
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--csv_path",
        type=str,
        required=True,
        help="Path to CSV file containing image paths",
    )
    parser.add_argument(
        "--output_path", type=str, required=True, help="Path to save output metrics"
    )
    parser.add_argument(
        "--steps", type=int, default=1000, help="Number of optimization steps"
    )
    parser.add_argument(
        "--batch_size", type=int, default=5, help="Batch size for optimization"
    )
    parser.add_argument("--lr", type=float, default=0.1, help="Learning rate")
    args = parser.parse_args()

    print(args)

    process_images(
        args.csv_path,
        args.output_path,
        steps=args.steps,
        batch_size=args.batch_size,
        lr=args.lr,
    )
