import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ['ORT_DISABLE_THREAD_AFFINITY'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import time
import random
from PIL import Image
from diffusers import ZImagePipeline
from diffusers.utils import load_image
from utils.load_text_masked_lora import load_text_masked_lora

PROMPT = "Enter your prompt here"
DEFENSE_WEIGHTS_PATH = ""
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_DIR = "./single_image_output"


def load_model_and_weights(defense_weights_path=None):
    cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
    os.makedirs(cache_dir, exist_ok=True)
    
    pipe = ZImagePipeline.from_pretrained(
        MODEL_ID, 
        torch_dtype=torch.bfloat16, 
        cache_dir=cache_dir,
    )
    pipe = pipe.to(DEVICE)
    # pipe.enable_model_cpu_offload()
    
    if defense_weights_path and os.path.exists(defense_weights_path):
        if defense_weights_path.endswith('text_masked_lora.safetensors'):
            print(f"[Position-Masked LoRA] Loading from {defense_weights_path}")
            # Use larger lora_scale for stronger effect
            lora_scale = 15
            print(f"  Using LoRA scale: {lora_scale}x")
            pipe = load_text_masked_lora(
                pipe,
                defense_weights_path,
                image_seq_len=1024,
                device=DEVICE,
                lora_scale=lora_scale,  # Amplify LoRA effect
            )
            print(f"Successfully loaded position-masked LoRA weights")
    
    return pipe

def generate_single_image(pipe, prompt, seed, output_path, edit_image=None):
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    kwargs = {
        "prompt": prompt,
        "height": 512,
        "width": 512,
        "num_inference_steps": 9,
        "guidance_scale": 0.0,
        "generator": generator
    }
    if edit_image is not None:
        kwargs["image"] = edit_image
    image = pipe(**kwargs).images[0]
    image.save(output_path)
    return True

def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    
    pipe = load_model_and_weights(DEFENSE_WEIGHTS_PATH)
    
    prompt = PROMPT
    
    for i in range(10):
        seed = random.randint(0, 2**20 - 1)
        timestamp = int(time.time())
        output_filename = f"generated_{seed}_{i+1}_{timestamp}.png"
        output_path = os.path.join(OUT_DIR, output_filename)
        generate_single_image(pipe, prompt, seed, output_path)
        
    print(f"Saved images to {OUT_DIR}")

if __name__ == "__main__":
    main() 