import torch
from transformers import CLIPTokenizer
from diffusers import (
    StableDiffusionPipeline, 
    DDIMScheduler, 
    EulerDiscreteScheduler, 
    StableDiffusion3Pipeline,
    DiffusionPipeline,  # For SDXL
    FluxPipeline  # For FLUX
)
import argparse
from dataclasses import dataclass
import numpy as np
import os
from PIL import Image
from pytorch_lightning import seed_everything
import pandas as pd

@dataclass
class CLIPEmbeddingInfo:
    prompt: str
    embedding: np.ndarray  # [seq_len, hidden_dim]

def main():
    parser = argparse.ArgumentParser(description='Modify CLIP embeddings and generate images')
    parser.add_argument('--model_path', type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--codebook_dir', type=str, default="codebook_copro_sexual2")
    parser.add_argument('--prompts_csv', type=str, default=None,
                       help='Path to CSV file containing prompts')
    parser.add_argument('--prompt', type=str, default="a photo of a dog",
                       help='Single prompt to use when prompts_csv is not provided')
    parser.add_argument('--negative_prompt', type=str, default=None,
                       help='Negative prompt to use (only for SD v1.x models)')
    parser.add_argument('--start_idx', type=int, default=1,
                       help='Starting index for prompts')
    parser.add_argument('--end_idx', type=int, default=None,
                       help='Ending index for prompts (None for all)')
    # Image generation arguments
    parser.add_argument('--height', type=int, default=512)
    parser.add_argument('--width', type=int, default=512)
    parser.add_argument('--output_path', type=str, default="output.png")
    parser.add_argument('--num_inference_steps', type=int, default=50)
    parser.add_argument('--guidance_scale', type=float, default=7.5)
    # Custom CLIP encoder arguments
    parser.add_argument('--training_method', type=str, default=None,
                       choices=['des', 'advunlearn', 'visu', 
                                'uce', 'esd', 'fmn', 'salun', 'spm', 'safegen', None],
                       help='Training method used (if None, use original CLIP)')
    parser.add_argument('--text_encoder_path', type=str, default=None,
                       help='Path to trained checkpoint')
    parser.add_argument('--text_encoder_2_path', type=str, default=None,
                       help='Path to trained checkpoint')
    parser.add_argument('--text_encoder_3_path', type=str, default=None,
                       help='Path to trained checkpoint')
    # Custom U-Net arguments
    parser.add_argument('--unet_path', type=str, default=None,
                       help='Path to trained U-Net checkpoint')
    # Prompt or Image Filtering
    parser.add_argument('--use_prompt_filtering', action='store_true',
                       help='Use prompt filtering')
    parser.add_argument('--use_image_filtering', action='store_true',
                       help='Use image safety checker (for v1 and v2 models only)')
    # Safe Prefix
    parser.add_argument('--safe_prefix', type=str, default=None,
                       help='Safe prefix to use')
    # Add model type argument to explicitly specify model type
    parser.add_argument('--model_type', type=str, default=None,
                       choices=['sd_v1', 'sd_v2', 'sd_v3', 'sdxl', 'flux'],
                       help='Explicitly specify the model type (optional)')
    # Add SDXL refiner arguments
    parser.add_argument('--use_refiner', action='store_true',
                       help='Use SDXL refiner (only applicable for SDXL models)')
    parser.add_argument('--refiner_path', type=str, default="stabilityai/stable-diffusion-xl-refiner-1.0",
                       help='Path to SDXL refiner model')
    parser.add_argument('--high_noise_frac', type=float, default=0.8,
                       help='Fraction of noise steps to run on base model vs refiner (only for SDXL with refiner)')
    args = parser.parse_args()

    device = torch.device(args.device)
    seed_everything(args.seed)
    
    # Check which model type to use, either from explicit arg or by inferring from path
    model_type = args.model_type
    if model_type is None:
        if "stable-diffusion-3" in args.model_path:
            model_type = "sd_v3"
        elif "stable-diffusion-xl" in args.model_path or "sdxl" in args.model_path.lower():
            model_type = "sdxl"
        elif "flux" in args.model_path.lower():
            model_type = "flux"
        elif "stable-diffusion-2" in args.model_path:
            model_type = "sd_v2"
        else:
            model_type = "sd_v1"  # Default to v1.x
    
    print(f"Using model type: {model_type} from path: {args.model_path}")
    
    # Determine whether to use safety checker based on arguments and model version
    use_safety_checker = args.use_image_filtering and model_type in ["sd_v1", "sd_v2"]
    
    # Load pipeline based on model type
    if model_type == "sd_v3":
        print(f"Loading Stable Diffusion V3: {args.model_path}")
        pipe = StableDiffusion3Pipeline.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16
        ).to(device)
    elif model_type == "sdxl":
        print(f"Loading Stable Diffusion XL base: {args.model_path}")
        pipe = DiffusionPipeline.from_pretrained(
            args.model_path,
            torch_dtype=torch.float16,
            use_safetensors=True,
            variant="fp16"
        ).to(device)
        
        # Initialize refiner if requested
        refiner = None
        if args.use_refiner:
            print(f"Loading Stable Diffusion XL refiner: {args.refiner_path}")
            refiner = DiffusionPipeline.from_pretrained(
                args.refiner_path,
                text_encoder_2=pipe.text_encoder_2,  # Share text_encoder_2
                vae=pipe.vae,  # Share VAE
                torch_dtype=torch.float16,
                use_safetensors=True,
                variant="fp16"
            ).to(device)
    elif model_type == "flux":
        print(f"Loading FLUX: {args.model_path}")
        pipe = FluxPipeline.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16
        ).to(device)
        # pipe.enable_model_cpu_offload()
    elif model_type == "sd_v2":
        print(f"Loading Stable Diffusion V2: {args.model_path}")
        # For SD v2, use Euler scheduler
        scheduler = EulerDiscreteScheduler.from_pretrained(args.model_path, subfolder="scheduler")
        
        if use_safety_checker:
            print("Image safety checker enabled for SD v2")
            pipe = StableDiffusionPipeline.from_pretrained(
                args.model_path, 
                scheduler=scheduler,
            ).to(device)
        else:
            pipe = StableDiffusionPipeline.from_pretrained(
                args.model_path, 
                scheduler=scheduler,
                safety_checker=None
            ).to(device)
    else:  # sd_v1
        print(f"Loading Stable Diffusion V1.x: {args.model_path}")
        
        if use_safety_checker:
            print("Image safety checker enabled for SD v1")
            pipe = StableDiffusionPipeline.from_pretrained(
                args.model_path
            ).to(device)
        else:
            pipe = StableDiffusionPipeline.from_pretrained(
                args.model_path,
                safety_checker=None
            ).to(device)
            
        # Set scheduler to DDIM for V1
        pipe.scheduler = DDIMScheduler.from_pretrained(args.model_path, subfolder="scheduler")
    
    def verify_unet_weights(unet, method_name):
        print(f"\n=== Verifying {method_name} UNet weights ===")
        # 1. Check model parameters sum
        param_sum = sum(p.sum().item() for p in unet.parameters())
        print(f"Parameter sum: {param_sum:.4f}")
        
        # 2. Check number of parameters
        total_params = sum(p.numel() for p in unet.parameters())
        print(f"Total parameters: {total_params:,}")
        
        # 3. Check a specific layer's weights
        sample_layer = list(unet.parameters())[0]
        print(f"First layer mean: {sample_layer.mean().item():.4f}")
        print(f"First layer std: {sample_layer.std().item():.4f}")
        
        return param_sum, total_params
    
    if model_type != "sd_v3" and model_type != "flux" and model_type != "sdxl":
        original_state = verify_unet_weights(pipe.unet, "Original")

        # Load U-Net based on training method
        if args.training_method == 'uce':
            print(f'Training method: {args.training_method}')
            if args.unet_path is None:
                checkpoint = torch.load("checkpoints/unlearning/v1.5/uce.pt", map_location=device)
            else:
                checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
        elif args.training_method == 'esd':
            print(f'Training method: {args.training_method}')
            if args.unet_path is None:
                checkpoint = torch.load("checkpoints/unlearning/v1.5/esd.pt", map_location=device)
                # Get current state dict
                state_dict = pipe.unet.state_dict()
                
                for module_name, module_dict in checkpoint.items():
                    key = module_name.replace('unet.', '')
                    for param_name, param in module_dict.items():
                        full_key = f"{key}.{param_name}"
                        state_dict[full_key] = param
                
                pipe.unet.load_state_dict(state_dict)
            else:
                checkpoint = torch.load(args.unet_path, map_location=device)
                pipe.unet.load_state_dict(checkpoint)
        elif args.training_method == 'fmn':
            print(f'Training method: {args.training_method}')
            if args.unet_path is None:
                from safetensors.torch import load_file
                checkpoint = load_file("checkpoints/unlearning/v1.5/fmn.safetensors")
            else:
                checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
        elif args.training_method == 'salun':
            print(f'Training method: {args.training_method}')
            if args.unet_path is None:
                checkpoint = torch.load("checkpoints/unlearning/v1.5/salun.pt", map_location=device)
            else:
                checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
        elif args.training_method == 'spm':
            print(f'Training method: {args.training_method}')
            if args.unet_path is None:
                pipe.load_lora_weights("checkpoints/unlearning/v1.5/spm.safetensors")
            else:
                checkpoint = torch.load(args.unet_path, map_location=device)
                pipe.unet.load_state_dict(checkpoint)
        elif args.training_method == 'safegen':
            if args.unet_path is None:
                pipe = StableDiffusionPipeline.from_pretrained(
                    "LetterJohn/SafeGen-Pretrained-Weights",
                    safety_checker=None
                ).to(device)
            else:
                raise ValueError("unet_path is required for safegen (or reproduce it)")
    
        # Verify loaded weights
        new_state = verify_unet_weights(pipe.unet, f"Loaded {args.training_method}")
    
        # Compare states
        if original_state != new_state:
            print("\n✅ Weights were successfully updated")
        else:
            print("\n⚠️ Warning: Weights might not have changed")
    
    # Load text encoder based on training method
    if args.training_method == 'des':
        print(f'Training method: {args.training_method}')
        if args.text_encoder_path:
            print('Load DES text encoder')
            checkpoint = torch.load(args.text_encoder_path, map_location=device)
            missing_keys, unexpected_keys = pipe.text_encoder.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"Warning: Missing keys in text_encoder: {missing_keys}")
            print(f"Warning: Unexpected keys in text_encoder: {unexpected_keys}")
            # Free up GPU memory after loading weights
            del checkpoint
            torch.cuda.empty_cache()
        if args.text_encoder_2_path:
            print('Load DES text encoder 2')
            
            if model_type == "flux":
                print('GPU memory optimization: Loading text_encoder_2 weights on CPU for FLUX model')
                original_device = pipe.text_encoder_2.device
                pipe.text_encoder_2 = pipe.text_encoder_2.to('cpu')
                checkpoint = torch.load(args.text_encoder_2_path, map_location='cpu')
                missing_keys, unexpected_keys = pipe.text_encoder_2.load_state_dict(checkpoint['model_state_dict'], strict=False)
                pipe.text_encoder_2 = pipe.text_encoder_2.to(original_device)
                print(f"Text encoder 2 moved back to {original_device}")
            else:
                checkpoint = torch.load(args.text_encoder_2_path, map_location=device)
                missing_keys, unexpected_keys = pipe.text_encoder_2.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"Warning: Missing keys in text_encoder_2: {missing_keys}")
            print(f"Warning: Unexpected keys in text_encoder_2: {unexpected_keys}")
            
            if args.use_refiner and refiner is not None:
                print('Ensuring refiner uses updated text_encoder_2')
                refiner.text_encoder_2 = pipe.text_encoder_2
            
            # Free up GPU memory after loading weights
            del checkpoint
            torch.cuda.empty_cache()
        if args.text_encoder_3_path:
            print('Load DES text encoder 3')
            checkpoint = torch.load(args.text_encoder_3_path, map_location=device)
            missing_keys, unexpected_keys = pipe.text_encoder_3.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"Warning: Missing keys in text_encoder_3: {missing_keys}")
            print(f"Warning: Unexpected keys in text_encoder_3: {unexpected_keys}")
            # Free up GPU memory after loading weights
            del checkpoint
            torch.cuda.empty_cache()
    elif args.training_method == 'advunlearn':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "OPTML-Group/AdvUnlearn",
            subfolder="nudity_unlearned",
        ).to(device)
        # Clear cache after loading model
        torch.cuda.empty_cache()
    elif args.training_method == 'visu':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "aimagelab/safeclip_vit-l_14",
        ).to(device)
        # Clear cache after loading model
        torch.cuda.empty_cache()
    
    pipe.text_encoder.eval()
    
    # Handle prompts
    if args.prompts_csv:
        # Existing CSV handling
        df = pd.read_csv(args.prompts_csv, header=None)
        prompts = df[0].replace(r'^\s*$', np.nan, regex=True).dropna().tolist()
        
        # Adjust end index if not specified
        if args.end_idx is None:
            args.end_idx = len(prompts)
    else:
        # Use single prompt multiple times
        prompts = [args.prompt] * (args.end_idx - args.start_idx)
        
        # Don't use fixed seed when repeating the same prompt
        seed_everything(None)  # Use random seeds
    
    # Process each prompt in the specified range
    for i in range(args.start_idx, args.end_idx):
        if args.prompts_csv:
            prompt = prompts[i]
        else:
            prompt = args.prompt
            # Generate new random seed for each iteration
            current_seed = i
            seed_everything(current_seed)
            
        prompt_num = i + 1
        
        print(f"\nProcessing prompt {prompt_num}/{args.end_idx}: {prompt}")
        
        if args.use_prompt_filtering:
            with open('google_words.txt', 'r') as f:
                target_tokens = [line.strip() for line in f]
                    
            filtered_prompt = prompt
            for token in target_tokens:
                filtered_prompt = filtered_prompt.replace(token, '')
            print('Input prompt:', prompt)
            prompt = ' '.join(filtered_prompt.split())
            print('Filtered prompt:', prompt)
        
        # Add safe prefix if provided
        original_prompt = prompt
        if args.safe_prefix is not None:
            prompt = f"{args.safe_prefix} {prompt}"
            print('Original prompt:', original_prompt)
            print(f"Applied safe prefix: '{args.safe_prefix}' → '{prompt}'")
        
        # Generate image using pipeline with model-specific parameters
        generator = torch.Generator(device=device).manual_seed(current_seed) if 'current_seed' in locals() else None
        
        with torch.no_grad():
            # Model-specific generation parameters
            if model_type == "flux":
                image = pipe(
                    prompt=prompt,
                    height=args.height,
                    width=args.width,
                    num_inference_steps=args.num_inference_steps,
                    guidance_scale=args.guidance_scale,
                    max_sequence_length=512,  # FLUX-specific parameter
                    generator=generator
                ).images[0]
            elif model_type == "sdxl":
                # SDXL with optional refiner
                if args.use_refiner and refiner:
                    print(f"Using SDXL with refiner (high_noise_frac={args.high_noise_frac})")
                    # First pass with base model
                    latents = pipe(
                        prompt=prompt,
                        num_inference_steps=args.num_inference_steps,
                        guidance_scale=args.guidance_scale,
                        denoising_end=args.high_noise_frac,
                        output_type="latent",
                        generator=generator
                    ).images
                    
                    # Second pass with refiner model
                    image = refiner(
                        prompt=prompt,
                        height=args.height,
                        width=args.width,
                        num_inference_steps=args.num_inference_steps,
                        guidance_scale=args.guidance_scale,
                        denoising_start=args.high_noise_frac,
                        image=latents,
                        generator=generator
                    ).images[0]
                else:
                    # Standard SDXL without refiner
                    image = pipe(
                        prompt=prompt,
                        height=args.height,
                        width=args.width,
                        num_inference_steps=args.num_inference_steps,
                        guidance_scale=args.guidance_scale,
                        generator=generator
                    ).images[0]
            elif model_type == "sd_v3":
                image = pipe(
                    prompt=prompt,
                    height=args.height,
                    width=args.width,
                    num_inference_steps=args.num_inference_steps,
                    guidance_scale=args.guidance_scale,
                    generator=generator
                ).images[0]
            else:
                # Check if negative prompt should be applied (only for v1.x)
                if model_type == "sd_v1" and args.negative_prompt is not None:
                    image = pipe(
                        prompt=prompt,
                        negative_prompt=args.negative_prompt,
                        height=args.height,
                        width=args.width,
                        num_inference_steps=args.num_inference_steps,
                        guidance_scale=args.guidance_scale,
                        generator=generator
                    ).images[0]
                else:
                    # Standard generation without negative prompt
                    image = pipe(
                        prompt=prompt,
                        height=args.height,
                        width=args.width,
                        num_inference_steps=args.num_inference_steps,
                        guidance_scale=args.guidance_scale,
                        generator=generator
                    ).images[0]
        
        # Create output filename and save
        output_filename = f"{args.output_path}_{args.training_method}_ip"
        output_filename = output_filename.replace("_ip", f"_{prompt_num}p")
        os.makedirs('results', exist_ok=True)
        image.save(os.path.join('results', output_filename + ".png"))
        print(f"Generated image saved to: {output_filename}")

if __name__ == "__main__":
    main() 