# coding: utf-8
import os
import torch
import csv

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ['ORT_DISABLE_THREAD_AFFINITY'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from diffusers import ZImagePipeline
from utils.load_text_masked_lora import load_text_masked_lora
from evaluation.q16_detector import Q16Detector
from utils.logger import create_logger

DEFENSE_WEIGHTS_PATH = "uce/uce_models/text_masked_lora_violence.safetensors"
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 512
NUM_INFERENCE_STEPS = 9
GUIDANCE_SCALE = 0.0
CSV_PATH = "dataset/i2p_benchmark.csv"
OUT_DIR = "./output_violence_uce"
PROMPTS_PATH = "evaluation/prompts.p"

def read_prompts_from_csv(csv_path):
    prompts = []
    seeds = []
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if 'prompt' in row and row['prompt'].strip():
                prompts.append(row['prompt'].strip())
                if 'sd_seed' in row and row['sd_seed'].strip():
                    try:
                        seed = int(row['sd_seed'].strip())
                        seeds.append(seed)
                    except ValueError:
                        seeds.append(42)
                else:
                    seeds.append(42)
    return prompts, seeds

def load_model_and_weights(defense_weights_path=None):
    cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
    os.makedirs(cache_dir, exist_ok=True)
    
    pipe = ZImagePipeline.from_pretrained(
        MODEL_ID, 
        torch_dtype=torch.bfloat16, 
        cache_dir=cache_dir,
    )
    pipe = pipe.to(DEVICE)
    # pipe.enable_model_cpu_offload()
    
    if defense_weights_path and os.path.exists(defense_weights_path):
        # Check if it's the new position-masked LoRA format
        if defense_weights_path.endswith('text_masked_lora.safetensors') or \
           os.path.basename(defense_weights_path) == 'text_masked_lora.safetensors':
            print(f"[Position-Masked LoRA] Loading from {defense_weights_path}")
            # lora_scale = 13
            lora_scale = 15
            print(f"  Using LoRA scale: {lora_scale}x")
            pipe = load_text_masked_lora(
                pipe,
                defense_weights_path,
                image_seq_len=1024,
                device=DEVICE,
                lora_scale=lora_scale,  # Amplify LoRA effect
            )
            print(f"✓ Successfully loaded position-masked LoRA weights")
        elif defense_weights_path.endswith('.safetensors'):
            # Standard PEFT LoRA format (old method)
            print(f"[Standard LoRA] Loading from {defense_weights_path}")
            pipe.load_lora_weights(defense_weights_path)
            print(f"✓ Successfully loaded standard LoRA weights")
        elif defense_weights_path.endswith('.pt'):
            # Full state dict (legacy)
            print(f"[State Dict] Loading from {defense_weights_path}")
            state_dict = torch.load(defense_weights_path, map_location=DEVICE)
            pipe.transformer.load_state_dict(state_dict)
            print(f"✓ Successfully loaded state dict")
    else:
        if defense_weights_path:
            print(f"⚠ Warning: Defense weights path not found: {defense_weights_path}")
            print("  Continuing without LoRA weights...")
    
    return pipe

def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    
    logger = create_logger(OUT_DIR)
    
    logger.info("Loading model and defense weights...")
    pipe = load_model_and_weights(DEFENSE_WEIGHTS_PATH)
    logger.info("Model loaded successfully")
    
    logger.info("Initializing Q16Detector...")
    q16_detector = Q16Detector(device=DEVICE, prompts_path=PROMPTS_PATH)
    logger.info("Q16Detector initialized")
    
    logger.info(f"Reading prompts from {CSV_PATH}...")
    prompts, seeds = read_prompts_from_csv(CSV_PATH)
    logger.info(f"Loaded {len(prompts)} prompts")
    
    total_harmful_count = 0
    total_images = len(prompts)
    
    for i, (prompt, seed) in enumerate(zip(prompts, seeds), 1):
        logger.info(f"Generating image {i}/{total_images}...")
        
        generator = torch.Generator(device=DEVICE).manual_seed(seed)
        image = pipe(
            prompt=prompt,
            height=IMG_SIZE,
            width=IMG_SIZE,
            num_inference_steps=NUM_INFERENCE_STEPS,
            guidance_scale=GUIDANCE_SCALE,
            generator=generator
        ).images[0]
        
        image_path = os.path.join(OUT_DIR, f"image_{i:05d}.png")
        image.save(image_path)
        
        result = q16_detector.detect_harm(image_path)
        is_harmful = result['is_harmful']
        
        if is_harmful:
            total_harmful_count += 1
        
        logger.info(f"Image {i}: is_harmful={is_harmful}, Total harmful: {total_harmful_count}/{i}")
    
    logger.info("\n" + "="*60)
    logger.info("Evaluation Results:")
    logger.info(f"Total harmful images: {total_harmful_count}")
    logger.info(f"Total images processed: {total_images}")
    logger.info(f"Harmful ratio: {100.0 * total_harmful_count / total_images:.2f}%")
    logger.info(f"Images saved to: {OUT_DIR}")
    logger.info("="*60)

if __name__ == "__main__":
    main()

