import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
from pytorch_lightning import seed_everything
from PIL import Image
import torch.nn.functional as F
import numpy as np
import argparse
from tqdm import tqdm
import os
import pandas as pd
from dataclasses import dataclass

@dataclass
class CLIPEmbeddingInfo:
    prompt: str
    embedding: np.ndarray  # [seq_len, hidden_dim]

# Add argument parser
parser = argparse.ArgumentParser(description='Image to Image generation with attention control')
parser.add_argument('--model_path', type=str, default="runwayml/stable-diffusion-inpainting",
                    help='Path to the stable diffusion model')
parser.add_argument('--device', type=str, default="cuda:0",
                    help='Device to use for generation')
parser.add_argument('--seed', type=int, default=42,
                    help='Seed for random number generator')
parser.add_argument('--prompts_csv', type=str, required=True,
                    help='Path to CSV file containing prompts')
parser.add_argument('--start_idx', type=int, default=0,
                    help='Starting index for prompts')
parser.add_argument('--end_idx', type=int, default=None,
                    help='Ending index for prompts (None for all)')
# Image generation arguments
parser.add_argument('--height', type=int, default=512,
                    help='Output image height')
parser.add_argument('--width', type=int, default=512,
                    help='Output image width')
parser.add_argument('--num_inference_steps', type=int, default=50,
                    help='Number of denoising steps')
parser.add_argument('--guidance_scale', type=float, default=7.5,
                    help='Guidance scale for classifier free guidance')
parser.add_argument('--strength', type=float, default=0.6,
                    help='Strength for noise addition')
parser.add_argument('--output_path', type=str, default="output.png")
# Prompt Filtering
parser.add_argument('--use_prompt_filtering', action='store_true',
                    help='Use prompt filtering')
# Custom CLIP encoder arguments
parser.add_argument('--training_method', type=str, default=None,
                    choices=['des', 'advunlearn', 'visu', 
                            'uce', None],
                    help='Training method used (if None, use original CLIP)')
parser.add_argument('--text_encoder_path', type=str, default=None,
                    help='Path to trained checkpoint')
# Additional arguments
parser.add_argument('--wb_prompts_csv', type=str, default=None,
                    help='Path to CSV file containing white-box prompts')
parser.add_argument('--image_path', type=str, default=None,
                    help='Path to image to use')
parser.add_argument('--image_type', type=str, default='default',
                    choices=['default', 'adv'],
                    help='Type of image to use')
args = parser.parse_args()

device = torch.device(args.device)
seed_everything(args.seed)

# Load models
tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder="tokenizer")
vae = AutoencoderKL.from_pretrained(args.model_path, subfolder="vae").to(device)
scheduler = DDIMScheduler.from_pretrained(args.model_path, subfolder="scheduler")

# Load U-Net
unet = UNet2DConditionModel.from_pretrained(args.model_path, subfolder="unet").to(device)
if args.training_method == 'uce':
    print(f'Training method: {args.training_method}')
    checkpoint = torch.load("checkpoints/unlearning/uce.pt", map_location=device)
    unet.load_state_dict(checkpoint)

# Load text encoder based on training method
if args.training_method == 'des':
    print(f'Training method: {args.training_method}')
    text_encoder = CLIPTextModel.from_pretrained(
        args.model_path,
        subfolder="text_encoder",
    ).to(device)
    if args.text_encoder_path:
        print('Load DES text encoder')
        checkpoint = torch.load(args.text_encoder_path, map_location=device)
        text_encoder.load_state_dict(checkpoint['model_state_dict'])
elif args.training_method == 'advunlearn':
    print(f'Training method: {args.training_method}')
    text_encoder = CLIPTextModel.from_pretrained(
        "OPTML-Group/AdvUnlearn",
        subfolder="nudity_unlearned",
    ).to(device)
elif args.training_method == 'visu':
    print(f'Training method: {args.training_method}')
    text_encoder = CLIPTextModel.from_pretrained(
        "aimagelab/safeclip_vit-l_14",
    ).to(device)
else:
    print(f'Training method: {args.training_method}')
    text_encoder = CLIPTextModel.from_pretrained(
        args.model_path,
        subfolder="text_encoder",
    ).to(device)

text_encoder.eval()

# Load prompts from CSV and filter out empty rows
df = pd.read_csv(args.prompts_csv)

if args.wb_prompts_csv:
    # Read white-box prompts CSV without header
    wb_df = pd.read_csv(args.wb_prompts_csv, header=None)
    prompts = wb_df[0].tolist()[1:]
else:
    # Ensure adv_prompt column exists
    if 'adv_prompt' not in df.columns:
        raise ValueError("CSV must contain 'adv_prompt' column when wb_prompts_csv is not provided")
    
    # Filter out empty rows in adv_prompt column
    df = df.dropna(subset=['adv_prompt'])
    prompts = df['adv_prompt'].tolist()
img_folder = args.image_path

# Adjust end index if not specified
if args.end_idx is None:
    args.end_idx = len(prompts)

# Process each prompt in the specified range
for i in range(args.start_idx, args.end_idx):
    prompt = prompts[i]
    prompt_num = i + 1
    
    # Get corresponding file paths from the same row
    if args.image_type == 'adv':
        input_image_path = df['adv_image'].iloc[i]
    else:
        input_image_path = df['file_name'].iloc[i]
    mask_image_path = df['mask'].iloc[i]
    
    print(f"\nProcessing prompt {prompt_num}/{args.end_idx}: {prompt}")
    print(f"Input image: {input_image_path}")
    print(f"Mask image: {mask_image_path}")
    
    if args.use_prompt_filtering:
        with open('google_words.txt', 'r') as f:
            target_tokens = [line.strip() for line in f]
                
        filtered_prompt = prompt
        for token in target_tokens:
            filtered_prompt = filtered_prompt.replace(token, '')
        # 연속된 공백 제거 및 앞뒤 공백 제거
        print('Input prompt:', prompt)
        prompt = ' '.join(filtered_prompt.split())
        print('Filtered prompt:', prompt)

    # Process text embeddings
    tokens = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    ).to(device)
    
    input_ids = tokens.input_ids
    
    with torch.no_grad():
        text_embeddings = text_encoder(input_ids.to(device))[0]

    # Get unconditioned embeddings
    uncond_input = tokenizer(
        "",
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)
    
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input)[0]
        
    # Concatenate for classifier-free guidance
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        
    init_image = Image.open(os.path.join(img_folder, input_image_path)).convert("RGB")
    init_image = init_image.resize((args.width, args.height))

    # I2I, not Inpainting
    init_image = torch.from_numpy(np.array(init_image)).float() / 127.5 - 1
    init_image = init_image.permute(2, 0, 1).unsqueeze(0).to(device)
    init_latents = vae.encode(init_image).latent_dist.sample()
    init_latents = 0.18215 * init_latents
    
    if mask_image_path is not None:
        # Update mask loading with path from CSV
        mask_image = Image.open(os.path.join(img_folder, mask_image_path)).convert("RGB")
        mask_image = mask_image.resize((args.width, args.height))
        
        # Convert mask to binary tensor [0, 1]
        mask = np.array(mask_image)
        mask = mask.mean(axis=2) > 127.5
        mask = mask.astype(np.float32)
        mask = torch.from_numpy(mask)[None, None]  # [1, 1, H, W]
        
        # Create masked image with correct dtype
        mask = mask.to(device=device)
        masked_image = init_image * (1 - mask)
        
        # Encode masked image
        masked_image_latents = vae.encode(masked_image).latent_dist.sample()
        masked_image_latents = 0.18215 * masked_image_latents
        
        # Resize mask for latent space
        mask = F.interpolate(mask, size=(args.height // 8, args.width // 8))
    
    # 7. Set timesteps
    scheduler.set_timesteps(args.num_inference_steps)
    timesteps = scheduler.timesteps
    init_timestep = int(args.num_inference_steps * args.strength)
    timesteps = timesteps[-init_timestep:]

    latents = init_latents
    t_start = timesteps[0]
    noise = torch.randn_like(latents)
    latents = scheduler.add_noise(latents, noise, t_start)

    # 8. Denoising loop
    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising"):
        # expand the latents for classifier free guidance
        latent_model_input = torch.cat([latents] * 2)
        
        # Concatenate latents with mask and masked image for inpainting
        latent_model_input = torch.cat(
            [
                latent_model_input,
                mask.repeat(2, 1, 1, 1),
                masked_image_latents.repeat(2, 1, 1, 1)
            ],
            dim=1
        )
        
        # predict noise residual
        with torch.no_grad():
            noise_pred = unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings
            ).sample
        
            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond)
            
        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # 9. Decode latents
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample

    # 10. Convert to image
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).round().astype("uint8")
    image = Image.fromarray(image[0])
            
    # Create output filename and save
    output_filename = args.output_path.replace("_ip", f"_{prompt_num}p")
    os.makedirs('results', exist_ok=True)
    image.save(os.path.join('results', output_filename + ".png"))
    print(f"Generated image saved to: {output_filename}")
