import os
import torch
from pathlib import Path
from diffusers import StableDiffusionPipeline
from tqdm import tqdm

def setup_pipeline():
    """Set up and initialize Stable Diffusion pipeline with LoRA weights"""
    # Configure CUDA device
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    inference_dtype = torch.float16
    huggingface_cache_dir = os.environ.get('HUGGING_FACE_CACHE_DIR', None)
    
    # Load the pipeline
    pipeline = StableDiffusionPipeline.from_pretrained(
        'runwayml/stable-diffusion-v1-5',
        torch_dtype=inference_dtype,
        cache_dir=huggingface_cache_dir,
    )
    
    # Load the LoRA weight
    lora_path = './models--SPO-Diffusion-Models--SPO-SD-v1-5_4k-p_10ep/spo-sd-v1-5_4k-p_10ep_lora_diffusers.safetensors' # the lora path
    pipeline.load_lora_weights(lora_path)
    
    pipeline.safety_checker = None
    pipeline = pipeline.to(device)
    
    return pipeline

def generate_image(prompt, pipeline, seed=42):
    """Generate a single image"""
    generator = torch.Generator(device='cuda:0').manual_seed(seed)
    image = pipeline(
        prompt=prompt,
        generator=generator,
        guidance_scale=7.5
    ).images[0]
    return image

def load_prompts(file_path):
    """Load prompts from text file"""
    prompts = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            prompts.append(line.strip())
    return prompts

def main():
    # Set the path; when you need to generate images of a new category, you need to change the path
    prompt_file = './T2I-Com/texture.txt'
    save_dir = Path('./Generation/SPO-with-LoRA/texture')
    
    # Create save directory
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize pipeline
    pipeline = setup_pipeline()
    
    # Load prompts
    prompts = load_prompts(prompt_file)
    print(f"Loaded {len(prompts)} prompts")
    
    # Generation
    for idx, prompt in enumerate(tqdm(prompts, desc="Generating texture images")):
        image_path = save_dir / f"img{idx}.jpg"
        try:
            image = generate_image(prompt, pipeline)
            image.save(image_path)
        except Exception as e:
            print(f"Error generating image {idx}: {e}")

if __name__ == "__main__":
    main()