from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
import os
import json
from PIL import Image
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Generate images from I2P dataset")

    parser.add_argument("--output_dir", type=str, help="Output directory")
    parser.add_argument("--model_path", type=str, help="Path to model checkpoint", default="CompVis/stable-diffusion-v1-4")
    parser.add_argument("--special_token", type=str, help="Special token for image generation", default="")
    # Custom CLIP encoder arguments
    parser.add_argument('--training_method', type=str, default=None,
                       choices=['des', 'advunlearn', 'visu', 
                                'uce', 'esd', 'fmn', 'salun', 'spm', None],
                       help='Training method used (if None, use original CLIP)')
    parser.add_argument('--text_encoder_path', type=str, default=None,
                       help='Path to trained checkpoint')
    parser.add_argument('--unet_path', type=str, default=None,
                       help='Path to trained U-Net checkpoint')

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()

    metadata = []

    device = "cuda" if torch.cuda.is_available() else "cpu"

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "train"), exist_ok=True)

    data = load_dataset('AIML-TUDA/i2p', split='train')

    print("Number of images: ", len(data))

    pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device)
    
    def verify_unet_weights(unet, method_name):
        print(f"\n=== Verifying {method_name} UNet weights ===")
        # 1. Check model parameters sum
        param_sum = sum(p.sum().item() for p in unet.parameters())
        print(f"Parameter sum: {param_sum:.4f}")
        
        # 2. Check number of parameters
        total_params = sum(p.numel() for p in unet.parameters())
        print(f"Total parameters: {total_params:,}")
        
        # 3. Check a specific layer's weights
        sample_layer = list(unet.parameters())[0]
        print(f"First layer mean: {sample_layer.mean().item():.4f}")
        print(f"First layer std: {sample_layer.std().item():.4f}")
        
        return param_sum, total_params
    
    original_state = verify_unet_weights(pipe.unet, "Original")

    # Load U-Net based on training method
    if args.training_method == 'uce':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/uce.pt", map_location=device)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'esd':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/esd.pt", map_location=device)
            # Get current state dict
            state_dict = pipe.unet.state_dict()
            
            for module_name, module_dict in checkpoint.items():
                key = module_name.replace('unet.', '')
                for param_name, param in module_dict.items():
                    full_key = f"{key}.{param_name}"
                    state_dict[full_key] = param
            
            pipe.unet.load_state_dict(state_dict)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'fmn':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            from safetensors.torch import load_file
            checkpoint = load_file("checkpoints/unlearning/v1.5/fmn.safetensors")
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'salun':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/salun.pt", map_location=device)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'spm':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            pipe.load_lora_weights("checkpoints/unlearning/v1.5/spm.safetensors")
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
    # Verify loaded weights
    new_state = verify_unet_weights(pipe.unet, f"Loaded {args.training_method}")

    # Compare states
    if original_state != new_state:
        print("\n✅ Weights were successfully updated")
    else:
        print("\n⚠️ Warning: Weights might not have changed")
    
    # Load text encoder based on training method
    if args.training_method == 'des':
        print(f'Training method: {args.training_method}')
        if args.text_encoder_path:
            print('Load DES text encoder')
            checkpoint = torch.load(args.text_encoder_path, map_location=device)
            missing_keys, unexpected_keys = pipe.text_encoder.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"Warning: Missing keys in text_encoder: {missing_keys}")
            print(f"Warning: Unexpected keys in text_encoder: {unexpected_keys}")
    elif args.training_method == 'advunlearn':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "OPTML-Group/AdvUnlearn",
            subfolder="nudity_unlearned",
        ).to(device)
    elif args.training_method == 'visu':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "aimagelab/safeclip_vit-l_14",
        ).to(device)
    
    pipe.text_encoder.eval()

    count = 0
    for item in data:
        if args.special_token != "":
            prompt = [item["prompt"]]
        else:
            prompt = [args.special_token + " " + item["prompt"]]
        seed = item["sd_seed"]

        gen = torch.Generator(device)
        gen.manual_seed(seed)
        
        out =  pipe(prompt=prompt, generator=gen, negative_prompt=["nudity"])

        for image in out.images:
            
            file_name = "img_{}.png".format(count)
            image.save(os.path.join(args.output_dir, "train", file_name))
            count += 1

            metadata.append({"file_name": file_name, "prompt": prompt})

    with open(os.path.join(args.output_dir, "metadata.json"), "w") as f:
        json.dump(metadata, f)
