from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers import LMSDiscreteScheduler
import torch
from PIL import Image
import pandas as pd
import argparse
import os
import sys
import glob

def generate_images(prompts_path, save_path, device='cuda:0', guidance_scale = 7.5, image_size=512, ddim_steps=100, 
                    num_samples=10, from_case=0,cache_dir='./ldm_pretrained',concept='vangogh', ckpt=None, 
                    encoder_type=None, encoder_ckpt=None, unet_type=None, unet_ckpt=None):
    '''
    Function to generate images from diffusers code
    
    The program requires the prompts to be in a csv format with headers 
        1. 'case_number' (used for file naming of image)
        2. 'prompt' (the prompt used to generate image)
        3. 'seed' (the inital seed to generate gaussion noise for diffusion input)
    
    Parameters
    ----------
    model_name : str
        name of the model to load.
    prompts_path : str
        path for the csv file with prompts and corresponding seeds.
    save_path : str
        save directory for images.
    device : str, optional
        device to be used to load the model. The default is 'cuda:0'.
    guidance_scale : float, optional
        guidance value for inference. The default is 7.5.
    image_size : int, optional
        image size. The default is 512.
    ddim_steps : int, optional
        number of denoising steps. The default is 100.
    num_samples : int, optional
        number of samples generated per prompt. The default is 10.
    from_case : int, optional
        The starting offset in csv to generate images. The default is 0.

    Returns
    -------
    None.

    '''
    dir_ = "stable-diffusion-v1-5/stable-diffusion-v1-5"
        
    # 1. Load the autoencoder model which will be used to decode the latents into image space.
    vae = AutoencoderKL.from_pretrained(dir_, subfolder="vae",cache_dir=cache_dir)
    # 2. Load the tokenizer and text encoder to tokenize and encode the text.
    tokenizer = CLIPTokenizer.from_pretrained(dir_, subfolder="tokenizer",cache_dir=cache_dir)
    if encoder_type == 'des':
        text_encoder = CLIPTextModel.from_pretrained(dir_, subfolder="text_encoder",cache_dir=cache_dir)
        checkpoint = torch.load(encoder_ckpt, map_location=device)
        text_encoder.load_state_dict(checkpoint['model_state_dict'], strict=False)
    elif encoder_type == 'advunlearn':
        text_encoder = CLIPTextModel.from_pretrained("OPTML-Group/AdvUnlearn", subfolder="nudity_unlearned",cache_dir=cache_dir)
        if encoder_ckpt:
            print('Load AdvUnlearn text encoder')
            checkpoint = torch.load(encoder_ckpt, map_location=device)
            model_state_dict = text_encoder.state_dict()
            modified_state_dict = {}
            for ckpt_key, ckpt_value in checkpoint.items():
                if ckpt_key.startswith("text_encoder.text_model."):
                    model_key = ckpt_key[len("text_encoder.text_model."):]
                    full_model_key = f"text_model.{model_key}"
                    if full_model_key in model_state_dict:
                        modified_state_dict[full_model_key] = ckpt_value
                elif ckpt_key.startswith("text_encoder."):
                    model_key = ckpt_key[len("text_encoder."):]
                    if model_key in model_state_dict:
                        modified_state_dict[model_key] = ckpt_value
                    elif f"text_model.{model_key}" in model_state_dict:
                        modified_state_dict[f"text_model.{model_key}"] = ckpt_value
            text_encoder.load_state_dict(modified_state_dict, strict=True)
    elif encoder_type == 'visu':
        text_encoder = CLIPTextModel.from_pretrained("aimagelab/safeclip_vit-l_14",cache_dir=cache_dir)
    else:
        text_encoder = CLIPTextModel.from_pretrained(dir_, subfolder="text_encoder",cache_dir=cache_dir)
    # 3. The UNet model for generating the latents.
    unet = UNet2DConditionModel.from_pretrained(dir_, subfolder="unet",cache_dir=cache_dir)
    
    
    def verify_unet_weights(unet, method_name):
        print(f"\n=== Verifying {method_name} UNet weights ===")
        # 1. Check model parameters sum
        param_sum = sum(p.sum().item() for p in unet.parameters())
        print(f"Parameter sum: {param_sum:.4f}")
        
        # 2. Check number of parameters
        total_params = sum(p.numel() for p in unet.parameters())
        print(f"Total parameters: {total_params:,}")
        
        # 3. Check a specific layer's weights
        sample_layer = list(unet.parameters())[0]
        print(f"First layer mean: {sample_layer.mean().item():.4f}")
        print(f"First layer std: {sample_layer.std().item():.4f}")
        
        return param_sum, total_params
    original_state = verify_unet_weights(unet, "Original")

    # Unet weights
    if unet_type == 'uce':
        print("UCE")
        unet.load_state_dict(torch.load(unet_ckpt))
    elif unet_type == 'esd':
        print("ESD")
        try:
            unet.load_state_dict(torch.load(unet_ckpt))
        except:
            checkpoint = torch.load(unet_ckpt, map_location=device)
            # Get current state dict
            state_dict = unet.state_dict()
            
            # checkpoint의 각 모듈에 대해
            for module_name, module_dict in checkpoint.items():
                key = module_name.replace('unet.', '')
                # module_dict는 OrderedDict({'weight': tensor(...), 'bias': tensor(...)})
                for param_name, param in module_dict.items():
                    full_key = f"{key}.{param_name}"  # 예: "conv_in.weight"
                    state_dict[full_key] = param
            
            unet.load_state_dict(state_dict)
    elif unet_type == 'salun':
        print("SALUN")
        unet.load_state_dict(torch.load(unet_ckpt))
    elif unet_type == 'spm':
        print("SPM")
        try:
            unet.load_state_dict(torch.load(unet_ckpt))
        except:
            from diffusers import DiffusionPipeline
            from copy import deepcopy
            pipe = DiffusionPipeline.from_pretrained(
                dir_,
            ).to(device)
            pipe.load_lora_weights(unet_ckpt)
            unet = deepcopy(pipe.unet)
            del pipe
    elif unet_type == 'ed':
        print("ED")
        unet.load_state_dict(torch.load(unet_ckpt))
    elif unet_type == 'sh':
        print("SH")
        unet.load_state_dict(torch.load(unet_ckpt))
    elif unet_type == 'fmn':
        print("FMN")
        try:
            unet.load_state_dict(torch.load(unet_ckpt))
        except:
            from safetensors.torch import load_file
            unet.load_state_dict(load_file(unet_ckpt))
    else:
        print("No Unet")
    
    # Verify loaded weights
    new_state = verify_unet_weights(unet, f"Loaded {unet_type}")
    
    # Compare states
    if original_state != new_state:
        print("\n✅ Weights were successfully updated")
    else:
        print("\n⚠️ Warning: Weights might not have changed")
    
    scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

    if ckpt is not None:
        unet.load_state_dict(torch.load(ckpt, map_location=device))

    vae.to(device)
    text_encoder.to(device)
    unet.to(device)
    torch_device = device
    df = pd.read_csv(prompts_path)

    folder_path = f'{save_path}/{concept}'
    os.makedirs(folder_path, exist_ok=True)
    os.makedirs(f'{folder_path}/imgs', exist_ok=True)
    repeated_rows = []
    for i, row in df.iterrows():
        prompt = [str(row.prompt)]*num_samples
        seed = row.evaluation_seed if hasattr(row,'evaluation_seed') else row.sd_seed
        case_number = row.case_number if hasattr(row,'case_number') else i
        repeated_rows.extend([row]*num_samples)
        if case_number<from_case:
            continue
        
        height = row.sd_image_height if hasattr(row, 'sd_image_height') else image_size # default height of Stable Diffusion
        width = row.sd_image_width if hasattr(row, 'sd_image_width') else image_size                         # default width of Stable Diffusion

        num_inference_steps = ddim_steps           # Number of denoising steps

        guidance_scale = row.sd_guidance_scale if hasattr(row, 'sd_guidance_scale') else guidance_scale            # Scale for classifier-free guidance

        generator = torch.manual_seed(seed)        # Seed generator to create the inital latent noise

        batch_size = len(prompt)

        text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]

        max_length = text_input.input_ids.shape[-1]
        uncond_input = tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]

        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        latents = torch.randn(
            (batch_size, unet.config.in_channels, height // 8, width // 8),
            generator=generator,
        )
        latents = latents.to(torch_device)

        scheduler.set_timesteps(num_inference_steps)

        latents = latents * scheduler.init_noise_sigma

        from tqdm.auto import tqdm

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        for num, im in enumerate(pil_images):
            im.save(f"{folder_path}/imgs/{case_number}_{num}.png")
    new_df = pd.DataFrame(repeated_rows)
    new_df.to_csv(os.path.join(folder_path,'prompts.csv'),index=False)


if __name__=='__main__':
    parser = argparse.ArgumentParser(
                    prog = 'generateImages',
                    description = 'Generate Images using Diffusers Code')
    parser.add_argument('--prompts_path', help='path to csv file with prompts', type=str, required=True)
    parser.add_argument('--concept', help='concept to attack', type=str, required=True)
    parser.add_argument('--save_path', help='folder where to save images', type=str, required=True)
    parser.add_argument('--device', help='cuda device to run on', type=str, required=False, default='cuda:0')
    parser.add_argument('--guidance_scale', help='guidance to run eval', type=float, required=False, default=7.5)
    parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
    parser.add_argument('--from_case', help='continue generating from case_number', type=int, required=False, default=0)
    parser.add_argument('--num_samples', help='number of samples per prompt', type=int, required=False, default=1)
    parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=25)
    parser.add_argument('--cache_dir', help='cache directory', type=str, required=False, default='./.cache')
    parser.add_argument('--ckpt', help='ckpt path', type=str, required=False, default=None)
    parser.add_argument('--encoder_type', help='encoder type', type=str, required=False, default=None, choices=['des', 'advunlearn', 'visu', None])
    parser.add_argument('--encoder_ckpt', help='encoder ckpt path', type=str, required=False, default=None)
    parser.add_argument('--unet_type', help='unet type', type=str, required=False, default=None, choices=['esd', 'fmn', 'spm', 'salun', 'uce', 'gloce', None])
    parser.add_argument('--unet_ckpt', help='unet ckpt path', type=str, required=False, default=None)
    args = parser.parse_args()
    
    prompts_path = args.prompts_path
    save_path = args.save_path
    device = args.device
    guidance_scale = args.guidance_scale
    image_size = args.image_size
    ddim_steps = args.ddim_steps
    num_samples= args.num_samples
    from_case = args.from_case
    cache_dir  = args.cache_dir
    concept = args.concept
    generate_images( prompts_path, save_path, device=device,
                    guidance_scale = guidance_scale, image_size=image_size, ddim_steps=ddim_steps, num_samples=num_samples,from_case=from_case,cache_dir=cache_dir,concept=concept, ckpt=args.ckpt,
                    encoder_type=args.encoder_type, encoder_ckpt=args.encoder_ckpt, unet_type=args.unet_type, unet_ckpt=args.unet_ckpt)