from tqdm.auto import tqdm
import random
from PIL import Image

import torch as T
import transformers, diffusers
import imageio

import os
from typing import Optional, Union, Tuple, List, Callable, Dict
from tqdm import tqdm
import json
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch.nn.functional as nnf
import numpy as np
import abc
import shutil
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import utils

import cv2
from skimage import img_as_ubyte
import matplotlib.pyplot as plt

from typing import List, Union, Tuple, Dict, Optional
import abc
import torch.nn.functional as nnf
import torch

import torch
import torchvision.transforms as transforms
from PIL import Image
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from llava.conversation import conv_templates
from llava.model import *

import sys

# Get the filename of the current script
script_name = os.path.basename(sys.argv[0])
_index = script_name[-4]
# Initialize the device variable

ckpt_folder = "./_ckpt/mgie_7b"
diffusion_steps = 10

saved_video_fps = 15

seed = 23

from edit_0_instruction import *

image_instruction_seed_pairs = [
    [f"./dataset/video_frame_0/kun/kun-{idx}", instruction, 23]
    for idx in [1,2]
    for instruction in robot_instruction]


_instruction_list = [
"Introduce a vibrant sunset or sunrise",
"Japanese woodblock print style",
"Include a waterfall flowing from the mountain",
"Apply a winter theme with snow-capped peaks"]

instruction = "Create a night scene with street lights."

image_folder = "./dataset/video_frame_0/driving/driving-4"
group_idx = [
        0, 75, 149
    ]

group_size = len(group_idx)
start_idx = 0
end_idx = 1769
L = [one for one in range(start_idx, end_idx, 50)]

for instruction in [instruction]:

    use_mask = False
    use_dilation = False
    dilate_size = 0
    latent_starts_step = diffusion_steps - 1

    device = f"cuda:{_index}"

    if image_folder.endswith('/'):
        _last_part = os.path.basename(image_folder[:-1])
    else:
        _last_part = os.path.basename(image_folder)


    _total_frame_number = len([f for f in os.listdir(image_folder) if f.endswith('.jpg')])

    total_frame_number = min(_total_frame_number, 150)

    group_images = [f"{_last_part}_frame_{idx:04d}.jpg" for idx in group_idx]


    total_images = [f"{_last_part}_frame_{idx:04d}.jpg" for idx in L]

    variable_step_dict = {'self_q': int(0.0*diffusion_steps),
                                'self_k': int(1.0*diffusion_steps),
                                'self_v': int(1.0*diffusion_steps),
                                'cross_q': int(0.0*diffusion_steps),
                                'cross_k': int(1.0*diffusion_steps),
                                'cross_v': int(1.0*diffusion_steps),
                                'latent': int(1.0*diffusion_steps)}

    attention_stored = False
    use_stored_config = False


    data = None

    if data is not None:
        globals().update(data)


    if use_stored_config:
        file_path = './ml-mgie/editing_results/idx_00028_clip_-7TCS_e4EVg_0_Make it Wood Sculpture.-good/config.json'
        with open(file_path, 'r') as file:
            data = json.load(file)
        globals().update(data)
        group_images = group_images[0]

    # Variables to be saved
    variables_dict = {
        "device": device,
        "ckpt_folder": ckpt_folder,
        "diffusion_steps": diffusion_steps,
        "saved_video_fps": saved_video_fps,
        "instruction": instruction,
        "group_size": group_size,
        "image_folder": image_folder,
        "group_images": group_images,
        "total_images": total_images,
        "variable_step_dict": variable_step_dict,
        "attention_stored": attention_stored,
        "use_mask": use_mask,
        "latent_starts_step": latent_starts_step,
        "seed": seed
    }


    def crop_resize(f, sz=512):
        w, h = f.size
        if w>h:
            p = (w-h)//2
            f = f.crop([p, 0, p+h, h])
        elif h>w:
            p = (h-w)//2
            f = f.crop([0, p, w, p+w])
        f = f.resize([sz, sz])
        return f
    def remove_alter(s):  # hack expressive instruction
        if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()
        if '</s>' in s: s = s[:s.index('</s>')].strip()
        if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]
        if '[IMG0]' in s: s = s[:s.index('[IMG0]')]
        s = '.'.join([s.strip() for s in s.split('.')[:2]])
        if s[-1]!='.': s += '.'
        return s.strip()


    from PIL import Image
    import numpy as np
    from torchvision import transforms
    from scipy.ndimage import grey_dilation

    # Define a custom transform that applies dilation
    class DilationTransform:
        def __init__(self, size):
            self.structure = np.ones((size, size), dtype=np.uint8)

        def __call__(self, img):
            img_array = np.array(img)
            dilated_array = grey_dilation(img_array, footprint=self.structure)
            return Image.fromarray(dilated_array)


    def create_result_folder(video_name, edit_instruction):

        # Path to the 'editing_results' directory
        base_path = "./editing_results"

        # Ensure the base directory exists and create if it does not
        if not os.path.exists(base_path):
            os.makedirs(base_path)

        # List only directories inside 'editing_results'
        existing_folders = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

        # Count the existing directories that start with 'idx_'
        idx_count = len([folder for folder in existing_folders if folder.startswith('idx_')])

        # New folder name with zero-padded index
        new_folder_name = f"idx_{str(idx_count).zfill(5)}_{video_name}_{edit_instruction}"

        # Full path for the new directory
        new_folder_path = os.path.join(base_path, new_folder_name)

        # Create the new folder
        os.makedirs(new_folder_path, exist_ok=True)

        print(f"New folder created: {new_folder_path}")
        return new_folder_path

    result_subfolder = create_result_folder(_last_part, instruction)

    json_file_path = os.path.join(result_subfolder, 'config.json')
    
    try:
        with open(json_file_path, 'w') as json_file:
            json.dump(variables_dict, json_file, indent=4)
    except:
        breakpoint()

    print(f"Variables have been saved to JSON at {json_file_path}")

    DEFAULT_IMAGE_TOKEN = '<image>'
    DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
    DEFAULT_IM_START_TOKEN = '<im_start>'
    DEFAULT_IM_END_TOKEN = '<im_end>'
    PATH_LLAVA = './_ckpt/LLaVA-7B-v1'



    tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
    model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).to(device)
    image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)

    tokenizer.padding_side = 'left'
    tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))
    ckpt = T.load(os.path.join(ckpt_folder, "mllm.pt"), map_location='cpu')
    model.load_state_dict(ckpt, strict=False)

    mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)

    vision_tower = model.get_model().vision_tower[0]
    vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).to(device)
    model.get_model().vision_tower[0] = vision_tower
    vision_config = vision_tower.config
    vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
    vision_config.use_im_start_end = mm_use_im_start_end
    if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
    image_token_len = (vision_config.image_size//vision_config.patch_size)**2

    _ = model.eval()
    EMB = ckpt['emb'].to(device)
    with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to(device) , EMB)
    print('NULL:', NULL.shape)

    pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16, safety_checker=None).to(device)
    pipe.set_progress_bar_config(disable=True)
    pipe.unet.load_state_dict(T.load(os.path.join(ckpt_folder, "unet.pt"), map_location='cpu'))


    def getROI(mask, margin_ratio, thres=0.25):
            mask = cv2.threshold(img_as_ubyte(mask, force_copy=True), thres, 1, cv2.THRESH_BINARY)[1]
            non_zero_point = cv2.findNonZero(mask)
            if non_zero_point is not None:
                x11, y11, x21, y21 = cv2.minMaxLoc(non_zero_point[:, :, 0])[0], cv2.minMaxLoc(non_zero_point[:, :, 1])[0], \
                                     cv2.minMaxLoc(non_zero_point[:, :, 0])[1], cv2.minMaxLoc(non_zero_point[:, :, 1])[1]

            else:
                x21 = mask.shape[1]
                x11 = 0
                y21 = mask.shape[0]
                y11 = 0

            margin = int(margin_ratio * min(x21 - x11, y21 - y11))
            if margin > 0:
                x11 = max(0, x11 - margin)
                x21 = min(mask.shape[1], x21 + margin)
                y11 = max(0, y11 - margin)
                y21 = min(mask.shape[0], y21 + margin)
            # fixme: here should we use None?
            if (x11 >= x21) or (y11 >= y21):
                x21 = mask.shape[1]
                x11 = 0
                y21 = mask.shape[0]
                y11 = 0

            return (int(x11), int(y11), int(x21), int(y21))


    def convert_image_to_binary_mask(path):
        from PIL import Image
        import numpy as np

        # Read the image
        image = Image.open(path).convert("L")
        image_np = np.array(image)

        # Create binary mask
        binary_mask = np.where(image_np > 127, 1, 0).astype('uint8')

        return binary_mask


    def apply_soft_mask(mask):    

        # Create a dilation kernel. You can change its size for different effects.
        kernel_size = 10
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))

        # Dilate the mask
        soft_mask = cv2.dilate(mask, kernel, iterations=1)

        # Soften the edges using a Gaussian blur
        soft_mask = cv2.GaussianBlur(soft_mask.astype(np.float32), (25, 25), 0)

        # Restore the original 1 values
        soft_mask[mask == 1] = 1

        return soft_mask


    def masked_replace_with_rescale(A, B, mask):

        # import pdb; pdb.set_trace()
        mean_A = A[mask].mean()
        std_A = A[mask].std()
        mean_B = B[mask].mean()
        std_B = B[mask].std()

        A_normalized = (A - mean_A) / (std_A)
        A_rescaled = A_normalized * std_B + mean_B
        A[mask] = A_rescaled[mask]
        # B_normalized = (B - mean_B) / (std_B)
        # B_rescaled = B_normalized * std_A + mean_A
        # A[mask] = B_rescaled[mask]
        return A



    class LocalBlend:

        def get_mask(self, x_t, maps, alpha, use_pool):
            k = 1
            maps = (maps * alpha).sum(-1).mean(1)
            if use_pool:
                maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
            mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
            mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
            mask = mask.gt(self.th[1-int(use_pool)])
            return mask

        def __call__(self, x_t, attention_store):
            self.counter += 1

            # self.start_blend = 50
            self.start_blend = -1
            self.end_blend = end_blend
            if self.counter > self.start_blend and self.counter < self.end_blend:

                maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
                maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, 77) for item in maps]
                maps = torch.cat(maps, dim=1)
                mask = self.get_mask(x_t, maps, self.alpha_layers, True)
                mask = mask[:1] + mask

                import pdb
                import pickle
                # if self.counter<47:
                # mask_add = cv2.imread(config.external_mask_path)[:,:,0] / 255.0 #convert_image_to_binary_mask(config.external_mask_path)
                # mask_add = cv2.resize(mask_add, (512, 512), interpolation=cv2.INTER_NEAREST)   
                mask_64 = config.external_mask[3::8, 3::8]

                if config.area_mask_soft != 0.0:     
                    mask_64 = apply_soft_mask(mask_64.astype(np.float32))

                if self.substruct_layers is not None:
                    maps_sub = ~self.get_mask(maps, self.substruct_layers, False)
                    mask = mask * maps_sub
                mask = mask.float()

                if config.use_external_mask:
                    new_mask = torch.ones_like(mask)
                    new_mask[:,:] = torch.tensor(mask_64).cuda()
                else:
                    new_mask = mask.clone().detach().to(mask.device)

                if config.time_step_soft !=0 and config.use_external_mask:
                    rate = self.counter / 100.0 * config.time_step_soft
                    rate = min(rate, 1.0)
                    new_mask = new_mask * rate

                def masked_replace_with_rescale(A, B, mask):

                    # import pdb; pdb.set_trace()
                    mean_A = A[mask].mean()
                    std_A = A[mask].std()
                    mean_B = B[mask].mean()
                    std_B = B[mask].std()

                    A_normalized = (A - mean_A) / (std_A)
                    A_rescaled = A_normalized * std_B + mean_B
                    A[mask] = A_rescaled[mask]

                    return A

                if self.counter > -1 and config.latent_change_distribution:

                    x_t[1] = masked_replace_with_rescale(x_t[1], x_t[0], new_mask[0].repeat(4, 1, 1).bool())
                    x_t = x_t[:1] + new_mask * (x_t - x_t[:1])
                else:
                    x_t = x_t[:1] + new_mask * (x_t - x_t[:1])
            if self.counter < config.start_with_same_latent:
                x_t[1] = x_t[0]
            return x_t

        def __init__(self, prompts: List[str], words: [List[List[str]]], tokenizer, device, NUM_DDIM_STEPS,
                     substruct_words=None, start_blend=0.2, th=(.3, .3), start_with_same_latent=0):
            self.start_with_same_latent = start_with_same_latent
            alpha_layers = torch.zeros(len(prompts),  1, 1, 1, 1, 77)
            for i, (prompt, words_) in enumerate(zip(prompts, words)):
                # import pdb; pdb.set_trace()
                if type(words_) is str:
                    words_ = [words_]
                for word in words_:
                    ind = utils.get_word_inds(prompt, word, tokenizer)
                    alpha_layers[i, :, :, :, :, ind] = 1

            if substruct_words is not None:
                substruct_layers = torch.zeros(len(prompts),  1, 1, 1, 1, 77)
                for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
                    if type(words_) is str:
                        words_ = [words_]
                    for word in words_:
                        ind = utils.get_word_inds(prompt, word, tokenizer)
                        substruct_layers[i, :, :, :, :, ind] = 1
                self.substruct_layers = substruct_layers.to(device)
            else:
                self.substruct_layers = None
            self.alpha_layers = alpha_layers.to(device)
            self.start_blend = int(start_blend * NUM_DDIM_STEPS)
            self.counter = 0 
            self.th=th


    class AttentionControlEdit(abc.ABC):

        @staticmethod
        def get_empty_store():
            return {"down_cross": [], "mid_cross": [], "up_cross": [],
                    "down_self": [],  "mid_self": [],  "up_self": []}

        @property
        def num_uncond_att_layers(self):
            return self.num_att_layers if self.LOW_RESOURCE else 0

        def __call__(self, attn, is_cross: bool, place_in_unet: str):
            if self.cur_att_layer >= self.num_uncond_att_layers:
                if self.LOW_RESOURCE:
                    attn = self.forward(attn, is_cross, place_in_unet)
                else:
                    h = attn.shape[0]
                    attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
            self.cur_att_layer += 1
            if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
                self.cur_att_layer = 0
                self.cur_step += 1
                self.between_steps()
            return attn

        def step_callback(self, x_t):
            if self.local_blend is not None:
                x_t = self.local_blend(x_t, self.attention_store)
            return x_t

        def replace_self_attention(self, attn_base, att_replace, place_in_unet):
            # import pdb;pdb.set_trace()
            if self.use_mask_on_attention:
                # import pdb;pdb.set_trace()

                def majority_pool(array, target_length):
                    # Calculate the segment size
                    segment_size = len(array) // target_length
                    compressed_array = np.zeros(target_length, dtype=bool)

                    for i in range(target_length):
                        segment = array[i * segment_size: (i + 1) * segment_size]
                        # Assign the mode of the segment
                        compressed_array[i] = np.sum(segment) > segment_size / 2

                    return compressed_array

                compressed_mask = majority_pool(self.mask.reshape(-1), attn_base.shape[1])
                if att_replace.shape[2] <= 32 ** 2:
                    self_attention_mask = torch.ones_like(attn_base[0])
                    self_attention_mask[compressed_mask, :] = 0
                    self_attention_mask[:, compressed_mask] = 0
                    if self.attention_change_distribution:
                        att_replace_ = masked_replace_with_rescale(att_replace, attn_base.unsqueeze(0), self_attention_mask.bool().unsqueeze(0).unsqueeze(0).repeat(1, attn_base.shape[0],1,1))
                        return attn_base * self_attention_mask  + att_replace_ * (1-self_attention_mask)
                    return attn_base * self_attention_mask  + att_replace * (1-self_attention_mask)
                else:
                    return att_replace            
            else:
                if att_replace.shape[2] <= 32 ** 2:
                    attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
                    return attn_base
                else:
                    return att_replace


        def between_steps(self):
            if len(self.attention_store) == 0:
                self.attention_store = self.step_store
            else:
                for key in self.attention_store:
                    for i in range(len(self.attention_store[key])):
                        self.attention_store[key][i] += self.step_store[key][i]
            self.step_store = self.get_empty_store()

        def get_average_attention(self):
            average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
            return average_attention

        def reset(self):
            self.cur_step = 0
            self.cur_att_layer = 0

            self.step_store = self.get_empty_store()
            self.attention_store = {}

        @abc.abstractmethod
        def replace_cross_attention(self, attn_base, att_replace):
            raise NotImplementedError

        def forward(self, attn, is_cross: bool, place_in_unet: str):

            key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
            if attn.shape[1] <= 32 ** 2:
                self.step_store[key].append(attn)

            # import pdb; pdb.set_trace()
            # markmark
            if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
                h = attn.shape[0] // (self.batch_size)
                attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
                attn_base, attn_repalce = attn[0], attn[1:]
                if is_cross:
                    alpha_words = self.cross_replace_alpha[self.cur_step]
                    attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                    attn[1:] = attn_repalce_new
                else:
                    attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
                attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
            return attn

        def __init__(self, prompts, num_steps: int,
                     cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                     self_replace_steps: Union[float, Tuple[float, float]],
                     local_blend: Optional[LocalBlend], tokenizer, device):
            self.cur_step = 0
            self.num_att_layers = -1
            self.cur_att_layer = 0

            self.step_store = self.get_empty_store()
            self.attention_store = {}

            self.batch_size = len(prompts)
            self.cross_replace_alpha = utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
            if type(self_replace_steps) is float:
                self_replace_steps = 0, self_replace_steps
            self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
            self.local_blend = local_blend


    class AttentionSwap(AttentionControlEdit):

        def replace_cross_attention(self, attn_base, att_replace):
            attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
            attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
            return attn_replace

        def __init__(self, prompts, num_steps: int, cross_map_replace_steps: float, self_map_replace_steps: float, self_output_replace_steps: float,
                     source_subject_word=None, target_subject_word=None, tokenizer=None, device=None, LOW_RESOURCE=False, use_local_blend=True, mask=None,
                     use_mask_on_attention=None, use_mask_on_latent=None, attention_change_distribution=None, latent_change_distribution=None,
                     start_with_same_latent=0):
            self_map_replace_steps = self_map_replace_steps + self_output_replace_steps

            if use_local_blend:
                blend_word = (((source_subject_word,), (target_subject_word,)))
                local_blend = LocalBlend(prompts, blend_word, tokenizer, device, num_steps, start_with_same_latent=start_with_same_latent)
            else:
                local_blend = None

            super(AttentionSwap, self).__init__(prompts, num_steps, cross_map_replace_steps, self_map_replace_steps, local_blend, tokenizer, device)
            self.cross_map_replace_steps = cross_map_replace_steps
            self.self_map_replace_steps = self_map_replace_steps
            self.self_output_replace_steps = self_output_replace_steps
            self.mapper, alphas = get_refinement_mapper(prompts, tokenizer)
            self.mapper, alphas = self.mapper.to(device), alphas.to(device)
            self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
            self.LOW_RESOURCE = LOW_RESOURCE
            self.mask = mask
            self.use_mask_on_attention = use_mask_on_attention
            self.use_mask_on_latent = use_mask_on_latent
            self.latent_change_distribution = latent_change_distribution
            self.attention_change_distribution = attention_change_distribution


    import numpy as np
    import torch
    from PIL import Image, ImageDraw, ImageFont
    import cv2
    from typing import Optional, Union, Tuple, List, Callable, Dict
    from IPython.display import display
    from tqdm import tqdm
    from einops import rearrange, repeat


    class ScoreParams:

        def __init__(self, gap, match, mismatch):
            self.gap = gap
            self.match = match
            self.mismatch = mismatch

        def mis_match_char(self, x, y):
            if x != y:
                return self.mismatch
            else:
                return self.match


    def get_matrix(size_x, size_y, gap):
        matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
        matrix[0, 1:] = (np.arange(size_y) + 1) * gap
        matrix[1:, 0] = (np.arange(size_x) + 1) * gap
        return matrix


    def get_traceback_matrix(size_x, size_y):
        matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
        matrix[0, 1:] = 1
        matrix[1:, 0] = 2
        matrix[0, 0] = 4
        return matrix


    def global_align(x, y, score):
        matrix = get_matrix(len(x), len(y), score.gap)
        trace_back = get_traceback_matrix(len(x), len(y))
        for i in range(1, len(x) + 1):
            for j in range(1, len(y) + 1):
                left = matrix[i, j - 1] + score.gap
                up = matrix[i - 1, j] + score.gap
                diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
                matrix[i, j] = max(left, up, diag)
                if matrix[i, j] == left:
                    trace_back[i, j] = 1
                elif matrix[i, j] == up:
                    trace_back[i, j] = 2
                else:
                    trace_back[i, j] = 3
        return matrix, trace_back


    def get_aligned_sequences(x, y, trace_back):
        x_seq = []
        y_seq = []
        i = len(x)
        j = len(y)
        mapper_y_to_x = []
        while i > 0 or j > 0:
            if trace_back[i, j] == 3:
                x_seq.append(x[i-1])
                y_seq.append(y[j-1])
                i = i-1
                j = j-1
                mapper_y_to_x.append((j, i))
            elif trace_back[i][j] == 1:
                x_seq.append('-')
                y_seq.append(y[j-1])
                j = j-1
                mapper_y_to_x.append((j, -1))
            elif trace_back[i][j] == 2:
                x_seq.append(x[i-1])
                y_seq.append('-')
                i = i-1
            elif trace_back[i][j] == 4:
                break
        mapper_y_to_x.reverse()
        return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)


    def get_mapper(x: str, y: str, tokenizer, max_len=77):
        x_seq = tokenizer.encode(x)
        y_seq = tokenizer.encode(y)
        score = ScoreParams(0, 1, -1)
        matrix, trace_back = global_align(x_seq, y_seq, score)
        mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
        alphas = torch.ones(max_len)
        alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
        mapper = torch.zeros(max_len, dtype=torch.int64)
        mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
        mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
        return mapper, alphas


    def get_refinement_mapper(prompts, tokenizer, max_len=77):
        x_seq = prompts[0]
        mappers, alphas = [], []
        for i in range(1, len(prompts)):
            mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
            mappers.append(mapper)
            alphas.append(alpha)
        return torch.stack(mappers), torch.stack(alphas)


    def get_word_inds(text: str, word_place: int, tokenizer):
        split_text = text.split(" ")
        if type(word_place) is str:
            word_place = [i for i, word in enumerate(split_text) if word_place == word]
        elif type(word_place) is int:
            word_place = [word_place]
        out = []
        if len(word_place) > 0:
            words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
            cur_len, ptr = 0, 0

            for i in range(len(words_encode)):
                cur_len += len(words_encode[i])
                if ptr in word_place:
                    out.append(i + 1)
                if cur_len >= len(split_text[ptr]):
                    ptr += 1
                    cur_len = 0
        return np.array(out)


    def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
        # import pdb; pdb.set_trace()
        if low_resource:
            # import pdb; pdb.set_trace()
            noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[:2])["sample"]
            # pdb.set_trace()
            noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[2:])["sample"]
        else:
            latents_input = torch.cat([latents] * 2)
            noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
            noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
            # import pdb; pdb.set_trace()
        noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
        latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]

        import pdb
        # pdb.set_trace()
        latents = controller.step_callback(latents)
        return latents


    def latent2image(vae, latents):
        latents = 1 / 0.18215 * latents
        image = vae.decode(latents)['sample']
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        image = (image * 255).astype(np.uint8)
        return image


    def init_latent(latent, model, height, width, generator, batch_size):
        if latent is None:
            latent = torch.randn(
                (1, model.unet.in_channels, height // 8, width // 8),
                generator=generator,
            )
        latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)
        return latent, latents


    def get_word_inds(text: str, word_place: int, tokenizer):
        split_text = text.split(" ")
        if type(word_place) is str:
            word_place = [i for i, word in enumerate(split_text) if word_place == word]
        elif type(word_place) is int:
            word_place = [word_place]
        out = []
        if len(word_place) > 0:
            words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
            cur_len, ptr = 0, 0

            for i in range(len(words_encode)):
                cur_len += len(words_encode[i])
                if ptr in word_place:
                    out.append(i + 1)
                if cur_len >= len(split_text[ptr]):
                    ptr += 1
                    cur_len = 0
        return np.array(out)


    def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
                               word_inds: Optional[torch.Tensor]=None):
        if type(bounds) is float:
            bounds = 0, bounds

        start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
        if word_inds is None:
            word_inds = torch.arange(alpha.shape[2])
        alpha[: start, prompt_ind, word_inds] = 0
        alpha[start: end, prompt_ind, word_inds] = 1
        alpha[end:, prompt_ind, word_inds] = 0
        return alpha


    def get_time_words_attention_alpha(prompts, num_steps,
                                       cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
                                       tokenizer, max_num_words=77):

        if type(cross_replace_steps) is not dict:
            cross_replace_steps = {"default_": cross_replace_steps}
        if "default_" not in cross_replace_steps:
            cross_replace_steps["default_"] = (0., 1.)
        alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
        for i in range(len(prompts) - 1):
            alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
                                                      i)
        for key, item in cross_replace_steps.items():
            if key != "default_":
                 inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
                 for i, ind in enumerate(inds):
                     if len(ind) > 0:
                        alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
        alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
        return alpha_time_words


    def generate_json_config(source_mask_path=None):
        dict = {}
        dict['json_path'] = 'video-config.json'
        dict['source_mask_path'] = source_mask_path

        if dict['source_mask_path'] == None:
            dict['source_mask_path'] = 'mask.png'

        mask = cv2.imread(dict['source_mask_path'])[:,:,0] /255.0

        domo_mask_binary = np.zeros_like(mask)
        domo_mask_binary[mask > 0.5] = 1
        domo_mask_binary[mask <= 0.5] = 0
        mask = domo_mask_binary

        x1,y1,x2,y2=getROI(mask,0.2)
        mask = mask*255.0

        if do_not_crop:
            x1 = 0
            y1 = 0
            x2 = mask.shape[1]
            y2 = mask.shape[0]

        dict['crop_area'] = [x1, y1, x2, y2]

        new_mask = mask[y1:y2,x1:x2]

        new_mask = cv2.resize(new_mask,(512,512),cv2.INTER_NEAREST)  
        plt.imshow(new_mask)

        # cv2.imwrite(dict['square_source_mask_path'],new_mask)

        # img = cv2.imread(dict['source_image_path'])

        # new_img = img[y1:y2,x1:x2]
        # new_img = cv2.resize(new_img,(512,512),cv2.INTER_LINEAR)  
        # plt.imshow(new_img[:,:,::-1])
        # cv2.imwrite(dict["square_source_image_path"],new_img)

        # dict['json_path'] = dict['source_image_path'].split('.')[0] + '.json'

        if os.path.exists(dict['json_path']):
            print("The file exists.")
            with open(dict['json_path'], 'w') as f:
                json.dump(dict, f)
        else:
            with open(dict['json_path'], 'w') as f:
                json.dump(dict, f)

        # print(dict['json_path'])
        return dict['json_path']


    do_not_crop = True

    blend_width = 20


    self_output_range = [0.1, 0.3, 0.5, 0.7]
    self_map_range = [0.0]
    cross_map_range = [0.1, 0.3, 0.5, 0.7]
    end_blend = 52

    range_combination_list = [(0.0, 0.0, 0.0)]
    for aa in self_output_range:
        for bb in self_map_range:
            for cc in cross_map_range:
                range_combination_list.append((aa, bb, cc))

    config_path = None
    if config_path is None:
        config_path = generate_json_config()


    def register_attention_control(model, controller):
        def ca_forward(self, place_in_unet):
            to_out = self.to_out
            if type(to_out) is torch.nn.modules.container.ModuleList:
                to_out = self.to_out[0]
            else:
                to_out = self.to_out

            def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,):
                is_cross = encoder_hidden_states is not None

                residual = hidden_states

                if self.spatial_norm is not None:
                    hidden_states = self.spatial_norm(hidden_states, temb)

                input_ndim = hidden_states.ndim

                if input_ndim == 4:
                    batch_size, channel, height, width = hidden_states.shape
                    hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

                batch_size, sequence_length, _ = (
                    hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                )
                attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)

                if self.group_norm is not None:
                    hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

                query = self.to_q(hidden_states)

                if encoder_hidden_states is None:
                    encoder_hidden_states = hidden_states
                elif self.norm_cross:
                    encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

                key = self.to_k(encoder_hidden_states)
                value = self.to_v(encoder_hidden_states)

                if controller.cur_att_layer > controller.num_uncond_att_layers and query.shape[0]==2 and not is_cross and 0 <= controller.cur_step <= int(controller.self_output_replace_steps * 50):
                    print('swappping swapping')
                    import pdb; pdb.set_trace()
                    query[1, :, :] = query[0, :, :]
                    key[1, :, :] = key[0, :, :]
                    value[1, :, :] = value[0, :, :]
                # print("step: ", controller.current_step, "layer: ", controller.current_layer)

                # Function to save tensors
                def save_tensors(query, key, value, step, layer):
                    # Create subfolder based on the current step and layer
                    subfolder = os.path.join(controller.save_root_dir, "saved_attention", f'step_{step}_layer_{layer}')
                    os.makedirs(subfolder, exist_ok=True)

                    # Save the tensors
                    torch.save(query, os.path.join(subfolder, 'query.pt'))
                    torch.save(key, os.path.join(subfolder, 'key.pt'))
                    torch.save(value, os.path.join(subfolder, 'value.pt'))

                def load_tensors(step, layer, query, key, value):
                    # Define subfolder based on the current step and layer
                    # assert len(controller.load_folder_list) == 1
                    if controller.include_origin:
                        query_list = [query]; key_list = [key]; value_list = [value]
                    else:
                        query_list = []; key_list = []; value_list = []
                    for _load_folder in controller.load_folder_list:
                        subfolder = os.path.join(_load_folder, 'saved_attention', f'step_{step}_layer_{layer}')
                        # Load the tensors
                        query_list.append(torch.load(os.path.join(subfolder, 'query.pt'), map_location=device))
                        key_list.append(torch.load(os.path.join(subfolder, 'key.pt'), map_location=device))
                        value_list.append(torch.load(os.path.join(subfolder, 'value.pt'), map_location=device))
                    if controller.use_shuffle:
                        random.shuffle(key_list)
                        random.shuffle(value_list)
                        # combined_list = list(zip(query_list, key_list, value_list))
                        # random.shuffle(combined_list)
                        # query_list, key_list, value_list = zip(*combined_list)
                    _query, _key, _value = torch.cat(query_list, dim=1), torch.cat(key_list, dim=1), torch.cat(value_list, dim=1)
                    if is_cross:

                        if controller.current_step < controller.variable_step_dict['cross_q']:
                            query = _query

                        if controller.current_step < controller.variable_step_dict['cross_k']:
                            key = _key
                            # print("key shape: ", key.shape)
                        if controller.current_step < controller.variable_step_dict['cross_v']:
                            value = _value
                    else:
                        if controller.current_step < controller.variable_step_dict['self_q']:
                            query = _query
                        if controller.current_step < controller.variable_step_dict['self_k']:
                            key = _key
                            # print("key shape: ", key.shape)
                        if controller.current_step < controller.variable_step_dict['self_v']:
                            value = _value
                    return query, key, value

                if 'save' in controller.save_load_mode:
                    save_tensors(query, key, value, controller.current_step, controller.current_layer)
                if 'load' in controller.save_load_mode:
                    _, key, value = load_tensors(controller.current_step, controller.current_layer, query, key, value)

                controller.current_layer = controller.current_layer + 1
                # print("before head to batch dim: ", query.shape, key.shape, value.shape)
                query = self.head_to_batch_dim(query)
                key = self.head_to_batch_dim(key)
                value = self.head_to_batch_dim(value)
                # print("after head to batch dim: ", query.shape, key.shape, value.shape)

                attention_probs = self.get_attention_scores(query, key, attention_mask)

                # print(query.shape, key.shape, value.shape, attention_probs.shape)

                double_key = torch.cat([key, key], dim=1)
                double_value = torch.cat([value, value], dim=1)
                attention_probs_double_key = self.get_attention_scores(query, double_key, attention_mask)

                # attention_probs = controller(attention_probs, is_cross, place_in_unet)
                hidden_states = torch.bmm(attention_probs, value)

                hidden_states_double = torch.bmm(attention_probs_double_key, double_value)

                hidden_states = self.batch_to_head_dim(hidden_states)

                # linear proj
                hidden_states = to_out(hidden_states)

                if input_ndim == 4:
                    hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

                if self.residual_connection:
                    hidden_states = hidden_states + residual

                hidden_states = hidden_states / self.rescale_output_factor

                return hidden_states
            return forward

        class DummyController:

            def __call__(self, *args):
                return args[0]

            def __init__(self):
                self.num_att_layers = 0

        if controller is None:
            controller = DummyController()

        def register_recr(net_, count, place_in_unet):
            if net_.__class__.__name__ == 'Attention':
                net_.forward = ca_forward(net_, place_in_unet)
                return count + 1
            elif hasattr(net_, 'children'):
                for net__ in net_.children():
                    count = register_recr(net__, count, place_in_unet)
            return count

        cross_att_count = 0
        sub_nets = model.unet.named_children()
        for net in sub_nets:
            if "down" in net[0]:
                cross_att_count += register_recr(net[1], 0, "down")
            elif "up" in net[0]:
                cross_att_count += register_recr(net[1], 0, "up")
            elif "mid" in net[0]:
                cross_att_count += register_recr(net[1], 0, "mid")

        controller.num_att_layers = cross_att_count


    def load_360(image_path):
        image = np.array(Image.open(image_path))[:, :, :3]
        h, w, c = image.shape
        if h < w:
            offset = (w - h) // 2
            image = image[:, offset:offset + h]
        elif w < h:
            offset = (h - w) // 2
            image = image[offset:offset + w]
        image = np.array(Image.fromarray(image).resize((360, 360)))
        return image

    def normal_load(img_path):
        image = np.array(Image.open(img_path))[:, :, :3]
        return image

    class NullInversion:

        def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
            prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
            alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
            alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
            beta_prod_t = 1 - alpha_prod_t
            pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
            pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
            prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
            return prev_sample

        def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
            timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
            alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
            alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
            beta_prod_t = 1 - alpha_prod_t
            next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
            next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
            next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
            return next_sample

        def get_noise_pred_single(self, latents, t, context):
            noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
            return noise_pred

        def get_noise_pred(self, latents, t, is_forward=True, context=None):
            latents_input = torch.cat([latents] * 2)
            if context is None:
                context = self.context
            guidance_scale = 1 if is_forward else self.GUIDANCE_SCALE
            noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
            noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
            if is_forward:
                latents = self.next_step(noise_pred, t, latents)
            else:
                latents = self.prev_step(noise_pred, t, latents)
            return latents

        @torch.no_grad()
        def latent2image(self, latents, return_type='np'):
            latents = 1 / 0.18215 * latents.detach()
            image = self.model.vae.decode(latents)['sample']
            if return_type == 'np':
                image = (image / 2 + 0.5).clamp(0, 1)
                image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
                image = (image * 255).astype(np.uint8)
            return image

        @torch.no_grad()
        def image2latent(self, image):
            with torch.no_grad():
                if type(image) is Image:
                    image = np.array(image)
                if type(image) is torch.Tensor and image.dim() == 4:
                    latents = image
                else:
                    image = torch.from_numpy(image).float() / 127.5 - 1
                    image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device)
                    latents = self.model.vae.encode(image)['latent_dist'].mean
                    latents = latents * 0.18215
            return latents

        @torch.no_grad()
        def init_prompt(self, prompt: str):
            uncond_input = self.model.tokenizer(
                [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
                return_tensors="pt"
            )
            uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
            text_input = self.model.tokenizer(
                [prompt],
                padding="max_length",
                max_length=self.model.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
            self.context = torch.cat([uncond_embeddings, text_embeddings])
            self.prompt = prompt

        @torch.no_grad()
        def ddim_loop(self, latent):
            uncond_embeddings, cond_embeddings = self.context.chunk(2)
            all_latent = [latent]
            latent = latent.clone().detach()
            for i in range(self.NUM_DDIM_STEPS):
                t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
                noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
                latent = self.next_step(noise_pred, t, latent)
                all_latent.append(latent)
            return all_latent

        @property
        def scheduler(self):
            return self.model.scheduler

        @torch.no_grad()
        def ddim_inversion(self, image):
            latent = self.image2latent(image)
            # image_rec = self.latent2image(latent)
            # ddim_latents = self.ddim_loop(latent)
            # return image_rec, ddim_latents
            return latent
            # TODO testing here


        def null_optimization(self, latents, num_inner_steps, epsilon):
            uncond_embeddings, cond_embeddings = self.context.chunk(2)
            uncond_embeddings_list = []
            latent_cur = latents[-1]
            bar = tqdm(total=num_inner_steps * self.NUM_DDIM_STEPS)
            for i in range(self.NUM_DDIM_STEPS):
                uncond_embeddings = uncond_embeddings.clone().detach()
                uncond_embeddings.requires_grad = True
                optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
                latent_prev = latents[len(latents) - i - 2]
                t = self.model.scheduler.timesteps[i]
                with torch.no_grad():
                    noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
                for j in range(num_inner_steps):
                    noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
                    noise_pred = noise_pred_uncond + self.GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
                    latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
                    loss = nnf.mse_loss(latents_prev_rec, latent_prev)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    loss_item = loss.item()
                    bar.update()
                    if loss_item < epsilon + i * 2e-5:
                        break
                for j in range(j + 1, num_inner_steps):
                    bar.update()
                uncond_embeddings_list.append(uncond_embeddings[:1].detach())
                with torch.no_grad():
                    context = torch.cat([uncond_embeddings, cond_embeddings])
                    latent_cur = self.get_noise_pred(latent_cur, t, False, context)
            bar.close()
            return uncond_embeddings_list

        def invert(self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-5, is_normal_load=False):
            self.init_prompt(prompt)
            utils.register_attention_control(self.model, None)
            if is_normal_load:
                image_gt = normal_load(image_path)
            else:
                image_gt = load_360(image_path)

            # image_rec, ddim_latents = self.ddim_inversion(image_gt)
            # uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
            # return ddim_latents[-1], uncond_embeddings
            # TODO testing here
            return self.ddim_inversion(image_gt)


        def __init__(self, model, ddim_steps, guidance_scale):
            scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                                      set_alpha_to_one=False)
            self.model = model
            self.tokenizer = self.model.tokenizer
            self.model.scheduler.set_timesteps(ddim_steps)
            self.prompt = None
            self.context = None
            self.NUM_DDIM_STEPS = ddim_steps
            self.GUIDANCE_SCALE = guidance_scale


    if use_mask:
        pipe2 = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16, safety_checker=None).to(device)
        pipe2.set_progress_bar_config(disable=True)
        pipe2.unet.load_state_dict(T.load(os.path.join(ckpt_folder, "unet.pt"), map_location='cpu'))
        pipe2.unet.to(torch.float32)
        pipe2.vae.to(torch.float32)
        pipe2.text_encoder.to(torch.float32)
        latent_inversion = NullInversion(pipe2, ddim_steps=diffusion_steps, guidance_scale=7.5)



    class Config:

        def __init__(self, image_json_path=None):

            with open(image_json_path, 'r') as f:
                image_dict = json.load(f)

            self.external_mask_path = image_dict['source_mask_path']

            self.diffusion_steps = diffusion_steps
            self.variable_step_dict = {'self_q': int(0.0*self.diffusion_steps),
                                        'self_k': int(1.0*self.diffusion_steps),
                                        'self_v': int(1.0*self.diffusion_steps),
                                        'cross_q': int(0.0*self.diffusion_steps),
                                        'cross_k': int(0.0*self.diffusion_steps),
                                        'cross_v': int(0.0*self.diffusion_steps),
                                        'latent': int(0.0*self.diffusion_steps)}

            self.img_location = "dummp"
            self.load_images = []

            self.saved_video_fps = saved_video_fps

            self.result_subfolder = result_subfolder

            self.edit_instruction = "dummy"
            self.save_load_mode = "dummy"

            self.use_mask = use_mask
            self.latent_starts_step = latent_starts_step

            self.include_origin = False
            self.use_shuffle = False

            # self.blend_latent = False

            self.update()

            self.SEED = seed

            self.cross_attention_map = 0.0
            self.self_attention_output = 0.0
            self.use_external_mask = True
            self.use_mask_on_self_attention = True
            self.area_mask_soft = 0.0
            self.time_step_soft = 0.0
            self.external_mask = None
            self.seed = 1
            self.change_distribution_after_swap = False
            self.note = 'face'

        def update(self):
            if 'load' in self.save_load_mode:
                self.load_folder_list = [os.path.join(f"total_step_{self.diffusion_steps}", one_img, self.edit_instruction) for one_img in self.load_images]
                assert len(self.load_folder_list) > 0
            else:
                self.load_folder_list = []
            # if 'save' in self.save_load_mode:
            self.save_root_dir = os.path.join("total_step_" + str(self.diffusion_steps),
                                self.img_location.split('/')[-1].split('.')[0],
                                self.edit_instruction
                                )
        def to_json(self):
            # Create a copy of the object's dictionary
            data = self.__dict__.copy()

            # Remove the 'mask' key from the dictionary if it exists
            if 'external_mask' in data:
                del data['external_mask']

            # Convert the modified dictionary to a JSON string
            return json.dumps(data)


        @classmethod
        def from_json(cls, json_str):
            data = json.loads(json_str)
            instance = cls()
            instance.__dict__.update(data)
            return instance




    config_list = []

    if not attention_stored:
        for idx in range(group_size):
            config_list.append(Config(config_path))

            config_list[-1].load_images = [one.split(".")[0] for one in group_images[:idx]]

            try:    
                config_list[-1].img_location = os.path.join(image_folder, group_images[idx])
            except:
                breakpoint()

            if idx == 0:
                config_list[-1].save_load_mode = "save"
            else:
                config_list[-1].save_load_mode = "save_load"

    for img in total_images:
        config_list.append(Config(config_path))
        config_list[-1].load_images = [one.split(".")[0] for one in group_images]
        config_list[-1].img_location = os.path.join(image_folder, img)
        config_list[-1].save_load_mode = "load"

    for idx, _ in enumerate(config_list):
        config_list[idx].edit_instruction = instruction
        config_list[idx].variable_step_dict = variable_step_dict
        config_list[idx].update()

    total_source_frame_list = []
    total_edited_frame_list = []
    for config in config_list:

        txt = config.edit_instruction 
        img = Image.open(config.img_location).convert('RGB').resize((360, 360))
        img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]
        txt = "what will this image be like if '%s'"%(txt)
        txt = txt+'\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
        conv = conv_templates['vicuna_v1_1'].copy()
        conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)
        txt = conv.get_prompt()
        txt = tokenizer(txt)
        txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])

        with T.inference_mode():
            out = model.generate(txt.unsqueeze(dim=0).to(device), images=img.half().unsqueeze(dim=0).to(device), attention_mask=mask.unsqueeze(dim=0).to(device), 
                                 do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3, 
                                 return_dict_in_generate=True, output_hidden_states=True)
            out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]

            p = min(out.index(32003)-1 if 32003 in out else len(hid)-9, len(hid)-9)
            hid = hid[p:p+8]

            out = remove_alter(tokenizer.decode(out))
            emb = model.edit_head(hid.unsqueeze(dim=0), EMB)


        with T.inference_mode():

            prompt = None
            image = Image.open(config.img_location).convert('RGB').resize((360, 360))
            num_inference_steps = config.diffusion_steps
            guidance_scale = 7.5
            image_guidance_scale = 1.5
            negative_prompt = None
            num_images_per_prompt = 1
            eta = 0.0
            generator = T.Generator(device=device).manual_seed(config.SEED)
            latents = None
            prompt_embeds = emb
            negative_prompt_embeds = NULL
            ip_adapter_image = None
            output_type = "pil"
            return_dict = True
            callback_on_step_end = None
            callback_on_step_end_tensor_inputs = ["latents"]
            # **kwargs,
            kwargs = {}


            callback = kwargs.pop("callback", None)
            callback_steps = kwargs.pop("callback_steps", None)

            if callback is not None:
                deprecate(
                    "callback",
                    "1.0.0",
                    "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
                )
            if callback_steps is not None:
                deprecate(
                    "callback_steps",
                    "1.0.0",
                    "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
                )

            # 0. Check inputs
            pipe.check_inputs(
                prompt,
                callback_steps,
                negative_prompt,
                prompt_embeds,
                negative_prompt_embeds,
                callback_on_step_end_tensor_inputs,
            )
            pipe._guidance_scale = guidance_scale
            pipe._image_guidance_scale = image_guidance_scale

            device = pipe._execution_device

            if ip_adapter_image is not None:
                output_hidden_state = False if isinstance(pipe.unet.encoder_hid_proj, ImageProjection) else True
                image_embeds, negative_image_embeds = pipe.encode_image(
                    ip_adapter_image, device, num_images_per_prompt, output_hidden_state
                )
                if pipe.do_classifier_free_guidance:
                    image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])

            if image is None:
                raise ValueError("`image` input cannot be undefined.")

            # 1. Define call parameters
            if prompt is not None and isinstance(prompt, str):
                batch_size = 1
            elif prompt is not None and isinstance(prompt, list):
                batch_size = len(prompt)
            else:
                batch_size = prompt_embeds.shape[0]

            device = pipe._execution_device

            # 2. Encode input prompt
            prompt_embeds = pipe._encode_prompt(
                prompt,
                device,
                num_images_per_prompt,
                pipe.do_classifier_free_guidance,
                negative_prompt,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds,
            )

            # 3. Preprocess image
            image = pipe.image_processor.preprocess(image)

            # 4. set timesteps
            pipe.scheduler.set_timesteps(num_inference_steps, device=device)
            timesteps = pipe.scheduler.timesteps

            # 5. Prepare Image latents
            image_latents = pipe.prepare_image_latents(
                image,  
                batch_size,
                num_images_per_prompt,
                prompt_embeds.dtype,
                device,
                pipe.do_classifier_free_guidance,
            )

            height, width = image_latents.shape[-2:]
            height = height * pipe.vae_scale_factor
            width = width * pipe.vae_scale_factor

            # 6. Prepare latent variables
            num_channels_latents = pipe.vae.config.latent_channels
            latents = pipe.prepare_latents(
                batch_size * num_images_per_prompt,
                num_channels_latents,
                height,
                width,
                prompt_embeds.dtype,
                device,
                generator,
                latents,
            )

            # 7. Check that shapes of latents and image match the UNet channels
            num_channels_image = image_latents.shape[1]
            if num_channels_latents + num_channels_image != pipe.unet.config.in_channels:
                raise ValueError(
                    f"Incorrect configuration settings! The config of `pipeline.unet`: {pipe.unet.config} expects"
                    f" {pipe.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
                    f" `num_channels_image`: {num_channels_image} "
                    f" = {num_channels_latents+num_channels_image}. Please verify the config of"
                    " `pipeline.unet` or your `image` input."
                )

            # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
            extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)

            # 8.1 Add image embeds for IP-Adapter
            added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None



        with T.inference_mode():

            target_mask = cv2.imread(config.external_mask_path)[:,:,0] / 255.0
            config.external_mask = target_mask

            config.self_output_replace_steps = 0.0
            config.self_map_replace_steps = 0.0
            config.cross_map_replace_steps = 0.0

            self_output_replace_steps = config.self_output_replace_steps
            cross_map_replace_steps = config.cross_map_replace_steps

            assert config.self_output_replace_steps + config.self_map_replace_steps <= 1.0

            config.use_external_mask = True
            config.use_mask_on_attention = True
            config.use_mask_on_latent = config.use_mask_on_attention             
            config.attention_change_distribution = True
            config.latent_change_distribution = config.attention_change_distribution
            config.start_with_same_latent = 0
            prompts = ['nothing', 'nothing']
            target_subject_word = 'nothing'
            source_subject_word = 'nothing'
            LOW_RESOURCE = True
            controller = AttentionSwap(prompts, config.diffusion_steps, 
                                        cross_map_replace_steps=config.cross_map_replace_steps, 
                                        self_map_replace_steps=config.self_map_replace_steps, 
                                        self_output_replace_steps=config.self_output_replace_steps,
                                        source_subject_word=source_subject_word, target_subject_word=target_subject_word,
                                        tokenizer=tokenizer, device=device, LOW_RESOURCE=LOW_RESOURCE, 
                                        mask=config.external_mask, 
                                        use_mask_on_attention=config.use_mask_on_attention, use_mask_on_latent=config.use_mask_on_latent,  
                                        attention_change_distribution=config.attention_change_distribution,
                                        latent_change_distribution=config.latent_change_distribution,
                                        start_with_same_latent=config.start_with_same_latent,
                                        )

            # 9. Denoising loop
            num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
            pipe._num_timesteps = len(timesteps)
            with pipe.progress_bar(total=num_inference_steps) as progress_bar:
                for i, t in enumerate(timesteps[:]):
                    # Expand the latents if we are doing classifier free guidance.
                    # The latents are expanded 3 times because for pix2pix the guidance\
                    # is applied for both the text and the input image.
                    latent_model_input = torch.cat([latents] * 3) if pipe.do_classifier_free_guidance else latents

                    # concat latents, image_latents in the channel dimension
                    scaled_latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
                    scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)

                    # predict the noise residual

                    controller.current_step = i
                    controller.current_layer = 0
                    controller.save_root_dir = config.save_root_dir 
                    controller.save_load_mode = config.save_load_mode
                    controller.load_folder_list = config.load_folder_list
                    controller.variable_step_dict = config.variable_step_dict
                    controller.use_shuffle = config.use_shuffle
                    controller.include_origin = config.include_origin

                    register_attention_control(pipe, controller)

                    noise_pred = pipe.unet(
                        scaled_latent_model_input,
                        t,
                        encoder_hidden_states=prompt_embeds,
                        added_cond_kwargs=added_cond_kwargs,
                        return_dict=False,
                    )[0]

                    # perform guidance
                    if pipe.do_classifier_free_guidance:
                        noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
                        noise_pred = (
                            noise_pred_uncond
                            + pipe.guidance_scale * (noise_pred_text - noise_pred_image)
                            + pipe.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
                        )


                    # compute the previous noisy sample x_t -> x_t-1
                    latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
                    # print("latent shape: ", latents.shape)

                    def get_mask_image_path(original_path):
                        directory, filename = os.path.split(original_path)
                        base, extension = os.path.splitext(filename)
                        new_filename = f"{base}_mask{extension}"
                        new_path = os.path.join(directory, new_filename)
                        return new_path

                    if config.use_mask and i < config.variable_step_dict['latent'] and i >= config.latent_starts_step:
                    #     os.makedirs(os.path.join(config.save_root_dir, "saved_latent"), exist_ok=True)
                    #     torch.save(latents, os.path.join(config.save_root_dir, "saved_latent", f"{i}.pt"))
                    # elif config.save_load_mode == "nothing":
                        inverted_latent = latent_inversion.invert(config.img_location, "nothing")
                        # source_latents = torch.load(os.path.join(config.load_folder_list[0], "saved_latent", f"{i}.pt"))
                        mask_image = Image.open(get_mask_image_path(config.img_location)).convert('L')  # Load the mask as grayscale
                        resize_transform = transforms.Resize((45, 45))

                        # Load the mask as grayscale

                        # Combine resizing and dilation in a single transform pipeline
                        if use_dilation:
                            combined_transform = transforms.Compose([
                                transforms.Resize((45, 45)),
                                DilationTransform(size=dilate_size)  # Adjust dilation size as needed
                            ])
                            # Apply the combined transformation
                            resized_mask = combined_transform(mask_image)
                        else:
                            resized_mask = resize_transform(mask_image)
                        mask_tensor = transforms.ToTensor()(resized_mask).unsqueeze(0)  # Shape (1, 1, 64, 64)
                        mask_tensor = (mask_tensor > 0.5).half()
                        mask_tensor = mask_tensor.expand(1, 4, 45, 45).to(device)

                        print(mask_tensor.shape, inverted_latent.shape, latents.shape)
                        latents = (1 - mask_tensor) * inverted_latent.half().to(device) + mask_tensor * latents

                        # latents = latents.half()
                        # latents = source_latents
                    if callback_on_step_end is not None:
                        callback_kwargs = {}
                        for k in callback_on_step_end_tensor_inputs:
                            callback_kwargs[k] = locals()[k]
                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                        latents = callback_outputs.pop("latents", latents)
                        prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                        negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
                        image_latents = callback_outputs.pop("image_latents", image_latents)

                    # call the callback, if provided
                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
                        progress_bar.update()
                        if callback is not None and i % callback_steps == 0:
                            step_idx = i // getattr(pipe.scheduler, "order", 1)
                            callback(step_idx, t, latents)

            if not output_type == "latent":
                image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
                image, has_nsfw_concept = pipe.run_safety_checker(image, device, prompt_embeds.dtype)
            else:
                image = latents
                has_nsfw_concept = None

            if has_nsfw_concept is None:
                do_denormalize = [True] * image.shape[0]
            else:
                do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

            image = pipe.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

            # Offload all models
            pipe.maybe_free_model_hooks()

            if not return_dict:
                final = (image, has_nsfw_concept)
            else:
                final = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

        # def save_image(image, base_path="_output"):
        #     # Ensure the output directory exists
        #     if not os.path.exists(base_path):
        #         os.makedirs(base_path)

        #     # Format the filename as five digits with leading zeros
        #     for i in range(100000):  # This allows for file names from 00000 to 99999
        #         filename = f"{i:05d}.png"
        #         filepath = os.path.join(base_path, filename)

        #         # Check if the file already exists
        #         if not os.path.exists(filepath):
        #             # Save the image if the file does not exist
        #             image.save(filepath)
        #             print(f"Image saved as {filename}")
        #             break
        #     else:
        #         print("All possible filenames are already used.")

        # display(Image.open(config.img_location).convert('RGB'))

        source_image_filename = os.path.basename(config.img_location)
        source_image = Image.open(config.img_location)
        source_save_path = os.path.join(config.result_subfolder, source_image_filename)
        source_image.save(source_save_path)
        print(f"Source image saved: {source_save_path}")

        edited_save_path = os.path.join(config.result_subfolder, f"{source_image_filename.rsplit('.', 1)[0]}_edit.{source_image_filename.rsplit('.', 1)[1]}")
        final.images[0].save(edited_save_path)
        print(f"Edited image saved: {edited_save_path}")

        total_source_frame_list.append(source_save_path)
        total_edited_frame_list.append(edited_save_path)


    def images_to_video_and_gif(image_list, base_filename, fps=15):
        # Ensure the output directory exists
        os.makedirs(config.result_subfolder, exist_ok=True)

        base_filename = "0000_" + base_filename

        # Video and GIF size
        target_size = (512, 512)

        # Define the codec and create a VideoWriter object for MP4
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 'mp4v' is an encoding format compatible with MP4
        video_path = os.path.join(config.result_subfolder, f"{base_filename}.mp4")
        video = cv2.VideoWriter(video_path, fourcc, fps, target_size)

        # Prepare for GIF creation
        gif_images = []

        # Process each image
        for image_path in image_list:
            image = cv2.imread(image_path)  # Read each image
            if image is not None:
                resized_image = cv2.resize(image, target_size)  # Resize the image to 512x512
                video.write(resized_image)  # Add resized frame to video

                # Convert image to PIL format and append to GIF list
                frame = Image.fromarray(cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB))
                gif_images.append(frame)
            else:
                print(f"Warning: Could not read image {image_path}")

        # Release the video writer
        video.release()

        print(f"Video saved: {video_path}")

        # Save the GIF
        gif_path = os.path.join(config.result_subfolder, f"{base_filename}.gif")
        gif_images[0].save(gif_path, save_all=True, append_images=gif_images[1:], optimize=False, duration=1000/fps, loop=0)
        print(f"GIF saved: {gif_path}")


    images_to_video_and_gif(total_source_frame_list[group_size:], 'video_source_15fps', config.saved_video_fps)
    images_to_video_and_gif(total_edited_frame_list[group_size:], 'video_edited_15fps', config.saved_video_fps)
    print("Videos and GIFs have been created and saved.")

    if 'load' in config.save_load_mode:
        load_folder_list = [os.path.join(f"total_step_{config.diffusion_steps}", one_img, config.edit_instruction) for one_img in config.load_images]
        for folder in load_folder_list:
            try:
                shutil.rmtree(folder)
                print(f"Successfully deleted {folder}")
            except Exception as e:
                print(f"Failed to delete {folder}: {e}")