from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
import os
import json
from PIL import Image
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Generate images from I2P dataset")

    parser.add_argument("--output_dir", type=str, help="Output directory")
    parser.add_argument("--model_path", type=str, help="Path to model checkpoint", default="CompVis/stable-diffusion-v1-4")
    parser.add_argument("--seed", type=int, help="Seed for random number generator", default=0)
    parser.add_argument("--prompt", type=str, help="Prompt for image generation")
    parser.add_argument("--mode", type=str, help="Mode for image generation", choices=["train","test"], default="train")
    parser.add_argument("--num_train_images", type=int, help="Number of images to generate for training", default=1000)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    pipe = StableDiffusionPipeline.from_pretrained(args.model_path, safety_checker=None, torch_dtype=torch.float16).to(device)

    def verify_unet_weights(unet, method_name):
        print(f"\n=== Verifying {method_name} UNet weights ===")
        # 1. Check model parameters sum
        param_sum = sum(p.sum().item() for p in unet.parameters())
        print(f"Parameter sum: {param_sum:.4f}")
        
        # 2. Check number of parameters
        total_params = sum(p.numel() for p in unet.parameters())
        print(f"Total parameters: {total_params:,}")
        
        # 3. Check a specific layer's weights
        sample_layer = list(unet.parameters())[0]
        print(f"First layer mean: {sample_layer.mean().item():.4f}")
        print(f"First layer std: {sample_layer.std().item():.4f}")
        
        return param_sum, total_params
    
    original_state = verify_unet_weights(pipe.unet, "Original")

    # Load U-Net based on training method
    if args.training_method == 'uce':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/uce.pt", map_location=device)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'esd':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/esd.pt", map_location=device)
            # Get current state dict
            state_dict = pipe.unet.state_dict()
            
            # checkpoint의 각 모듈에 대해
            for module_name, module_dict in checkpoint.items():
                key = module_name.replace('unet.', '')
                # module_dict는 OrderedDict({'weight': tensor(...), 'bias': tensor(...)})
                for param_name, param in module_dict.items():
                    full_key = f"{key}.{param_name}"  # 예: "conv_in.weight"
                    state_dict[full_key] = param
            
            pipe.unet.load_state_dict(state_dict)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'fmn':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            from safetensors.torch import load_file
            checkpoint = load_file("checkpoints/unlearning/v1.5/fmn.safetensors")
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'salun':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            checkpoint = torch.load("checkpoints/unlearning/v1.5/salun.pt", map_location=device)
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
        pipe.unet.load_state_dict(checkpoint)
    elif args.training_method == 'spm':
        print(f'Training method: {args.training_method}')
        if args.unet_path is None:
            pipe.load_lora_weights("checkpoints/unlearning/v1.5/spm.safetensors")
        else:
            checkpoint = torch.load(args.unet_path, map_location=device)
            pipe.unet.load_state_dict(checkpoint)

    # Verify loaded weights
    new_state = verify_unet_weights(pipe.unet, f"Loaded {args.training_method}")

    # Compare states
    if original_state != new_state:
        print("\n✅ Weights were successfully updated")
    else:
        print("\n⚠️ Warning: Weights might not have changed")
    
    # Load text encoder based on training method
    if args.training_method == 'des':
        print(f'Training method: {args.training_method}')
        if args.text_encoder_path:
            print('Load DES text encoder')
            checkpoint = torch.load(args.text_encoder_path, map_location=device)
            missing_keys, unexpected_keys = pipe.text_encoder.load_state_dict(checkpoint['model_state_dict'], strict=False)
            print(f"Warning: Missing keys in text_encoder: {missing_keys}")
            print(f"Warning: Unexpected keys in text_encoder: {unexpected_keys}")
    elif args.training_method == 'advunlearn':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "OPTML-Group/AdvUnlearn",
            subfolder="nudity_unlearned",
        ).to(device)
    elif args.training_method == 'visu':
        print(f'Training method: {args.training_method}')
        from transformers import CLIPTextModel
        pipe.text_encoder = CLIPTextModel.from_pretrained(
            "aimagelab/safeclip_vit-l_14",
        ).to(device)
    
    pipe.text_encoder.eval()

    if(args.mode == "test"):

        gen = torch.Generator(device)

        os.makedirs(args.output_dir, exist_ok=True)

        gen.manual_seed(args.seed)

        out = pipe(prompt=[args.prompt], generator=gen)

        image = out.images[0]

        #save image
        image.save(os.path.join(args.output_dir, "img.png"))

    elif(args.mode == "train"):

        os.makedirs(args.output_dir, exist_ok=True)
        os.makedirs(os.path.join(args.output_dir), exist_ok=True)

        for i in range(args.num_train_images):
            while True:
                output = pipe(prompt=args.prompt)
                image = output.images[0]
                nsfw = output.nsfw_content_detected
                #check if nsfw is a list
                if isinstance(nsfw, list):
                    nsfw = nsfw[0]
                
                if not nsfw:
                    break
            image.save(os.path.join(args.output_dir, f"{args.prompt}_{i}.png"))
            
        
    