# coding: utf-8
import os
import torch
import json
from PIL import Image
from itertools import islice

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ['ORT_DISABLE_THREAD_AFFINITY'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from diffusers import ZImagePipeline
from utils.load_text_masked_lora import load_text_masked_lora
from utils.logger import create_logger
from datasets import load_dataset
import clip
import numpy as np
from scipy import linalg

DEFENSE_WEIGHTS_PATH = "uce/uce_models/text_masked_lora_violence.safetensors"
MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 512
NUM_INFERENCE_STEPS = 9
GUIDANCE_SCALE = 0.0
NUM_SAMPLES = 10000
OUT_DIR_REAL = "./output_coco_real"
OUT_DIR_FAKE = "./output_coco_uce_violence"
SEED = 42

def load_coco_dataset(num_samples):
    """Load COCO dataset in streaming mode and extract first N samples"""
    dataset = load_dataset(
        "lmms-lab/COCO-Caption",
        split="val",
        streaming=True
    )
    
    dataset_list = list(islice(dataset, num_samples))
    
    images = []
    captions = []
    
    for sample in dataset_list:
        images.append(sample["image"])
        captions.append(sample["answer"][0])
    
    return images, captions

def save_images(images, out_dir):
    """Save PIL images to directory"""
    os.makedirs(out_dir, exist_ok=True)
    for i, img in enumerate(images):
        if isinstance(img, Image.Image):
            img.save(os.path.join(out_dir, f"{i:04d}.png"))

def save_captions(captions, captions_path):
    """Save captions to JSON file"""
    with open(captions_path, 'w', encoding='utf-8') as f:
        json.dump(captions, f, ensure_ascii=False, indent=2)

def load_captions(captions_path):
    """Load captions from JSON file"""
    with open(captions_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_features(features, features_path):
    """Save features to numpy file"""
    np.save(features_path, features)

def load_features(features_path):
    """Load features from numpy file"""
    return np.load(features_path)

def load_model_and_weights(defense_weights_path=None):
    """Load ZImage pipeline and defense weights"""
    cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
    os.makedirs(cache_dir, exist_ok=True)
    
    pipe = ZImagePipeline.from_pretrained(
        MODEL_ID, 
        torch_dtype=torch.bfloat16, 
        cache_dir=cache_dir,
    )
    pipe = pipe.to(DEVICE)
    
    if defense_weights_path and os.path.exists(defense_weights_path):
        if defense_weights_path.endswith('text_masked_lora.safetensors') or \
           os.path.basename(defense_weights_path) == 'text_masked_lora.safetensors':
            print(f"[Position-Masked LoRA] Loading from {defense_weights_path}")
            lora_scale = 30
            print(f"  Using LoRA scale: {lora_scale}x")
            pipe = load_text_masked_lora(
                pipe,
                defense_weights_path,
                image_seq_len=1024,
                device=DEVICE,
                lora_scale=lora_scale,
            )
            print(f"✓ Successfully loaded position-masked LoRA weights")
        elif defense_weights_path.endswith('.safetensors'):
            print(f"[Standard LoRA] Loading from {defense_weights_path}")
            pipe.load_lora_weights(defense_weights_path)
            print(f"✓ Successfully loaded standard LoRA weights")
        elif defense_weights_path.endswith('.pt'):
            print(f"[State Dict] Loading from {defense_weights_path}")
            state_dict = torch.load(defense_weights_path, map_location=DEVICE)
            pipe.transformer.load_state_dict(state_dict)
            print(f"✓ Successfully loaded state dict")
    else:
        if defense_weights_path:
            print(f"⚠ Warning: Defense weights path not found: {defense_weights_path}")
            print("  Continuing without LoRA weights...")
    
    return pipe

def calculate_clip_score(images, captions, clip_model, preprocess):
    """Calculate CLIP score between images and captions"""
    image_inputs = torch.stack([
        preprocess(img.convert("RGB")) for img in images
    ]).to(DEVICE)
    
    text_inputs = clip.tokenize(captions, truncate=True).to(DEVICE)
    
    with torch.no_grad():
        image_features = clip_model.encode_image(image_inputs)
        text_features = clip_model.encode_text(text_inputs)
        
        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        # Calculate cosine similarity
        clip_score = (image_features * text_features).sum(dim=-1)
    
    return clip_score

def get_inception_features(images, inception_model):
    """Extract features from images using InceptionV3"""
    from torchvision import transforms
    
    preprocess = transforms.Compose([
        transforms.Resize(299),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    features_list = []
    with torch.no_grad():
        for img in images:
            if isinstance(img, str):
                img = Image.open(img).convert("RGB")
            elif isinstance(img, Image.Image):
                img = img.convert("RGB")
            img_tensor = preprocess(img).unsqueeze(0).to(DEVICE)
            features = inception_model(img_tensor)
            features_list.append(features.cpu().numpy())
    
    return np.concatenate(features_list, axis=0)

def calculate_fid(real_features, fake_features):
    """Calculate FID score between two sets of features"""
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = linalg.sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def load_inception_model():
    """Load InceptionV3 model for FID calculation"""
    from torchvision.models import inception_v3, Inception_V3_Weights
    
    model = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
    model.fc = torch.nn.Identity()  # Remove final classification layer
    model.aux_logits = False  # Disable auxiliary outputs
    model = model.to(DEVICE)
    model.eval()
    
    return model

def main():
    os.makedirs(OUT_DIR_REAL, exist_ok=True)
    os.makedirs(OUT_DIR_FAKE, exist_ok=True)
    
    logger = create_logger(OUT_DIR_FAKE)
    
    captions_path = os.path.join(OUT_DIR_REAL, "captions.json")
    features_path = os.path.join(OUT_DIR_REAL, "real_features.npy")
    
    # Check if features already exist - if yes, skip loading images
    if os.path.exists(features_path) and os.path.exists(captions_path):
        logger.info(f"Loading features from {features_path}...")
        real_features = load_features(features_path)
        logger.info(f"Loaded real features shape: {real_features.shape}")
        
        logger.info(f"Loading captions from {captions_path}...")
        captions = load_captions(captions_path)
        logger.info(f"Loaded {len(captions)} captions")
    else:
        # Need to load images to extract features
        existing_files = sorted([f for f in os.listdir(OUT_DIR_REAL) if f.endswith('.png')]) if os.path.exists(OUT_DIR_REAL) else []
        if len(existing_files) >= NUM_SAMPLES and os.path.exists(captions_path):
            logger.info(f"Found {len(existing_files)} existing images in {OUT_DIR_REAL}, loading directly...")
            real_images = [Image.open(os.path.join(OUT_DIR_REAL, f)).convert("RGB") for f in existing_files[:NUM_SAMPLES]]
            logger.info(f"Loading captions from {captions_path}...")
            captions = load_captions(captions_path)
            logger.info(f"Loaded {len(captions)} captions")
        else:
            logger.info("Loading COCO dataset...")
            real_images, captions = load_coco_dataset(NUM_SAMPLES)
            logger.info(f"Loaded {len(real_images)} images and captions from COCO")
            logger.info("Saving real images...")
            save_images(real_images, OUT_DIR_REAL)
            logger.info(f"Real images saved to {OUT_DIR_REAL}")
            logger.info("Saving captions...")
            save_captions(captions, captions_path)
            logger.info(f"Captions saved to {captions_path}")
        
        # Extract features from images
        logger.info("Loading InceptionV3 model for FID...")
        inception_model = load_inception_model()
        logger.info("InceptionV3 model loaded")
        
        logger.info("Extracting features from real images...")
        real_features = get_inception_features(real_images, inception_model)
        logger.info(f"Real features shape: {real_features.shape}")
        
        logger.info("Saving features...")
        save_features(real_features, features_path)
        logger.info(f"Features saved to {features_path}")
    
    if os.path.exists(features_path) and os.path.exists(captions_path):
        logger.info("Loading InceptionV3 model for FID...")
        inception_model = load_inception_model()
        logger.info("InceptionV3 model loaded")
    
    logger.info("Loading model and defense weights...")
    pipe = load_model_and_weights(DEFENSE_WEIGHTS_PATH)
    logger.info("Model loaded successfully")
    
    logger.info("Loading CLIP model...")
    clip_model, clip_preprocess = clip.load("ViT-B/32", device=DEVICE)
    logger.info("CLIP model loaded")
    
    fake_images = []
    cumulative_clip_scores = []
    
    logger.info("Starting image generation...")
    for i, caption in enumerate(captions, 1):
        logger.info(f"Generating image {i}/{NUM_SAMPLES}...")
        
        generator = torch.Generator(device=DEVICE).manual_seed(SEED + i)
        image = pipe(
            prompt=caption,
            height=IMG_SIZE,
            width=IMG_SIZE,
            num_inference_steps=NUM_INFERENCE_STEPS,
            guidance_scale=GUIDANCE_SCALE,
            generator=generator
        ).images[0]
        
        image_path = os.path.join(OUT_DIR_FAKE, f"{i-1:04d}.png")
        image.save(image_path)
        fake_images.append(image)
        
        # Calculate current CLIP score
        current_clip = calculate_clip_score([image], [caption], clip_model, clip_preprocess)
        cumulative_clip_scores.append(current_clip.item())
        avg_clip = np.mean(cumulative_clip_scores)
        
        logger.info(f"Current CLIP: {current_clip.item():.4f}, Average CLIP: {avg_clip:.4f}")
    
    logger.info("\nCalculating final metrics...")
    
    # Final average CLIP score
    final_avg_clip = np.mean(cumulative_clip_scores)
    
    # Final FID score (calculate once at the end)
    logger.info("Extracting features from generated images for FID calculation...")
    fake_features_final = get_inception_features(fake_images, inception_model)
    logger.info("Calculating FID score...")
    final_fid = calculate_fid(real_features, fake_features_final)
    
    logger.info("\n" + "="*60)
    logger.info("Evaluation Results:")
    logger.info(f"Average CLIP Score: {final_avg_clip:.4f}")
    logger.info(f"Final FID Score: {final_fid:.4f}")
    logger.info(f"Total images processed: {NUM_SAMPLES}")
    logger.info(f"Real images saved to: {OUT_DIR_REAL}")
    logger.info(f"Generated images saved to: {OUT_DIR_FAKE}")
    logger.info("="*60)

if __name__ == "__main__":
    main()

