import torch
import torchvision
from data_loader import get_dataloader, extract_embeddings
from image_creation import create_images_for_class
from clustering import Clustering
from utils import compute_sigma
from models.lmm import LargeMultiModalModel, PREPROCESS
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Image generation with LargeMultiModalModel")
    parser.add_argument('--device', type=str, default="cuda:0", help="Device to use: e.g., 'cuda:0' or 'cpu'")
    parser.add_argument('--split', type=str, default="0", help="Split of the data")
    parser.add_argument('--dataset', type=str, choices=['CIFAR10', 'STL10'], default='CIFAR10', help="Dataset to use: CIFAR10 or STL10")
    parser.add_argument('--K', type=int, default=5, help="Number of clusters (N_c) for spectral clustering")
    parser.add_argument('--batch_size', type=int, default=32, help="Batch size used for sampling data")
    parser.add_argument('--epsilon', type=float, required=True, help="Privacy parameter epsilon")
    parser.add_argument('--delta', type=float, required=True, help="Privacy parameter delta")
    parser.add_argument('--private_dir', type=str, required=True, help="Root directory for private dataset")
    parser.add_argument('--output_dir', type=str, required=True, help="Root directory to save generated images")
    parser.add_argument('--syn_dir', type=str, required=True, help="Root directory for synthetic dataset")
    return parser.parse_args()

# Main script execution
if __name__ == "__main__":
    # Parse command-line arguments
    args = parse_args()

    if args.dataset == "CIFAR10":
        categories = ["an airplane", "an automobile", "a bird", "a cat", "a deer", "a dog", "a frog", "a horse", "a ship", "a truck"]
    elif args.dataset == "STL10":
        categories = ["an airplane", "a bird", "a car", "a cat", "a deer", "a dog", "a horse", "a monkey", "a ship", "a truck"]

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    
    # Load the LargeMultiModalModel
    lmm = LargeMultiModalModel(device)

    # Load datasets based on user input
    private_loader = get_dataloader(args.dataset, batch_size=args.batch_size, transform=PREPROCESS, root=args.private_dir, is_train=True)
    syn_loader = get_dataloader(args.dataset, batch_size=args.batch_size, transform=PREPROCESS, root=args.syn_dir, is_train=False)

    # Extract embeddings from the datasets
    private_images, private_labels, private_embeddings = extract_embeddings(private_loader, lmm, device)
    syn_images, syn_labels, syn_embeddings = extract_embeddings(syn_loader, lmm, device)

    # Calculate sigma based on epsilon, delta, and other inputs
    sigma = compute_sigma(args.epsilon, args.delta)

    # Clustering images into sub-categories
    private_cluster_labels, syn_cluter_labels = Clustering(args.K, private_embeddings, private_labels, syn_embeddings, syn_labels)

    # Create images for each class
    for C in range(10):
        create_images_for_class(C, private_cluster_labels, syn_cluter_labels, args.K, lmm, private_embeddings, private_labels, syn_embeddings, syn_labels, args.split, device, sigma, args.output_dir, categories)
