import torch

import pandas as pd
import time


from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer


from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from diffusers import DPMSolverMultistepScheduler 
from diffusers import DDIMScheduler
from diffusers import DDPMScheduler
from diffusers import PNDMScheduler


import argparse
import os
from tqdm import tqdm

import pandas as pd




@torch.no_grad
def get_image(latents, nrow, ncol, vae):
    image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
    image = (image / 2 + 0.5).clamp(0, 1).squeeze()
    if len(image.shape) < 4:
        image = image.unsqueeze(0)
    image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8)
    rows = []
    for row_i in range(nrow):
        row = []
        for col_i in range(ncol):
            i = row_i * nrow + col_i
            row.append(image[i])
        rows.append(torch.hstack(row))
    image = torch.vstack(rows)
    return Image.fromarray(image.cpu().numpy())


@torch.no_grad
def get_text_embedding(prompt, tokenizer, text_encoder, device="cuda:0"):
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    return text_encoder(text_input.input_ids.to(device=device))[0]

@torch.no_grad
def get_noise(t,  latents, embeddings, unet):
    v = lambda _x, _e: unet(
        _x, t, encoder_hidden_states=_e
    ).sample
    embeds = torch.cat(embeddings)

    
    vel = v(latents, embeds)
        
    return vel






    
def run(args):




    sd_model="CompVis/stable-diffusion-v1-4"

    vae = AutoencoderKL.from_pretrained(sd_model, subfolder="vae", use_safetensors=True)
    tokenizer = CLIPTokenizer.from_pretrained(sd_model, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(
        sd_model, subfolder="text_encoder", use_safetensors=True
    )
    unet = UNet2DConditionModel.from_pretrained(
        sd_model, subfolder="unet", use_safetensors=True
    )

    
    vae.to(device=args.device)
    text_encoder.to(device=args.device)
    unet.to(device=args.device)

    # scheduler = EulerDiscreteScheduler.from_pretrained(sd_model, subfolder="scheduler")
    # scheduler = PNDMScheduler.from_pretrained(sd_model, subfolder="scheduler")
    # scheduler = DPMSolverMultistepScheduler.from_pretrained(sd_model, subfolder="scheduler")
    # scheduler = DDIMScheduler.from_pretrained(sd_model, subfolder="scheduler")
    
    # if args.sampler=="DDPM":
    #     scheduler = DDPMScheduler.from_pretrained(sd_model, subfolder="scheduler")
    # else:
    #     scheduler = DPMSolverMultistepScheduler.from_pretrained(sd_model, subfolder="scheduler")


    scheduler = DDPMScheduler.from_pretrained(sd_model, subfolder="scheduler")

    

    save_name = (
            args.prompt_set
        )
    
    save_path = f"{args.save_path}_gamma_{args.gamma}_thres_{args.thres}_{args.k}"
    if args.cfg:
        args.save_path = "./results_cfg"
    elif args.np:
        args.save_path = "./results_np"

    
    save_dir = os.path.join(save_path, save_name)
    print("file save directory name:", save_dir)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)


    print("Running inference...")
    start_time = time.time() 

    dataset = pd.read_csv(args.prompt_path)
    
    # unsafe_num = torch.zeros(len(dataset), args.num_inference_steps)




    for _num, data in dataset.iterrows():
        if _num < args.start_from:
            continue


        print(_num)

        if "adv_prompt" in data:
            adv_prompt = data['adv_prompt']

        elif "sensitive prompt" in data:
            adv_prompt = data["sensitive prompt"]

        elif "prompt" in data:
            adv_prompt = data["prompt"]

        if hasattr(data, 'guidance'):
            guidance = data.guidance
        elif hasattr(data, 'evaluation_guidance'):
            guidance = data.evaluation_guidance
        elif hasattr(data, 'sd_guidance_scale'):
            guidance = data.sd_guidance_scale
        else:
            guidance = args.guidance_scale

        seed = args.seed
        if hasattr(data, 'evaluation_seed'):
            seed = data.evaluation_seed




        prompt = [adv_prompt]
        neg_prompt = [args.neg_prompt]


        prompt_embeddings = get_text_embedding(prompt * args.batch_size, tokenizer=tokenizer, text_encoder=text_encoder, device=args.device)
        neg_embeddings = get_text_embedding(neg_prompt * args.batch_size, tokenizer=tokenizer, text_encoder=text_encoder, device=args.device)
        uncond_embeddings = get_text_embedding([""] * args.batch_size, tokenizer=tokenizer, text_encoder=text_encoder, device=args.device)

        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
        generator = torch.Generator(device=args.device)
        generator.manual_seed(seed)  
        latents = torch.randn(
            (args.batch_size, unet.config.in_channels, args.height // 8, args.width // 8),
            generator=generator,
            device=args.device,
        )


        


        scheduler.set_timesteps(args.num_inference_steps)



        a_bar_t_list = []
        a_t_list = []
   

        neg_ori_pred_list = []
        uncond_ori_pred_list = []
        safe_ori_pred_list = []

        x_t_list = []

        
        # prior = 0
        log_p_neg = torch.zeros((args.num_inference_steps+1,1), device=args.device) 
        log_p_neg_accumulate = torch.zeros((args.num_inference_steps+1,1), device=args.device) 

        log_p_neg_accumulate_test = torch.zeros((args.num_inference_steps+1,1), device=args.device) 

        mi = torch.zeros((args.num_inference_steps + 1,1), device=args.device)
        for i, t in tqdm(enumerate(scheduler.timesteps), colour="MAGENTA"):

            # Predict noise from diffusion model
            noise_prompt = get_noise(t,  latents, [prompt_embeddings], unet)
            noise_uncond = get_noise(t,  latents, [uncond_embeddings], unet)
            noise_neg = get_noise(t,  latents, [neg_embeddings], unet)

    


            a_bar_t = scheduler.alphas_cumprod[scheduler.timesteps[i]]
            a_bar_t_list.append(a_bar_t)

            
            if i == args.num_inference_steps - 1:
                # a_bar_tmin1 = torch.zeros_like(a_bar_t) + 1
                a_bar_tmin1 = a_bar_t
            else:
                a_bar_tmin1 = scheduler.alphas_cumprod[scheduler.timesteps[i+1]]
            
            a_t = a_bar_t/a_bar_tmin1


                
            noise_pred = noise_uncond + guidance*(noise_prompt-noise_uncond)
            noise_cfg = noise_uncond + guidance*(noise_prompt-noise_uncond)

          
            


            # class_weight = torch.zeros_like(a_t)
            # for j in range(args.num_inference_steps):
            #     if j < i:
            #         continue
                
            #     a_bar_t_temp = scheduler.alphas_cumprod[scheduler.timesteps[j]]
            #     if j == args.num_inference_steps - 1:
            #         # a_bar_tmin1 = torch.zeros_like(a_bar_t) + 1
            #         a_bar_tmin1_temp = a_bar_t_temp
            #     else:
            #         a_bar_tmin1_temp = scheduler.alphas_cumprod[scheduler.timesteps[j+1]]

            #     a_t_temp = a_bar_t_temp/a_bar_tmin1_temp

            #     class_weight += (1-a_t_temp)/2/a_t_temp/(1-a_bar_t_temp)


            # class_weight = 1 / 2 / a_bar_t  * 2

            class_weight = 1/2 * args.k







            current_gap = torch.zeros((args.batch_size),device=args.device)
            
            current_gap += class_weight*(       ((noise_pred - noise_uncond)**2).sum((1,2,3)) - ((noise_pred - noise_neg)**2).sum((1,2,3))      )

            # current_gap += log_p_neg[i]

            current_gap += mi[i]
            
            
            print(f"currrent_gap: {current_gap}")
            print(f"mi: {mi[i]}")
            

            


            # Calculate for optimal negative guidance       

            under = torch.zeros_like((current_gap), device=args.device)
           

            under +=  2*class_weight*((   (noise_neg-noise_uncond)*(noise_neg - noise_uncond )    )).sum((1,2,3))
            


            # under = torch.clip(under, -1e6 , 2*args.clip)

            # print(class_weight)

            

            # safe_guidance = torch.zeros_like(under)
            if i < args.num_inference_steps - 1:
                safe_guidance = torch.clip(current_gap - args.thres, 0, 1e6) 
                safe_guidance = (   safe_guidance/under*1  )

                safe_guidance = safe_guidance*args.gamma


                safe_guidance = torch.clip(safe_guidance, 0, args.max_clip)

                print(safe_guidance)
                safe_guidance = safe_guidance.reshape([-1,1,1,1])

            else:
                safe_guidance = torch.zeros_like(under)
           
            


            

            # Find optimal mixed guidance.

            if args.cfg:
                noise_pred = noise_uncond + guidance*(noise_prompt-noise_uncond)
            elif args.np:
                noise_pred = noise_neg + guidance*(noise_prompt-noise_neg)
            else:
                noise_pred = noise_uncond + guidance*(noise_prompt - noise_uncond) - safe_guidance * (noise_neg - noise_uncond)
                noise_cfg= noise_uncond + guidance*(noise_prompt - noise_uncond)

            # noise_pred = noise_uncond + guidance*(noise_prompt-noise_uncond)

           
            

            prev_latents = latents.clone()
           
            latents = scheduler.step(noise_pred, t, latents)['prev_sample']
            


            k_t = (1-a_t)/2/a_t/(1-a_bar_t)
            
            mi[i+1] = mi[i] + k_t*(((noise_pred - noise_uncond)**2).sum((1,2,3)) - ((noise_pred-noise_neg)**2).sum((1,2,3)))

            # mi[i+1] = mi[i] + k_t*(((noise_cfg - noise_uncond)**2).sum((1,2,3)) - ((noise_cfg-noise_neg)**2).sum((1,2,3)))




            # assert not torch.isnan(log_p_neg[i+1]).any()

            # p_neg[i+1] = torch.clamp(p_neg[i+1], min=args.p_min, max=args.p_max)

            # breakpoint()
            # print(log_p_neg[i])

  
        


        print("===================================")
        
        img_mixed_to_save = get_image(latents[0].unsqueeze(0), 1, 1 ,vae=vae)
        img_mixed_to_save.save(os.path.join(save_dir, f"{_num}.png"))
        # breakpoint()




    end_time = time.time()  
    execution_time = end_time - start_time  
    print(f"The function took {execution_time:.4f} seconds to run.")



    



def main():
    # arguments
    parser = argparse.ArgumentParser(description="Your script description here")
    parser.add_argument('--num_inference_steps', type=int, default=50)
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--batch_size", type=int, default=1)

    parser.add_argument("--height", type=int, default=512)
    parser.add_argument("--width", type=int, default=512)

    # Guidance for positive prompt
    parser.add_argument("--guidance_scale", type=float, default=7.5)


    parser.add_argument("--prompt_set", type=str, default="MMA-Diff")

    # Concept to remove


    parser.add_argument("--neg_prompt", type=str, default="Sexual Fantasy, Nudity, Pornography, Erotic Art, Nude, Naked, Sexual Acts") # v7
    parser.add_argument("--prompt_path", type=str, default=None)


    parser.add_argument("--save_path", type=str, default="./results")
    parser.add_argument("--device", type=str, default="cuda:0")


    # Maximum negative guidance scale
    parser.add_argument("--max_clip", type=float, default=150.0)
    # Threshold
    parser.add_argument("--thres", type=float, default=30.0)

    parser.add_argument("--k", type=float, default=16.0)
    # Clip
    parser.add_argument("--clip", type=float, default=150.0)
    parser.add_argument("--gamma", type=float, default=1.8)

    parser.add_argument("--start_from", type=int, default=0)

    parser.add_argument("--np", type=bool, default=False)
    parser.add_argument("--cfg", type=bool, default=False)

    parser.add_argument("--sampler", type=str, default="DPM")

    args = parser.parse_args()

    
    print(f"Prompt: {args.prompt_set}")
    print(f"Prompt: {args.neg_prompt}")
    
    print("Script is running with the provided arguments.\n")
    print(args)
    run(args)
    

if __name__ == "__main__":
    main()

