import os
import torch
import torchvision
from models.lmm import PREPROCESS

def create_images_for_class(C, private_cluster_labels, syn_cluster_labels, N_c, model, private_embeddings, private_labels, syn_embeddings, syn_labels, split, device, sigma, output_dir, categories):
    C_syn_embeddings = syn_embeddings[syn_labels == C]
    
    for c in range(N_c):
        c_syn_embeddings = C_syn_embeddings[syn_cluster_labels[C] == c]
        c_embeddings = private_embeddings[private_labels == C][private_cluster_labels[C] == c]
        noise = torch.randn_like(c_embeddings[0]) * sigma * 20  # Use the calculated sigma here

        dp_mean = c_embeddings.mean(0) + noise / len(c_embeddings)
        direction = dp_mean - c_syn_embeddings.mean(0)
        c_syn_embeddings += direction

        for i in range(len(c_syn_embeddings)):
            image_path = os.path.join(output_dir, str(C), f'image_{C}_{c}_{i}.png')
            if not os.path.exists(image_path):
                save_image_with_retries(i, C, c_syn_embeddings[i], model, C, image_path, categories, device)

def save_image_with_retries(index, class_idx, embedding, model, C, image_path, categories, device):
    label_dir = os.path.dirname(image_path)
    if not os.path.exists(label_dir):
        os.makedirs(label_dir)

    prediction = None
    n_trials = 0
    while prediction != C and n_trials < 1:
        print(f"Creating {class_idx}-{index}-th image of class {categories[class_idx]}, {n_trials}-th trial...")
        image = model.decode_one_embed(embedding.to(device), n_steps=50, negative_prompt="cropped", prompt=f"a photo of a {categories[class_idx]}")
        texts = [f"a photo of a {category}" for category in categories]
        prediction = model.classification(PREPROCESS(image), texts=texts)
        n_trials += 1

    if prediction == C:
        image.resize((224, 224)).save(image_path)
