import torch
import torchvision.models as models
from PIL import Image
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDIMInverseScheduler, DDIMScheduler
import os
import argparse
import torchvision.datasets as datasets
import torchvision.transforms as transforms

def load_diffusion_pipeline(device, stable_diffusion_model_path, unet_path):
    pipe = StableDiffusionPipeline.from_pretrained(stable_diffusion_model_path, torch_dtype=torch.float16).to(device)
    
    if unet_path:
        unet = UNet2DConditionModel.from_pretrained(
            unet_path,
            subfolder="unet",
            torch_dtype=torch.float16,
            local_files_only=True
        ).to(device)
        pipe.unet = unet
        
    return pipe

def data_expansion(pipe, device, expansion_factor, total_split, split, class_to_idx, 
                  dataset_dir, save_dir, optimized_embedding_dir, stable_diffusion_model_path):
    os.makedirs(save_dir, exist_ok=True)

    class_names = os.listdir(dataset_dir)
    num_classes = len(class_names)
    classes_per_split = num_classes // total_split
    start_index = split * classes_per_split
    end_index = start_index + classes_per_split if split < total_split - 1 else num_classes

    for i in range(start_index, end_index):
        class_name = class_names[i]
        class_dir = os.path.join(dataset_dir, class_name)
        if not os.path.isdir(class_dir):
            continue

        optimized_embedding_path = f"{optimized_embedding_dir}/{class_name}/learned_embeds_final.safetensors"
        prompt_embeds = torch.load(optimized_embedding_path)

        class_label = class_to_idx[class_name]
        targets = torch.tensor([class_label], device=device)

        class_save_dir = os.path.join(save_dir, class_name)
        os.makedirs(class_save_dir, exist_ok=True)

        image_count = 0
        class_images_count = len([f for f in os.listdir(class_dir) if os.path.isfile(os.path.join(class_dir, f))])
        for _ in range(class_images_count * expansion_factor):
            pipe.scheduler = DDIMScheduler.from_pretrained(stable_diffusion_model_path, subfolder="scheduler")
            latents = pipe(
                prompt_embeds=prompt_embeds,
                guidance_scale=5.5,
                num_inference_steps=1,
                output_type="latent",
                return_dict=False
            )[0]
            
            pipe.scheduler = DDIMInverseScheduler.from_pretrained(stable_diffusion_model_path, subfolder="scheduler")
            inv_latents, _ = pipe(
                prompt_embeds=prompt_embeds,
                guidance_scale=0,
                num_inference_steps=1,
                output_type="latent",
                return_dict=False,
                latents=latents
            )

            pipe.scheduler = DDIMScheduler.from_pretrained(stable_diffusion_model_path, subfolder="scheduler")
            new_image = pipe(prompt_embeds=prompt_embeds, guidance_scale=5, latents=inv_latents).images[0]
            
            if os.listdir(class_dir):
                sample_ext = os.path.splitext(os.listdir(class_dir)[0])[1]
                new_image_name = f"{image_count}{sample_ext}"
                save_path = os.path.join(class_save_dir, new_image_name)
                new_image.save(save_path)
                image_count += 1

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Data expansion for dataset")
    parser.add_argument("--expansion_factor", type=int, required=True, help="Expansion factor for each class")
    parser.add_argument("--total_split", type=int, required=True, help="Total number of task splits")
    parser.add_argument("--split", type=int, required=True, help="Current split index (starting from 0)")
    parser.add_argument("--stable_diffusion_model", type=str, required=True, help="Path to Stable Diffusion model")
    parser.add_argument("--unet_path", type=str, required=True, help="Path to custom UNet model")
    parser.add_argument("--dataset_dir", type=str, required=True, help="Original dataset directory")
    parser.add_argument("--save_dir", type=str, required=True, help="Directory to save generated images")
    parser.add_argument("--optimized_embedding_dir", type=str, required=True, help="Directory of optimized text embeddings")
    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = load_diffusion_pipeline(device, args.stable_diffusion_model, args.unet_path)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_dataset = datasets.ImageFolder(root=args.dataset_dir, transform=transform)
    class_to_idx = train_dataset.class_to_idx

    data_expansion(pipe, device, args.expansion_factor, args.total_split, args.split, class_to_idx, 
                  args.dataset_dir, args.save_dir, args.optimized_embedding_dir, args.stable_diffusion_model)