import os

os.environ['CURL_CA_BUNDLE'] = ''
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
import sys
sys.path.append("Anonymous")
from ddpo_pytorch.adaptive_model import Discriminator_P
from ddpo_pytorch.reward_judger import RewardJudger
import argparse
import torch
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
    StableDiffusionPipeline,
    rescale_noise_cfg,
)
import tqdm
from diffusers import DDIMScheduler
import numpy as np
from PIL import Image
from einops import rearrange
from ddim_with_logprob import ddim_step_with_logprob_LitE, ddim_step_forward_from_xt_1_to_xt
from torch import nn
from torch.cuda.amp import autocast
from diffusers.utils.torch_utils import randn_tensor
import math
import os
import random
from torchvision import transforms
from preference_models import get_preference_model_func, get_compare_func
from typing import Any, Callable, Dict, Optional, Union, List, Tuple
import spacy
from torch.nn import functional as F
from diffusers import StableDiffusionPipeline
from ddpo_pytorch.attention_utils import (
    AttentionStore,
    AttendExciteAttnProcessor,
    get_attention_maps_list)
import matplotlib.pyplot as plt
from DiffAugment import DiffAugment
import json

from utils import *
from format_pos_neg_pairs import process_p_n_groups

POLICY = 'color,translation,resize,cutout'


###########

def augment_image(image):
    augmentations = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5, contrast=0.5),
        transforms.RandomRotation(degrees=15),
    ]
    return [aug(image) for aug in augmentations]


def register_attention_control(pipeline):
    attn_res = (16, 16)
    pipeline.attention_store = AttentionStore(attn_res)
    attn_procs = {}
    cross_att_count = 0
    for name in pipeline.unet.attn_processors.keys():
        if name.startswith("mid_block"):
            place_in_unet = "mid"
        elif name.startswith("up_blocks"):
            place_in_unet = "up"
        elif name.startswith("down_blocks"):
            place_in_unet = "down"
        else:
            continue

        cross_att_count += 1
        attn_procs[name] = AttendExciteAttnProcessor(
            attnstore=pipeline.attention_store, place_in_unet=place_in_unet
        )

    pipeline.unet.set_attn_processor(attn_procs)
    pipeline.attention_store.num_att_layers = cross_att_count
    return pipeline


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def normal(feat, eps=1e-5):
    feat_mean, feat_std = calc_mean_std(feat, eps)
    normalized = (feat - feat_mean) / feat_std
    return normalized


def single_inner_step(pipeline,preference_model_fn,alignment_model_fn, noise_pred, t, latents_grad, extra_info,extra_step_kwargs, generator,replicate=10,p=None):
    """
    This function performs the DDIM step, decodes latents, augments images,
    and computes the preference and alignment losses.

    Args:
        pipeline: The model pipeline that contains the scheduler and VAE.
        noise_pred: The predicted noise tensor.
        t: The timestep for the DDIM step.
        latents_grad: Gradients for the latents.
        extra_info: A dictionary containing extra information such as 'prompt' and model functions.
        replicate: Number of times to replicate the image (default is 10).
        POLICY: The policy for augmenting images (default is None).

    Returns:
        loss_preference: The preference loss.
        loss_alignment: The alignment loss.
        pred_x0_dno_images: The augmented images after processing.
    """

    # Perform the DDIM step
    latents_prev_sample, pred_x0_latents, log_prob = ddim_step_with_logprob_LitE(
        pipeline.scheduler, noise_pred, t, latents_grad, variance_noise=None, **extra_step_kwargs
    )

    # Decode latents into image space
    pred_x0_dno = pipeline.vae.decode(
        pred_x0_latents.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=extra_info.get('generator', None),
    )[0]

    # Detach the tensor to avoid gradient computation
    pred_x0_dno_no_grad = pred_x0_dno.clone().detach()

    # Replicate and augment the images
    pred_x0_dno_images = DiffAugment(pred_x0_dno.repeat(replicate, 1, 1, 1), policy=POLICY)
    pred_x0_dno_images = pred_x0_dno_images.reshape(-1,10,3,512,512)
    pred_x0_dno_images = torch.cat([pred_x0_dno.unsqueeze(1), pred_x0_dno_images], 1)
    pred_x0_dno_images =pred_x0_dno_images.reshape(-1,3,512,512)
    # Prepare input prompts for the models
    input_prompts = [extra_info['prompt'] for _ in range(pred_x0_dno_images.shape[0])]
    input_clip_prompts =  [extra_info['input_clip_prompts'] for _ in range(pred_x0_dno_images.shape[0])]
    if p <= 0.5:
    # Compute preference loss
        loss_preference_all = preference_model_fn(pred_x0_dno_images, input_prompts)[0]
        loss_preference_all = loss_preference_all.reshape(-1,11)
       #loss_preference = -torch.linalg.norm(loss_preference_all)
        loss_preference = loss_preference_all.mean(-1)
        loss = loss_preference*25
    if p > 0.5:
        loss = []
        # Compute alignment loss
        for each_p in np.array(input_clip_prompts).T:
            loss_alignment_all = alignment_model_fn(pred_x0_dno_images, each_p, metadata='v2')[0]
            loss_alignment_all = loss_alignment_all.reshape(-1, 11)
            # loss_preference = -torch.linalg.norm(loss_preference_all)
            loss_alignment = loss_alignment_all.mean(-1)
            loss.append(loss_alignment.unsqueeze(0))
        loss = torch.cat(loss,dim=0)
        #np.array(loss).mean()
        loss = loss.mean(0)*100

    return loss, pred_x0_dno_images, latents_prev_sample, pred_x0_latents

def compute_custom_loss(
    pipeline,
    pred_x0_latents,
    generator,
    DiffAugment,
    POLICY,
    extra_info,
    t,
    preference_model_fn,
    contrast_pairs,
    attention_maps_list,
    pred_x0_prev=None,
    args=None
):
    pred_x0_dno = pipeline.vae.decode(
        pred_x0_latents.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
        return_dict=False,
        generator=generator,
    )[0]

    pred_x0_dno_no_grad = pred_x0_dno.clone().detach()

    replicate = 10
    pred_x0_dno_images = DiffAugment(pred_x0_dno.repeat(replicate, 1, 1, 1), policy=POLICY)
    pred_x0_dno_images = pred_x0_dno_images.reshape(-1,10,3,512,512)
    pred_x0_dno_images = torch.cat([pred_x0_dno.unsqueeze(1), pred_x0_dno_images], 1)
    pred_x0_dno_images =pred_x0_dno_images.reshape(-1,3,512,512)

    extra_info['timesteps'] = t.repeat(pred_x0_dno_images.shape[0])
    preference_model_input_ids = [extra_info['input_ids'] for _ in range(pred_x0_dno_images.shape[0])]
    preference_model_input_ids = torch.cat(preference_model_input_ids, dim=0)
    # extra_info['input_ids'] = preference_model_input_ids


    loss_preference = preference_model_fn(pred_x0_dno_images, {**extra_info,'input_ids':preference_model_input_ids}).reshape(-1,11)
    loss_preference = -torch.linalg.norm(loss_preference, dim=-1)


    positive_groups, negative_pairs = contrast_pairs["pos"], contrast_pairs["neg"]
    loss_pos = torch.tensor(0.0).repeat(pred_x0_latents.shape[0]).cuda()
    for (i, j) in positive_groups:
        if i >= 77 or j >= 77:
            continue
        loss_pos += cos_dist(attention_maps_list[i], attention_maps_list[j])

    loss_neg = torch.tensor(0.0).repeat(pred_x0_latents.shape[0]).cuda()
    for (i, j) in negative_pairs:
        if i >= 77 or j >= 77:
            continue
        loss_neg -= cos_dist(attention_maps_list[i], attention_maps_list[j])

    loss_cos = loss_pos / (len(positive_groups) + 1e-6) + loss_neg / (len(negative_pairs) + 1e-6)
    if pred_x0_prev is not None:

        x_lat = normal(pred_x0_latents)
        x_prev = normal(pred_x0_prev)

        if x_lat.ndim == 1:
            x_lat = x_lat.unsqueeze(0)
        if x_prev.ndim == 1:
            x_prev = x_prev.unsqueeze(0)

        delta_content_relative = torch.norm(x_lat - x_prev, p=2, dim=(-1,-2,-3)) / torch.norm(x_prev, p=2, dim=(-1,-2,-3))

        k = 1
        alpha_attn = 10 * (1 - torch.exp(-k * delta_content_relative))
        alpha_preference = torch.exp(-k * delta_content_relative)
    else:
        alpha_attn = 1.0
        alpha_preference = 1.0

 
    if args.flag_dpo_loss:
        total_loss = loss_cos * 10
    elif args.flag_attn_loss:
        total_loss = loss_cos * alpha_attn + loss_preference * alpha_preference
    else:
        total_loss = loss_preference

    return total_loss, pred_x0_dno_no_grad

def optimize_step_LitE_second(args, latents, pipeline, t, pred_x0_prev, prompt_embeds,
                       do_classifier_free_guidance, cross_attention_kwargs, preference_model_fn, extra_info,
                       generator, contrast_pairs, extra_step_kwargs, extra_lite_kwargs):
    latents_grad = latents.detach().requires_grad_(True)
    p =  extra_info['p']
    l2_weight =  extra_info['l2_weight']
    par_num = extra_lite_kwargs['par_num']
    travel_iter = extra_lite_kwargs['travel_iter']
    latent_prev_0 = extra_lite_kwargs['latent_prev_0']
    total_loss_0 = extra_lite_kwargs['total_loss_0']
    reward_judger = extra_lite_kwargs['reward_judger']
    with torch.enable_grad():
        latent_model_input = torch.cat([latents_grad] * 2) if do_classifier_free_guidance else latents_grad
        latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
        prompt_embeds_input = prompt_embeds.repeat(par_num, 1, 1) 
        noise_pred = pipeline.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_input,
            cross_attention_kwargs=cross_attention_kwargs,
            return_dict=False,
        )[0]
        if par_num==1:
            attention_maps = pipeline.attention_store.aggregate_attention(
                from_where=("up", "down", "mid"),
            )
        else:
            attention_maps = pipeline.attention_store.aggregate_seperate_attention(
                from_where=("up", "down", "mid"),
            )
        attention_maps_list = get_attention_maps_list(attention_maps=attention_maps)

        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + args.cfg_scale * (noise_pred_text - noise_pred_uncond)
            correction = noise_pred_text - noise_pred_uncond

        if do_classifier_free_guidance and args.cfg_rescale > 0.0:
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text,
                                           guidance_rescale=args.cfg_rescale)
        latents_prev_sample, pred_x0_latents, log_prob = ddim_step_with_logprob_LitE(pipeline.scheduler, noise_pred,
                                                                                    t, latents_grad,
                                                                                    variance_noise=None,
                                                                                    **extra_step_kwargs)
        reward_judger.update(pred_x0_latents)
        diff = latent_prev_0.repeat(par_num, 1, 1, 1) - latents_prev_sample
        l2_loss = torch.sqrt(torch.sum(diff ** 2, dim=(-1, -2, -3)))

        total_loss, pred_x0_dno_no_grad = compute_custom_loss(
            pipeline,
            pred_x0_latents,
            generator,
            DiffAugment,
            POLICY,
            extra_info,
            t,
            preference_model_fn,
            contrast_pairs,
            attention_maps_list,
            pred_x0_prev,
            args
        )


        # if extra_lite_kwargs['travel_iter']==args.num_optimization:
            # begin selection:
        delta_loss = total_loss-total_loss_0
        if delta_loss.mean()<-0 and par_num==1 and travel_iter>0:
            delta_latent= latents_prev_sample-latent_prev_0
            print('\n------single updated-----')
            return latents_prev_sample.detach(), pred_x0_latents.detach(), delta_latent.detach()
        elif par_num == 1 and travel_iter == extra_info['range_i']:
            print('\n------fail to updated-----')
            delta_latent=None
            pred_x0_latents = latent_prev_0
            latents_prev_sample = latent_prev_0
            return latents_prev_sample.detach(), pred_x0_latents.detach(), delta_latent
        elif par_num>1 and travel_iter>0:

            delta_loss_min = torch.argmin(delta_loss).item()
            if delta_loss[delta_loss_min]<-0:
                print('\n------mont carlo updated-----')
                pred_x0_latents = pred_x0_latents[delta_loss_min].unsqueeze(0)
                latents_prev_sample = latents_prev_sample[delta_loss_min].unsqueeze(0)
                delta_latent = latents_prev_sample-latent_prev_0
                delta_latent=delta_latent.detach()
            else:
                print('\n------fail to updated-----')
                pred_x0_latents = latent_prev_0
                latents_prev_sample = latent_prev_0

                delta_latent = None
            return latents_prev_sample.detach(), pred_x0_latents.detach(), delta_latent

        print('delta_loss',delta_loss)
        print('l2_loss',l2_loss)
        delta_loss+=l2_loss*(1-p)*l2_weight
        grad_direction = torch.autograd.grad(outputs=delta_loss.mean(), inputs=latents_grad)[0]
        if args.flag_dpo_loss:
            if (grad_direction * grad_direction).sum().sqrt().item() == 0.0:
                rho = 0.01
            else:
                rho = (correction * correction).sum().sqrt().item() * args.cfg_scale / (grad_direction * grad_direction).sum().sqrt().item() * 0.2
        elif args.flag_attn_loss:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()
        else:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()

        latents_prev_sample = latents_prev_sample - args.opt_stength * rho * grad_direction.detach()

        # if args.flag_round_stop:
        #     grad_abs = torch.norm(grad_direction.detach(), p=2) * args.num_round
        #     args.num_optimization = int(grad_abs)
        #     args.flag_round_stop = False
        #     # print("num_optimization", args.num_optimization)

        return  latents_prev_sample.detach(), pred_x0_latents.detach(), log_prob.detach(), grad_direction.detach(),pred_x0_dno_no_grad




def optimize_step_LitE_first(args, latents, pipeline, t, pred_x0_prev, prompt_embeds,
                       do_classifier_free_guidance, cross_attention_kwargs, preference_model_fn, extra_info,
                       generator, contrast_pairs, extra_step_kwargs, delta_latent=None):
    latents_grad = latents.detach().requires_grad_(True)
    p =  extra_info['p']
    with torch.enable_grad():
        latent_model_input = torch.cat([latents_grad] * 2) if do_classifier_free_guidance else latents_grad
        latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
        noise_pred = pipeline.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=cross_attention_kwargs,
            return_dict=False,
        )[0]

        attention_maps = pipeline.attention_store.aggregate_attention(
            from_where=("up", "down", "mid"),
        )
        attention_maps_list = get_attention_maps_list(attention_maps=attention_maps)

        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + args.cfg_scale * (noise_pred_text - noise_pred_uncond)
            correction = noise_pred_text - noise_pred_uncond

        if do_classifier_free_guidance and args.cfg_rescale > 0.0:
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text,
                                           guidance_rescale=args.cfg_rescale)
        latents_prev_sample, pred_x0_latents, log_prob = ddim_step_with_logprob_LitE(pipeline.scheduler, noise_pred,
                                                                                    t, latents_grad,
                                                                                    variance_noise=None,
                                                                                    **extra_step_kwargs)
        latents_prev_0 = latents_prev_sample.clone()

        total_loss, pred_x0_dno_no_grad = compute_custom_loss(
            pipeline,
            pred_x0_latents,
            generator,
            DiffAugment,
            POLICY,
            extra_info,
            t,
            preference_model_fn,
            contrast_pairs,
            attention_maps_list,
            pred_x0_prev,
            args
        )
        grad_direction = torch.autograd.grad(outputs=total_loss, inputs=latents_grad)[0]
        if args.flag_dpo_loss:
            if (grad_direction * grad_direction).sum().sqrt().item() == 0.0:
                rho = 0.01
            else:
                rho = (correction * correction).sum().sqrt().item() * args.cfg_scale / (grad_direction * grad_direction).sum().sqrt().item() * 0.2
        elif args.flag_attn_loss:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()
        else:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()

        latents_prev_sample = latents_prev_sample - args.opt_stength * rho * grad_direction.detach()

        if args.flag_round_stop:
            grad_abs = torch.norm(grad_direction.detach(), p=2) * args.num_round
            args.num_optimization = int(grad_abs)
            args.flag_round_stop = False
            # print("num_optimization", args.num_optimization)

        return latents_prev_0.detach(), total_loss, latents_prev_sample.detach(), pred_x0_latents.detach(), log_prob.detach(), grad_direction.detach(), pred_x0_dno_no_grad




def optimize_step_LitE(args, latents, pipeline, t, pred_x0_prev, prompt_embeds,
                       do_classifier_free_guidance, cross_attention_kwargs, preference_model_fn, extra_info,
                       generator, contrast_pairs, extra_step_kwargs, delta_latent=None):
    latents_grad = latents.detach().requires_grad_(True)
    p =  extra_info['p']
    with torch.enable_grad():
        latent_model_input = torch.cat([latents_grad] * 2) if do_classifier_free_guidance else latents_grad
        latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
        noise_pred = pipeline.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=cross_attention_kwargs,
            return_dict=False,
        )[0]

        attention_maps = pipeline.attention_store.aggregate_attention(
            from_where=("up", "down", "mid"),
        )
        attention_maps_list = get_attention_maps_list(attention_maps=attention_maps)

        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + args.cfg_scale * (noise_pred_text - noise_pred_uncond)
            correction = noise_pred_text - noise_pred_uncond

        if do_classifier_free_guidance and args.cfg_rescale > 0.0:
            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text,
                                           guidance_rescale=args.cfg_rescale)

        latents_prev_sample, pred_x0_latents, log_prob = ddim_step_with_logprob_LitE(pipeline.scheduler, noise_pred,
                                                                                    t, latents_grad,
                                                                                    variance_noise=None,
                                                                                    **extra_step_kwargs)



        latents_prev_sample = delta_latent * p + latents_prev_sample
        pred_x0_dno = pipeline.vae.decode(
            pred_x0_latents.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
            return_dict=False,
            generator=generator,
        )[0]

        pred_x0_dno_no_grad = pred_x0_dno.clone().detach()


        replicate = 10
        pred_x0_dno_images = DiffAugment(pred_x0_dno.repeat(replicate, 1, 1, 1), policy=POLICY)
        pred_x0_dno_images = torch.cat([pred_x0_dno, pred_x0_dno_images], 0)
        extra_info['timesteps'] = t.repeat(pred_x0_dno_images.shape[0])
        preference_model_input_ids = [extra_info['input_ids'] for _ in range(pred_x0_dno_images.shape[0])]
        preference_model_input_ids = torch.cat(preference_model_input_ids, dim=0)
        extra_info['input_ids'] = preference_model_input_ids
        loss_preference = preference_model_fn(pred_x0_dno_images, extra_info)
        loss_preference = -torch.linalg.norm(loss_preference)

        # positive-negative loss
        positive_groups, negative_pairs = contrast_pairs["pos"], contrast_pairs["neg"]
        loss_pos = torch.tensor(0.0).cuda()
        for (i, j) in positive_groups:
            if i >= 77:
                continue
            elif j >= 77:
                continue
            loss_pos += cos_dist(attention_maps_list[i], attention_maps_list[j])

        loss_neg = torch.tensor(0.0).cuda()
        for (i, j) in negative_pairs:
            if i >= 77:
                continue
            elif j >= 77:
                continue
            loss_neg -= cos_dist(attention_maps_list[i], attention_maps_list[j])

        loss_cos = loss_pos / (len(positive_groups) + 1e-6) + loss_neg / (len(negative_pairs) + 1e-6)
        if pred_x0_prev is not None:
            delta_content_relative = torch.norm(normal(pred_x0_latents) - normal(pred_x0_prev), p=2) / torch.norm(normal(pred_x0_prev), p=2)
            k = 1
            alpha_attn = 10 * (1 - torch.exp(-k * delta_content_relative))
            alpha_preference = torch.exp(-k * delta_content_relative)

        if args.flag_dpo_loss:
            total_loss = loss_cos * 10

        elif args.flag_attn_loss:
            total_loss = loss_cos * alpha_attn + loss_preference * alpha_preference
        else:
            total_loss = loss_preference
        if args.num_optimization>1:
            print(total_loss)
        grad_direction = torch.autograd.grad(outputs=total_loss, inputs=latents_grad)[0]
        if args.flag_dpo_loss:
            if (grad_direction * grad_direction).sum().sqrt().item() == 0.0:
                rho = 0.01
            else:
                rho = (correction * correction).sum().sqrt().item() * args.cfg_scale / (grad_direction * grad_direction).sum().sqrt().item() * 0.2
        elif args.flag_attn_loss:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()
        else:
            rho = (correction * correction).sum().sqrt().item() / (grad_direction * grad_direction).sum().sqrt().item()

        latents_prev_sample = latents_prev_sample - args.opt_stength * rho * grad_direction.detach()

        if args.flag_round_stop:
            grad_abs = torch.norm(grad_direction.detach(), p=2) * args.num_round
            args.num_optimization = int(grad_abs)
            args.flag_round_stop = False
            # print("num_optimization", args.num_optimization)

        return latents_prev_sample.detach(), pred_x0_latents.detach(), log_prob.detach(), grad_direction.detach(), pred_x0_dno_no_grad


def main(args, prompt, save_name, pipeline, preference_model_fn, contrast_pairs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # random seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)


    pipeline.to(device)
    pipeline.unet = pipeline.unet.eval()
    # freeze parameters of models to save more memory
    pipeline.vae.requires_grad_(False)
    pipeline.text_encoder.requires_grad_(False)
    pipeline.unet.requires_grad_(False)

    pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
    pipeline.scheduler.alphas_cumprod = pipeline.scheduler.alphas_cumprod.to(device)

    generator = torch.Generator(device=device).manual_seed(args.seed)
    do_classifier_free_guidance = args.cfg_scale > 1.0

    # generate negative prompt embeddings
    neg_prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            [""],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(device)
    )[0]

    prompt_embed = pipeline.text_encoder(
        pipeline.tokenizer(
            prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=pipeline.tokenizer.model_max_length,
        ).input_ids.to(device)
    )[0]

    prompt_input_ids = pipeline.tokenizer(
        prompt,
        max_length=pipeline.tokenizer.model_max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids.to(device)

    # the size of generated image
    height = args.height
    width = args.width
    rewardjudger = RewardJudger()
    pipeline.scheduler.set_timesteps(args.num_timesteps, device=device)
    timesteps = pipeline.scheduler.timesteps
    disnet_p= Discriminator_P(64, label_len=15)
    disnet_p_path = 'Anonymous/checkpoints/checkpoint_10/disnet_p'
    disnet_p.load_state_dict(torch.load(disnet_p_path, map_location=device))  # ckpts/disnet_p
    disnet_p.to(device)
    disnet_p.eval()
    # Prepare latent variables
    num_channels_latents = pipeline.unet.config.in_channels

    height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
    width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor

    cross_attention_kwargs = {}
    text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None

    prompt_embeds = pipeline._encode_prompt(
        None,
        device,
        1,
        do_classifier_free_guidance,
        None,
        prompt_embeds=prompt_embed,
        negative_prompt_embeds=neg_prompt_embed,
        lora_scale=text_encoder_lora_scale,
    )

    latents_ori = pipeline.prepare_latents(
        1,  # num_images_per_prompt
        num_channels_latents,
        height,
        width,
        prompt_embeds.dtype,
        device,
        generator,
        None
    )

    latents = latents_ori.clone().detach()
    latents_ori = latents_ori.clone().detach()

    eta = 1.0
    extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
    extra_info = {}

    pred_x0_prev = None
    delta_latent = None
    save_inter_array = [0,10,20,30,35,40,45,47,49]
    with pipeline.progress_bar(total=timesteps.shape[0]) as progress_bar:
        for i, t in enumerate(timesteps):
            latents_prev_sample = None

            args.num_optimization = 1
            args.flag_optimize = True
            args.flag_round = True
            if i >= args.time2:
                args.flag_attn_loss = False
                args.flag_dpo_loss = False
                args.num_round = 3
            elif (i >= 0) and (i < args.time1):
                args.flag_dpo_loss = True
                args.flag_attn_loss = False
                args.num_round = 1
                if (contrast_pairs["pos"] == []) and (contrast_pairs["neg"] == []):
                    args.flag_optimize = False
            elif (i >= args.time1) and (i < args.time2):
                args.flag_dpo_loss = False
                args.flag_attn_loss = True
                args.num_round = 1
            else:
                args.flag_optimize = False

            if True:
                for i_optimize in range(1):
                    if args.flag_dno:
                        extra_info['input_ids'] = prompt_input_ids
                        extra_info['timesteps'] = t.repeat(latents.shape[0])
                        extra_info['eps'] = args.eps
                        extra_info['range_i'] = args.range_i
                        extra_info['par_num'] = args.par_num
                        extra_info['p_min'] = args.p_min
                        extra_info['l2_weight']=args.l2_weight
                        extra_info['p_max'] = args.p_max
                        with autocast():
                            disoutput = disnet_p(latents).reshape(latents.shape[0], -1).squeeze()
                            #disoutput = disnet_p.normalize_and_clip(i, min=extra_info['p_min'], max=extra_info['p_max'])
                        p = (1 - disoutput)
                        extra_info['p'] = p

                        if p == 0 and args.p_min<args.p_max and RewardJudger.verify(latents):
                            break
                        latent_prev_0,total_loss_0,latents_prev_sample, pred_x0_latents, _, grad_direction,pred_x0_dno_no_grad_0 = optimize_step_LitE_first(
                            args,
                            latents,
                            pipeline, t,
                            pred_x0_prev,
                            prompt_embeds,
                            do_classifier_free_guidance,
                            cross_attention_kwargs,
                            preference_model_fn,
                            extra_info,
                            generator,
                            contrast_pairs,
                            extra_step_kwargs)

                        pred_x0_prev = pred_x0_latents
                        if args.range_i>0:
                            if delta_latent is None or i<args.inh_begin:
                                extra_info['par_num']=1
                                latents_prev_sample = latents.repeat(extra_info['par_num'], 1, 1, 1)
                                delta_latent = torch.randn_like(latents_prev_sample, requires_grad=False) * extra_info[
                                    'eps']
                                par_num = extra_info['par_num']
                            elif delta_latent is None or i<args.mc_begin:
                                latents_prev_sample = latents.repeat(extra_info['par_num'], 1, 1, 1)
                                delta_latent = torch.randn_like(latents_prev_sample, requires_grad=False) * extra_info[
                                    'eps']
                                par_num = extra_info['par_num']
                            else:
                                delta_latent = delta_latent.detach() * extra_info['eps']
                                par_num = 1
                            latents_prev_sample = latents_prev_sample+p*delta_latent
                            extra_lite_kwargs={}
                            extra_lite_kwargs['par_num']=par_num
                            extra_lite_kwargs['latent_prev_0'] = latent_prev_0
                            extra_lite_kwargs['total_loss_0'] = total_loss_0
                            extra_lite_kwargs['pred_x0_latents'] = pred_x0_latents
                            extra_lite_kwargs['reward_judger'] = reward_judger
                            for travel_iter in range(args.range_i+1):
                                extra_lite_kwargs['travel_iter'] = travel_iter
                                latents = ddim_step_forward_from_xt_1_to_xt(pipeline.scheduler, t, latents_prev_sample, generator)
                                extra_info['input_ids'] = prompt_input_ids
                                extra_info['timesteps'] = t.repeat(latents.shape[0])

                                result_package = optimize_step_LitE_second(
                                    args,
                                    latents,
                                    pipeline, t,
                                    pred_x0_prev,
                                    prompt_embeds,
                                    do_classifier_free_guidance,
                                    cross_attention_kwargs,
                                    preference_model_fn,
                                    extra_info,
                                    generator,
                                    contrast_pairs,
                                    extra_step_kwargs, extra_lite_kwargs)
                                try:
                                    latents_prev_sample, pred_x0_latents, delta_latent = result_package
                                    pred_x0_prev =  extra_lite_kwargs['pred_x0_latents']
                                    break
                                except Exception as e:
                                    latents_prev_sample, pred_x0_latents, _, grad_direction, pred_x0_dno_no_grad = result_package
                                    pred_x0_prev = pred_x0_latents

                            #latents = ddim_step_forward_from_xt_1_to_xt(pipeline.scheduler, t, latents_prev_sample, generator)
            if latents_prev_sample is not None:
                latents = latents_prev_sample
            else:
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
                noise_pred = pipeline.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + args.cfg_scale * (noise_pred_text - noise_pred_uncond)

                if do_classifier_free_guidance and args.cfg_rescale > 0.0:
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text,
                                                   guidance_rescale=args.cfg_rescale)
                latents, latent_pred_x0, _ = ddim_step_with_logprob_LitE(pipeline.scheduler, noise_pred, t, latents, **extra_step_kwargs)

                pred_x0_prev = latent_pred_x0
            '''
            image = pipeline.vae.decode(pred_x0_prev.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
                            return_dict=False,
                            generator=generator, )[0]
            image = pipeline.image_processor.postprocess(image, output_type="pt",
                                                         do_denormalize=[True] * image.shape[0])
            Image.fromarray((image[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)).save('temp.jpg')
            '''
            latent_model_input_ori = torch.cat([latents_ori] * 2) if do_classifier_free_guidance else latents_ori
            latent_model_input_ori = pipeline.scheduler.scale_model_input(latent_model_input_ori, t)
            noise_pred_ori = pipeline.unet(
                latent_model_input_ori,
                t,
                encoder_hidden_states=prompt_embeds,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]
            if do_classifier_free_guidance:
                noise_pred_uncond_ori, noise_pred_text_ori = noise_pred_ori.chunk(2)
                noise_pred_ori = noise_pred_uncond_ori + args.cfg_scale * (noise_pred_text_ori - noise_pred_uncond_ori)
            if do_classifier_free_guidance and args.cfg_rescale > 0.0:
                noise_pred_ori = rescale_noise_cfg(noise_pred_ori, noise_pred_text_ori, guidance_rescale=args.cfg_rescale)
            latents_ori, latents_ori_x0, _ = ddim_step_with_logprob_LitE(pipeline.scheduler, noise_pred_ori, t,
                                                                        latents_ori,
                                                                        variance_noise=None,
                                                                        **extra_step_kwargs)
            if i in save_inter_array and args.save_inter == 1:
                image = pipeline.vae.decode(extra_lite_kwargs['pred_x0_latents'].to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
                                            return_dict=False,
                                            generator=generator, )[0]
                image = pipeline.image_processor.postprocess(image, output_type="pt",
                                                             do_denormalize=[True] * image.shape[0])

                image_ori = pipeline.vae.decode(latents_ori_x0.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor,
                                                return_dict=False,
                                                generator=generator, )[0]
                image_ori = pipeline.image_processor.postprocess(image_ori, output_type="pt",
                                                                 do_denormalize=[True] * image_ori.shape[0])

                image = np.concatenate(((image_ori[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8),
                                        (image[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)),
                                       axis=1)
                image = Image.fromarray(image)
                inter_path = save_name[:-4]+f'_{i}.jpg'
                image.save(inter_path)

            if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipeline.scheduler.order == 0):
                progress_bar.update()
                torch.cuda.empty_cache()

    image = pipeline.vae.decode(latents.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor, return_dict=False,
                                generator=generator, )[0]
    image = pipeline.image_processor.postprocess(image, output_type="pt", do_denormalize=[True] * image.shape[0])

    image_ori = pipeline.vae.decode(latents_ori.to(pipeline.vae.dtype) / pipeline.vae.config.scaling_factor, return_dict=False,
                                generator=generator, )[0]
    image_ori = pipeline.image_processor.postprocess(image_ori, output_type="pt", do_denormalize=[True] * image_ori.shape[0])

    image = np.concatenate(((image_ori[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), (image[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)), axis=1)
    image = Image.fromarray(image)
    image.save(save_name)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_id', default='runwayml/stable-diffusion-v1-5')
    parser.add_argument('--device', default='cuda')
    parser.add_argument(
        '--prompt',
        default='a photo',
    )
    parser.add_argument(
        '--prompt_json',
        default='./show_prompt_files/prompt6.json',
    )
    parser.add_argument(
        '--cfg_scale',
        default=7.5,
        type=float,
    )
    parser.add_argument(
        '--cfg_rescale',
        default=0.0,
        type=float,
    )
    parser.add_argument(
        '--output_filename',
        default='sdv1-5_RLHF_img.png',
    )
    parser.add_argument(
        '--seed',
        default=40,
        type=int,
    )
    parser.add_argument(
        '--num_timesteps',
        default=50,
        type=int,
    )
    parser.add_argument(
        '--height',
        default=512,
        type=int,
    )
    parser.add_argument(
        '--width',
        default=512,
        type=int,
    )
    parser.add_argument(
        '--num_optimization',
        default=1,
        type=int,
    )
    parser.add_argument(
        '--flag_optimize',
        default=True,
        type=bool,
    )
    parser.add_argument(
        '--opt_stength',
        default=1,
        type=int,
    )
    # Noise Optimize setting
    parser.add_argument(
        '--flag_dno',
        default=True,
        type=bool,
    )
    parser.add_argument(
        '--flag_round',
        default=True,
        type=bool,
    )
    parser.add_argument(
        '--flag_round_stop',
        default=True,
        type=bool,
    )
    parser.add_argument(
        '--num_round',
        default=1,
        type=int,
    )
    parser.add_argument(
        '--flag_dpo_loss',
        default=False,
        type=bool,
    )
    parser.add_argument(
        '--flag_attn_loss',
        default=True,
        type=bool,
    )
    parser.add_argument(
        '--time1',
        default=8,
        type=int,
    )
    parser.add_argument(
        '--time2',
        default=25,
        type=int,
    )
    parser.add_argument(
        '--save_sub_dir',
        default='LitE_15',
        type=str,
    )
    parser.add_argument('--eps', default=0.3, type=float, help='Epsilon value')
    parser.add_argument('--range_i', default=2, type=int, help='Range index')
    parser.add_argument('--par_num', default=2, type=int, help='Number of parameters')
    parser.add_argument('--p_min', default=7, type=int, help='Minimum value for p')
    parser.add_argument('--p_max', default=40, type=int, help='Maximum value for p')
    parser.add_argument('--inh_begin', default=20, type=int, help='inherit begin')
    parser.add_argument('--mc_begin', default=30, type=int, help='Monte Carlo begin')
    parser.add_argument('--seed_num', default=1, type=int, help='seed number for diversity')
    parser.add_argument('--l2_weight', default=0.2, type=float, help='Monte Carlo begin')
    parser.add_argument('--json_folder', default='test_caption/pick_caption', type=str, help='json_folder')
    parser.add_argument('--eps_batch', default=0, type=int, help='batch eps')
    parser.add_argument('--l2_batch', default=0, type=int, help='batch l2')
    parser.add_argument('--save_inter', default=0, type=int, help='save intermediate res')
    parser.add_argument('--seed_base', default=40, type=int, help='save intermediate res')

    args = parser.parse_args()
    nlp = spacy.load('en_core_web_sm')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inference_dtype = torch.float16

    preference_model_func_cfg = dict(
        type="step_aware_preference_model_func",
        model_pretrained_model_name_or_path='yuvalkirstain/PickScore_v1',
        processor_pretrained_model_name_or_path='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
        ckpt_path='model_ckpts/sd-v1-5_step-aware_preference_model.bin',
        device=device,
        inference_dtype=inference_dtype,
    )
    preference_model_fn = get_preference_model_func(preference_model_func_cfg)

    # loading models
    ckpt_id = args.ckpt_id

    pipeline = StableDiffusionPipeline.from_pretrained(
        ckpt_id,
        torch_dtype=inference_dtype
    )

    # cross-attn
    attn_res = (16, 16)
    pipeline.attention_store = AttentionStore(attn_res)
    attn_procs = {}
    cross_att_count = 0
    for name in pipeline.unet.attn_processors.keys():
        if name.startswith("mid_block"):
            place_in_unet = "mid"
        elif name.startswith("up_blocks"):
            place_in_unet = "up"
        elif name.startswith("down_blocks"):
            place_in_unet = "down"
        else:
            continue
        cross_att_count += 1
        attn_procs[name] = AttendExciteAttnProcessor(
            attnstore=pipeline.attention_store, place_in_unet=place_in_unet
        )
    pipeline.unet.set_attn_processor(attn_procs)
    pipeline.attention_store.num_att_layers = cross_att_count
    all_json=sorted([f for f in os.listdir(args.json_folder) if f.endswith('.json') and os.path.isfile(os.path.join(args.json_folder, f))][5:])
    for json_name in all_json:
        print(f'doing {json_name}')
        args.prompt_json = json_name
        json_file_path = os.path.join(args.json_folder, args.prompt_json)
        with open(json_file_path, 'r') as f:
            json_data = json.load(f)
        args.prompt = json_data["prompt"].lower()
        doc = nlp(args.prompt)
        try:
            tokens_dict = find_tokens_from_text(json_data, doc)
        except Exception as e:
            print(f'skipping {json_name}')
            continue
        token_p_n_pairs = []
        for i, key in enumerate(tokens_dict):
            if tokens_dict[key] is not None:
                pairs = [[key], tokens_dict[key]]
            else:
                pairs = [[key], None]

            match_list_indices = align_indices(pipeline, args.prompt, pairs)
            result = flatten_lists(match_list_indices)
            token_p_n_pairs.append(result)
        positive_pairs, negative_pairs = process_p_n_groups(token_p_n_pairs)
        p_n_dict = {"pos": positive_pairs, "neg": negative_pairs}
        if (p_n_dict["pos"] == []) and (p_n_dict["neg"] == []):
            continue
        pipeline.to(device)
        pipeline.text_encoder.requires_grad_(False)
        for seed in range(args.seed_base,args.seed_base+args.seed_num):
            if args.l2_batch < 0.5:
                l2_weight_array = [args.l2_weight]
            else:
                l2_weight_array = np.arange(0, 1.21, 0.2)
            for l2_weight in l2_weight_array:
                if args.eps_batch>0.5:
                    eps_arr=[0.05,0.2,0.4,0.7,1.0]
                else:
                    eps_arr=[args.eps]
                for eps in eps_arr:
                    args.l2_weight=l2_weight
                    args.eps=eps
                    if args.seed_num>1:
                        args.seed = seed
                        args.output_filename=f"seed{seed}_json{json_name[:-5]}.png"
                    else:
                        args.output_filename = f"json{json_name[:-5]}.png"
                    save_path = os.path.join("./results",args.save_sub_dir,args.json_folder.split('/')[-1],f'eps{args.eps}_l2{args.l2_weight}_{args.p_min}-{args.p_max}')
                    os.makedirs(save_path, exist_ok=True)
                    save_name = os.path.join(save_path, args.output_filename)
                    main(args, args.prompt, save_name, pipeline=pipeline, preference_model_fn=preference_model_fn, contrast_pairs=p_n_dict)
