import torch
from utils.datasets import load_and_encode_image
from utils.stable_diffusion import compute_text_embedding

def find_adv_text_embeddings(img_path, unet, tokenizer, text_encoder, vae, scheduler, prompt=None, num_steps=15, batch_size=4, seed=1, lr=1e-2, fp16=False):
    """
    Find an adversarial text embeddings that replicates the imag in img_path using the Stable Diffusion model.
    
    Args:
        img_path (str): Path to the image.
        unet: The UNet model of Stable Diffusion.
        prompt (str): The text prompt to start the optimization from. If none, a random text embedding is generated.
        num_steps (int): Number of optimization steps.
        batch_size (int): Batch size for the optimization.
        seed (int): Random seed for reproducibility.
        lr (float): Learning rate for the optimizer.
        
    Returns:
        torch.Tensor: The optimized text embeddings.
    """
    torch.manual_seed(seed)
    if prompt:
        with torch.no_grad():
            text_embedding = compute_text_embedding(prompt, tokenizer, text_encoder)
    else:
        print("No prompt provided, generating random text embedding.")
        text_embedding = torch.randn(1, 77, 768).cuda()
        
    with torch.no_grad():  
        latents = load_and_encode_image(img_path, vae)
        latents = torch.repeat_interleave(latents, dim=0, repeats=batch_size)

    text_embedding = text_embedding.detach().to(unet.device)
    text_embedding.requires_grad_(True)
    latents = latents.to(unet.device)

    # set optimizer to update text embedding
    optimizer = torch.optim.Adam(lr=lr, params=[text_embedding])
    
    # run the optimization loop. Follows the standard diffusion training loop but updates the text embedding instead of the latents
    for step in range(num_steps):
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if fp16 else torch.float32):
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)
            timesteps = timesteps.long()
            
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)
            text_embedding_repeated = torch.repeat_interleave(text_embedding, dim=0, repeats=batch_size)
            model_pred = unet(noisy_latents, timesteps, text_embedding_repeated, return_dict=False)[0]
            loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
                        
    return text_embedding.detach()