import argparse

import torch.utils
from llava.mm_utils import get_model_name_from_path, process_images
from llava.model.builder import load_pretrained_model
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import AutoImageProcessor, Dinov2Model
from torch.utils.data import Dataset, DataLoader
import copy
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava import conversation as conversation_lib
from llava.conversation import conv_templates, SeparatorStyle
import json
from packaging import version
import tokenizers
from accelerate.logging import get_logger
import torch.nn as nn
import os
import numpy as np
import logging
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from accelerate import Accelerator
from tqdm import tqdm
import math
from llava.mm_utils import get_anyres_image_grid_shape
from llava.model.llava_arch import unpad_image
import types
from typing import Union, List
from transformers import TextStreamer
import json
import pickle
from einops import rearrange
from einops.layers.torch import Rearrange
from accelerate.hooks import add_hook_to_module
from transformers.integrations import is_deepspeed_zero3_enabled
from torchinfo import summary

IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
PERSONALIZATION_TOKEN = "<sks>"
logger = get_logger(__name__)


def tokenizer_image_token(prompt, tokenizer, personalization_token_id, image_token_index=IMAGE_TOKEN_INDEX, personalization_token=PERSONALIZATION_TOKEN, num_soft_tokens=16, return_tensors="pt"):
    soft_prompt_ids = [personalization_token_id + i for i in range(1, num_soft_tokens+1)]
    if '<image>' in prompt:
        prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
        sep_prompt_list = prompt.split('<image>')
        input_ids = []
        offset = 0
        if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
            offset = 1
            input_ids.append(prompt_chunks[0][0])
        # Normally, a prompt consists 3 parts:
        # 1 System prompt.
        # <sks> is <soft prompt>.
        # USER: query from user and the corresponding answers
        sys_prompt = sep_prompt_list[0].split(personalization_token)[0].strip()
        input_ids.extend(tokenizer(sys_prompt).input_ids[offset:])
        # Insert <sks> is <soft prompt>.
        input_ids.append(personalization_token_id)
        input_ids.extend(tokenizer('is').input_ids[offset:])
        input_ids.extend(soft_prompt_ids)
        input_ids.append(tokenizer.convert_tokens_to_ids('.'))
        user_prompt = sep_prompt_list[0].split(f'{personalization_token} is <soft token>. ')[1]
        input_ids.extend(tokenizer(user_prompt).input_ids[offset:])
        input_ids.append(image_token_index)
        
        user_ques_ans = sep_prompt_list[1]
        sep_prompt_list = user_ques_ans.split(personalization_token)
        for i, sub_prompt in enumerate(sep_prompt_list):
            while sub_prompt[0] == ' ':
                sub_prompt = sub_prompt[1:]
            while sub_prompt[-1] == ' ':
                sub_prompt = sub_prompt[:-1]
            input_ids.extend(tokenizer(sub_prompt).input_ids[offset:])
            if i < len(sep_prompt_list) - 1:
                input_ids.append(personalization_token_id)
    else:
        sys_prompt = prompt.split(f' {personalization_token} is <soft token>. ')[0]
        prompt_chunks = tokenizer(prompt).input_ids
        offset = 1
        input_ids = []
        input_ids.append(prompt_chunks[0])
        input_ids.extend(tokenizer(sys_prompt).input_ids[offset:])
        input_ids.append(personalization_token_id)
        input_ids.extend(tokenizer('is').input_ids[offset:])
        input_ids.extend(soft_prompt_ids)
        input_ids.append(tokenizer.convert_tokens_to_ids('.'))
        user_prompt = prompt.split(f' {personalization_token} is <soft token>. ')[1]
        sep_prompt_list = user_prompt.split(personalization_token)
        for i, sub_prompt in enumerate(sep_prompt_list):
            while sub_prompt[0] == ' ':
                sub_prompt = sub_prompt[1:]
            while sub_prompt[-1] == ' ':
                sub_prompt = sub_prompt[:-1]
            input_ids.extend(tokenizer(sub_prompt).input_ids[offset:])
            if i < len(sep_prompt_list) - 1:
                input_ids.append(personalization_token_id)
    
    if return_tensors is not None:
        if return_tensors == "pt":
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    return input_ids

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def parse_text_files(txt_file):
    with open(txt_file, "r") as f:
        data = f.read()
    
    return data.split("\n")

def preprocess_v1(
    sources,
    tokenizer,
    personalization_id,
    num_soft_tokens,
    has_image=False
):
    conv = conv_templates["personalized"].copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    if has_image:
        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer,personalization_id,num_soft_tokens=num_soft_tokens, return_tensors='pt') for prompt in conversations], dim=0)
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids
    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO

    # Mask targets
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_INDEX
        for i, rou in enumerate(rounds):
            if rou == "":
                break

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_image:
                round_len = len(tokenizer_image_token(rou, tokenizer, personalization_id, num_soft_tokens=num_soft_tokens))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer, personalization_id, num_soft_tokens=num_soft_tokens)) - 1
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
                round_len -= 1
                instruction_len -= 1

            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

    return dict(
        input_ids=input_ids,
        labels=targets,
    )


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--logging_dir", type=str)
    parser.add_argument("--gradient_accumulation_steps", type=int)
    parser.add_argument("--mixed_precision", type=str, default="no")
    parser.add_argument("--report_to", type=str, default="tensorboard")
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--ref_img", type=str)
    parser.add_argument("--tgt_img", type=str)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default=None)
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--prompt", type=str, required=True)
    parser.add_argument("--num_train_steps", type=int, default=200)
    parser.add_argument("--img_dir", type=str)
    parser.add_argument("--pos_ques", type=str, default="personalization_dataset/pos_question.txt")
    parser.add_argument("--pos_ans", type=str, default="personalization_dataset/pos_answer.txt")
    parser.add_argument("--neg_ques", type=str, default="personalization_dataset/neg_question.txt")
    parser.add_argument("--neg_ans", default="personalization_dataset/neg_answer.txt")
    parser.add_argument("--resume", type=str)
    parser.add_argument("--importance_weight", type=float, default=1.0)
    parser.add_argument("--other_weights", type=float, default=0.1)
    parser.add_argument("--task", type=str)
    parser.add_argument("--infer_ref_img", type=str)
    parser.add_argument("--infer_query_img", type=str)
    parser.add_argument("--checkpoint_path", type=str)
    parser.add_argument("--pos_prob", type=float, default=0.5)
    parser.add_argument("--data_eval_file", type=str)
    parser.add_argument("--use_features", type=str, choices=["dino", "face"], default="dino")
    parser.add_argument("--face_embedding_dir", type=str)
    parser.add_argument("--yes_no_ratio", type=float, default=0.5)
    parser.add_argument("--question", type=str)
    parser.add_argument("--num_query", type=int, default=16)
    parser.add_argument("--ref_img_list", type=str)
    parser.add_argument("--query_img_path", type=str)
    args = parser.parse_args()
    return args

class EncoderBasedDataset(Dataset):
    def __init__(self,
                 identity_dict_file,
                 pos_ques_file,
                 pos_ans_file,
                 neg_ques_file,
                 neg_ans_file,
                 caption_folder,
                 image_path,
                 query_img_path,
                 ref_img_processor,
                 query_img_processor,
                 tokenizer,
                 personalized_id,
                 model_config,
                 pos_prob,
                 yes_no_ratio,
                 num_soft_tokens
                 ):
        super().__init__()
        self.pos_ques = parse_text_files(pos_ques_file)
        self.pos_ans = parse_text_files(pos_ans_file)
        self.neg_ques = parse_text_files(neg_ques_file)
        self.neg_ans = parse_text_files(neg_ans_file)
        with open(identity_dict_file, "rb") as f:
            self.identity_dict = pickle.load(f)
        self.identity_imgs = list(self.identity_dict.keys())[:20000]
        self.caption_folder = caption_folder
        self.image_path = image_path
        self.query_img_path = query_img_path
        self.ref_img_processor = ref_img_processor
        self.query_img_processor = query_img_processor
        self.tokenizer = tokenizer
        self.personalized_id = personalized_id
        self.model_config = model_config

        yes_prompts = ["Yes"]
        self.yes_ids = []
        for yes_prompt in yes_prompts:
            self.yes_ids.extend(self.tokenizer(yes_prompt).input_ids[1:])
        
        no_prompts = ["No"]
        self.no_ids = []
        for no_prompt in no_prompts:
            self.no_ids.extend(self.tokenizer(no_prompt).input_ids[1:])
        
        self.pos_prob = pos_prob
        self.type=type
        self.yes_no_ratio = yes_no_ratio
        self.num_soft_tokens=num_soft_tokens
    
    def __len__(self):
        return len(self.identity_imgs)
    
    def __getitem__(self, idx):
        prob = np.random.rand()
        ref_img = Image.open(os.path.join(self.image_path, self.identity_imgs[idx]))
        # print(self.identity_dict[self.identity_imgs[idx]])
        ref_img = self.ref_img_processor(ref_img, return_tensors="pt")["pixel_values"]
        if type(ref_img) is np.ndarray:
            ref_img = torch.from_numpy(ref_img)
        if prob < self.yes_no_ratio:
            # Yes/no question
            yes_no = True
            is_pos = np.random.rand() < self.pos_prob
            if is_pos:
                ind = np.random.choice(np.arange(len(self.pos_ques)))
                question = self.pos_ques[ind]
                answer = self.pos_ans[ind]
                # question = "Is <sks> in this image?"
                # answer = "Yes"
                # Choose an image from the set of the identity image
                query_img_list = self.identity_dict[self.identity_imgs[idx]]
                query_img_file = np.random.choice(query_img_list)
                query_img = Image.open(os.path.join(self.query_img_path, query_img_file))
                yes_answer = True
            else:
                ind = np.random.choice(np.arange(len(self.neg_ques)))
                question = self.neg_ques[ind]
                answer = self.neg_ans[ind]
                # question = "Is <sks> in this image?"
                # answer = "No"
                # Choose an image from the different set of the identity image
                ind = np.random.choice([i for i in range(len(self.identity_imgs)) if i!=idx])
                query_img_list = self.identity_dict[self.identity_imgs[ind]]
                query_img_file = np.random.choice(query_img_list)
                query_img = Image.open(os.path.join(self.query_img_path, query_img_file))
                yes_answer = False
            sources = [[
                {
                    "from": "human",
                    "value": f"<image>\n{question}"
                },
                {
                    "from": "gpt",
                    "value": f"{answer}"
                }
            ]]
            data_dict = preprocess_v1(
                sources,
                self.tokenizer,
                self.personalized_id,
                num_soft_tokens=self.num_soft_tokens,
                has_image=True
            )
            input_ids = data_dict["input_ids"]
            mask_weights = torch.zeros_like(input_ids, dtype=torch.bool)
            if yes_answer:
                for token_id in self.yes_ids:
                    mask_weights = torch.logical_or(mask_weights, input_ids==token_id)
            else:
                for token_id in self.no_ids:
                    mask_weights = torch.logical_or(mask_weights, input_ids==token_id)
            query_img = process_images([query_img], self.query_img_processor, self.model_config)
        else:
            yes_no = False
            img_file = self.identity_imgs[idx]
            img_idx = img_file.split('.')[0]
            ques_ans_files = os.path.join(self.caption_folder, f"ques_ans_{img_idx}.json")
            with open(ques_ans_files, "r") as f:
                ques_list = json.load(f)["conversations"]
            ques_ans = np.random.choice(ques_list)
            question = ques_ans["USER"]
            answer = ques_ans["gpt"]
            sources = [
                [
                    {
                        "from": "human",
                        "value": question
                    },
                    {
                        "from": "gpt",
                        "value": answer
                    }
                ]
            ]
            data_dict = preprocess_v1(
                sources,
                self.tokenizer,
                self.personalized_id,
                num_soft_tokens=self.num_soft_tokens,
                has_image=True
            )
            query_img = torch.zeros((1,3, 336, 336))
            mask_weights = torch.zeros_like(data_dict["input_ids"])
            yes_answer = False
        data_dict = dict(input_ids = data_dict["input_ids"][0],
                         labels = data_dict["labels"][0])
        data_dict['ref_img'] = ref_img[0]
        data_dict['query_img'] = query_img[0]
        data_dict['mask_weights'] = mask_weights[0]
        data_dict['yes_no'] = torch.tensor([yes_no])
        data_dict['yes_answer'] = torch.tensor([yes_answer])
        return data_dict
    
class EncoderBasedDatasetWithHardNegativeSamples(Dataset):
    def __init__(self,
                 identity_dict_file,
                 pos_ques_file,
                 pos_ans_file,
                 neg_ques_file,
                 neg_ans_file,
                 caption_folder,
                 image_path,
                 query_img_path,
                 ref_img_processor,
                 query_img_processor,
                 tokenizer,
                 personalized_id,
                 model_config,
                 pos_prob,
                 yes_no_ratio,
                 num_soft_tokens
                 ):
        super().__init__()
        self.pos_ques = parse_text_files(pos_ques_file)
        self.pos_ans = parse_text_files(pos_ans_file)
        self.neg_ques = parse_text_files(neg_ques_file)
        self.neg_ans = parse_text_files(neg_ans_file)
        with open(identity_dict_file, "rb") as f:
            self.identity_dict = pickle.load(f)
        self.identity_imgs = list(self.identity_dict.keys())[:20000]
        self.caption_folder = caption_folder
        self.image_path = image_path
        self.query_img_path = query_img_path
        self.ref_img_processor = ref_img_processor
        self.query_img_processor = query_img_processor
        self.tokenizer = tokenizer
        self.personalized_id = personalized_id
        self.model_config = model_config

        yes_prompts = ["Yes"]
        self.yes_ids = []
        for yes_prompt in yes_prompts:
            self.yes_ids.extend(self.tokenizer(yes_prompt).input_ids[1:])
        
        no_prompts = ["No"]
        self.no_ids = []
        for no_prompt in no_prompts:
            self.no_ids.extend(self.tokenizer(no_prompt).input_ids[1:])
        
        self.pos_prob = pos_prob
        self.type=type
        with open("hard_negative_samples.pkl", "rb") as f:
            self.hard_negative_samples = pickle.load(f)
        self.yes_no_ratio = yes_no_ratio
        self.num_soft_tokens = num_soft_tokens
    
    def __len__(self):
        return len(self.identity_imgs)
    
    def __getitem__(self, idx):
        prob = np.random.rand()
        ref_img = Image.open(os.path.join(self.image_path, self.identity_imgs[idx]))
        # print(self.identity_dict[self.identity_imgs[idx]])
        ref_img = self.ref_img_processor(ref_img, return_tensors="pt")["pixel_values"]
        if type(ref_img) is np.ndarray:
            ref_img = torch.from_numpy(ref_img)
        if prob < self.yes_no_ratio:
            # Yes/no question
            yes_no = True
            is_pos = np.random.rand() < self.pos_prob
            if is_pos:
                ind = np.random.choice(np.arange(len(self.pos_ques)))
                question = self.pos_ques[ind]
                answer = self.pos_ans[ind]
                query_img_list = self.identity_dict[self.identity_imgs[idx]]
                query_img_file = np.random.choice(query_img_list)
                query_img = Image.open(os.path.join(self.query_img_path, query_img_file))
                yes_answer = True
            else:
                ind = np.random.choice(np.arange(len(self.neg_ques)))
                question = self.neg_ques[ind]
                answer = self.neg_ans[ind]
                if np.random.rand() < 0.2:
                    negative_samples = self.hard_negative_samples[self.identity_imgs[idx]]["hard_samples"]
                else:
                    negative_samples = self.hard_negative_samples[self.identity_imgs[idx]]["soft_samples"]
                identity_img = np.random.choice(negative_samples)
                query_img_list = self.identity_dict[identity_img]
                query_img_file = np.random.choice(query_img_list)
                query_img = Image.open(os.path.join(self.query_img_path, query_img_file))
                yes_answer = False
            sources = [[
                {
                    "from": "human",
                    "value": f"<image>\n{question}"
                },
                {
                    "from": "gpt",
                    "value": f"{answer}"
                }
            ]]
            data_dict = preprocess_v1(
                sources,
                self.tokenizer,
                self.personalized_id,
                num_soft_tokens=self.num_soft_tokens,
                has_image=True
            )
            input_ids = data_dict["input_ids"]
            mask_weights = torch.zeros_like(input_ids, dtype=torch.bool)
            if yes_answer:
                for token_id in self.yes_ids:
                    mask_weights = torch.logical_or(mask_weights, input_ids==token_id)
            else:
                for token_id in self.no_ids:
                    mask_weights = torch.logical_or(mask_weights, input_ids==token_id)
            query_img = process_images([query_img], self.query_img_processor, self.model_config)
        else:
            yes_no = False
            img_file = self.identity_imgs[idx]
            img_idx = img_file.split('.')[0]
            ques_ans_files = os.path.join(self.caption_folder, f"ques_ans_{img_idx}.json")
            with open(ques_ans_files, "r") as f:
                ques_list = json.load(f)["conversations"]
            ques_ans = np.random.choice(ques_list)
            question = ques_ans["USER"]
            answer = ques_ans["gpt"]
            sources = [
                [
                    {
                        "from": "human",
                        "value": question
                    },
                    {
                        "from": "gpt",
                        "value": answer
                    }
                ]
            ]
            data_dict = preprocess_v1(
                sources,
                self.tokenizer,
                self.personalized_id,
                num_soft_tokens=self.num_soft_tokens,
                has_image=True
            )
            query_img = torch.zeros((1,3, 336, 336))
            mask_weights = torch.zeros_like(data_dict["input_ids"])
        data_dict = dict(input_ids = data_dict["input_ids"][0],
                         labels = data_dict["labels"][0])
        data_dict['ref_img'] = ref_img[0]
        data_dict['query_img'] = query_img[0]
        data_dict['mask_weights'] = mask_weights[0]
        data_dict['yes_no'] = torch.tensor([yes_no])
        return data_dict

def save_checkpoint(accelerator, model, args, step):
    model_state_dict = accelerator.unwrap_model(model).state_dict()
    torch.save(model_state_dict, os.path.join(args.output_dir, f"checkpoint_{step}.ckpt"))

def collate_fn(batch):
    input_ids = [sample["input_ids"] for sample in batch]
    labels = [sample["labels"] for sample in batch]
    mask_weights = [sample["mask_weights"] for sample in batch]
    pad_token_id = 0
    model_max_length = 2048
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=pad_token_id
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=IGNORE_INDEX
    )
    mask_weights = torch.nn.utils.rnn.pad_sequence(
        mask_weights,
        batch_first=True,
        padding_value=0
    )

    input_ids = input_ids[:, :model_max_length]
    labels = labels[:, :model_max_length]
    attention_mask = input_ids.ne(pad_token_id)
    batch_data = dict(
        input_ids=input_ids,
        labels=labels,
        attention_mask=attention_mask,
        mask_weights=mask_weights
    )
    if "ref_img" in batch[0]:
        ref_images = [sample["ref_img"] for sample in batch]
        batch_data["ref_img"] = torch.stack(ref_images)
    if "ref_embeddings" in batch[0]:
        ref_embeddings = [sample["ref_embeddings"] for sample in batch]
        batch_data["ref_embeddings"] = torch.stack(ref_embeddings)

    tgt_images = [sample["query_img"] for sample in batch]
    yes_no = [sample["yes_no"] for sample in batch]
    batch_data["query_img"] = torch.stack(tgt_images)
    batch_data["yes_no"] = torch.stack(yes_no)
    yes_answer = [sample["yes_answer"] for sample in batch]
    batch_data["yes_answer"] = torch.stack(yes_answer)
    
    return batch_data

def prepare_inputs_labels_for_multimodal(
    self, input_ids,
    position_ids, attention_mask,
    past_key_values, labels, images, personalized_id,
    learnable_features, image_sizes=None, mask_weights=None
):
    vision_tower = self.get_vision_tower()
    if vision_tower is None or images is None or input_ids.shape[1] == 1:
        return input_ids, position_ids, attention_mask, past_key_values, None, labels
    
    if type(images) is list or images.ndim == 5:
        if type(images) is list:
            images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
        concat_images = torch.cat([image for image in images], dim=0)
        image_features = self.encode_images(concat_images)
        split_sizes = [image.shape[0] for image in images]
        image_features = torch.split(image_features, split_sizes, dim=0)
        mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
        image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
        if mm_patch_merge_type == 'flat':
            image_features = [x.flatten(0, 1) for x in image_features]
        elif mm_patch_merge_type.startswith('spatial'):
            new_image_features = []
            for image_idx, image_feature in enumerate(image_features):
                if image_feature.shape[0] > 1:
                    base_image_feature = image_feature[0]
                    image_feature = image_feature[1:]
                    height = width = self.get_vision_tower().num_patches_per_side
                    assert height * width == base_image_feature.shape[0]
                    if image_aspect_ratio == 'anyres':
                        num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
                        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                    else:
                        raise NotImplementedError
                    if 'unpad' in mm_patch_merge_type:
                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                        image_feature = unpad_image(image_feature, image_sizes[image_idx])
                        image_feature = torch.cat((
                            image_feature,
                            self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
                        ), dim=-1)
                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                    else:
                        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
                        image_feature = image_feature.flatten(0, 3)
                    image_feature = torch.cat((base_image_feature, image_feature), dim=0)
                else:
                    image_feature = image_feature[0]
                    if 'unpad' in mm_patch_merge_type:
                        image_feature = torch.cat((
                            image_feature,
                            self.model.image_newline[None].to(image_feature.device)
                        ), dim=0)
                new_image_features.append(image_feature)
            image_features = new_image_features
        else:
            raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
    else:
        image_features = self.encode_images(images)
    
    if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
        raise NotImplementedError
    _labels = labels
    _position_ids = position_ids
    _attention_mask = attention_mask
    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
    else:
        attention_mask = attention_mask.bool()
    
    if position_ids is None:
        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
    if labels is None:
        labels = torch.full_like(input_ids, IGNORE_INDEX)
    
    if mask_weights is not None:
        mask_weights_list = []
        mask_weights = [cur_mask_weights[cur_attention_mask] for cur_mask_weights, cur_attention_mask in zip(mask_weights, attention_mask)]
    
    _input_ids = input_ids
    input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
    labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
    new_input_embeds = []
    new_labels = []
    cur_image_idx = 0
    for batch_idx, cur_input_ids in enumerate(input_ids):
        num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
        if num_images == 0:
            cur_image_features = image_features[cur_image_idx]
            personalized_token_indices = [-1] + torch.where(cur_input_ids == personalized_id)[0].tolist() + [len(cur_input_ids)]
            token_embeddings = []
            for i in range(len(personalized_token_indices) - 1):
                token_ids = cur_input_ids[personalized_token_indices[i] + 1: personalized_token_indices[i+1]]
                if (token_ids == personalized_id + 1).any():
                    pos = torch.where(token_ids == personalized_id + 1)[0]
                    token_embeddings.append(self.get_model().embed_tokens(token_ids[:pos]))
                    token_embeddings.append(learnable_features[batch_idx, 1:])
                    token_embeddings.append(self.get_model().embed_tokens(token_ids[pos+len(learnable_features[batch_idx, 1:]):]))
                else:
                    token_embeddings.append(self.get_model().embed_tokens(token_ids))
                
                if i < len(personalized_token_indices) - 2:
                    token_embeddings.append(learnable_features[batch_idx, 0].unsqueeze(0))
            
            cur_input_embeds = torch.cat(token_embeddings)
            new_input_embeds.append(cur_input_embeds)
            new_labels.append(labels[batch_idx])
            if mask_weights is not None:
                cur_mask_weights = mask_weights[batch_idx]
                mask_weights_list.append(cur_mask_weights)
            cur_image_idx += 1
            continue
        image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
        cur_input_ids_noim = []
        cur_labels = labels[batch_idx]
        cur_labels_noim = []
        for i in range(len(image_token_indices) - 1):
            cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
            cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
        # Get sks features
        split_sizes = [x.shape[0] for x in cur_labels_noim]
        cur_input_ids_noim = torch.cat(cur_input_ids_noim, dim=0)
        personalized_token_indices = [-1] + torch.where(cur_input_ids_noim == personalized_id)[0].tolist() + [len(cur_input_ids_noim)]
        token_embeddings = []
        for i in range(len(personalized_token_indices) - 1):
            token_ids = cur_input_ids_noim[personalized_token_indices[i] + 1 : personalized_token_indices[i+1]]
            if (token_ids == personalized_id + 1).any():
                pos = torch.where(token_ids == personalized_id + 1)[0]
                token_embeddings.append(self.get_model().embed_tokens(token_ids[:pos]))

                token_embeddings.append(learnable_features[batch_idx][1:])
                token_embeddings.append(self.get_model().embed_tokens(token_ids[pos+len(learnable_features[batch_idx, 1:]):]))
            else:
                token_embeddings.append(self.get_model().embed_tokens(token_ids))
            if i < len(personalized_token_indices) - 2:
                token_embeddings.append(learnable_features[batch_idx][0].unsqueeze(0))
        
        
        cur_input_embeds = torch.cat(token_embeddings)

        # cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
        cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
        cur_new_input_embeds = []
        cur_new_labels = []

        for i in range(num_images + 1):
            cur_new_input_embeds.append(cur_input_embeds_no_im[i])
            cur_new_labels.append(cur_labels_noim[i])
            if i < num_images:
                cur_image_features = image_features[cur_image_idx]
                cur_image_idx += 1
                cur_new_input_embeds.append(cur_image_features)
                cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
        
        if mask_weights is not None:
            cur_mask_weights = mask_weights[batch_idx]
            cur_new_mask_weights = []
            for i in range(len(image_token_indices) - 1):
                cur_new_mask_weights.append(cur_mask_weights[image_token_indices[i] + 1 : image_token_indices[i+1]])
                if i < len(image_token_indices) - 2:
                    cur_new_mask_weights.append(torch.zeros((cur_image_features.shape[0],), dtype=torch.bool, device=cur_mask_weights.device))
            cur_new_mask_weights = torch.cat(cur_new_mask_weights)
            mask_weights_list.append(cur_new_mask_weights)

        cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

        cur_new_input_embeds = torch.cat(cur_new_input_embeds)
        cur_new_labels = torch.cat(cur_new_labels)

        new_input_embeds.append(cur_new_input_embeds)
        new_labels.append(cur_new_labels)

    # Truncate sequences to max length as image embeddings can make the sequence longer
    tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
    if tokenizer_model_max_length is not None:
        new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
        new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

    # Combine them
    max_len = max(x.shape[0] for x in new_input_embeds)
    batch_size = len(new_input_embeds)

    new_input_embeds_padded = []
    new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
    attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
    position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
    if mask_weights is not None:
        mask_weights_padded = torch.zeros((batch_size, max_len), dtype=mask_weights_list[0].dtype, device=mask_weights_list[0].device)
    for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
        cur_len = cur_new_embed.shape[0]
        if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
            new_input_embeds_padded.append(torch.cat((
                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
                cur_new_embed
            ), dim=0))
            if cur_len > 0:
                new_labels_padded[i, -cur_len:] = cur_new_labels
                attention_mask[i, -cur_len:] = True
                position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
                if mask_weights is not None:
                    mask_weights_padded[i, -cur_len:] = mask_weights_list[i]
        else:
            new_input_embeds_padded.append(torch.cat((
                cur_new_embed,
                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
            ), dim=0))
            if cur_len > 0:
                new_labels_padded[i, :cur_len] = cur_new_labels
                attention_mask[i, :cur_len] = True
                position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
                if mask_weights is not None:
                    mask_weights_padded[i, :cur_len] = mask_weights_list[i]
    new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

    if _labels is None:
        new_labels = None
    else:
        new_labels = new_labels_padded

    if _attention_mask is None:
        attention_mask = None
    else:
        attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

    if _position_ids is None:
        position_ids = None
    
    if mask_weights is not None:
        mask_weights = mask_weights_padded

    return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, mask_weights

class FeatureProjection(nn.Module):
    def __init__(self, input_dim, hidden_dim, embed_output_dim, lm_head_output_dim):
        super().__init__()
        self.embed_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_output_dim)
        )
        self.lm_head_projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, lm_head_output_dim)
        )
    
    def forward(self, x):
        return self.embed_projection(x), self.lm_head_projection(x)
    
def masked_mean(t, *, dim, mask=None):
    if mask is None:
        return t.mean(dim=dim)

    denom = mask.sum(dim=dim, keepdim=True)
    mask = rearrange(mask, "b n -> b n 1")
    masked_t = t.masked_fill(~mask, 0.0)

    return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

def FeedForward(dim, mult=4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias=False),
    )


def reshape_tensor(x, heads):
    bs, length, width = x.shape
    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
    x = x.view(bs, length, heads, -1)
    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
    x = x.transpose(1, 2)
    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
    x = x.reshape(bs, heads, length, -1)
    return x


class PerceiverAttention(nn.Module):
    def __init__(self, *, dim, dim_head=64, heads=8):
        super().__init__()
        self.scale = dim_head**-0.5
        self.dim_head = dim_head
        self.heads = heads
        inner_dim = dim_head * heads

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x, latents):
        """
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D)
        """
        x = self.norm1(x)
        latents = self.norm2(latents)

        b, l, _ = latents.shape

        q = self.to_q(latents)
        kv_input = torch.cat((x, latents), dim=-2)
        k, v = self.to_kv(kv_input).chunk(2, dim=-1)

        q = reshape_tensor(q, self.heads)
        k = reshape_tensor(k, self.heads)
        v = reshape_tensor(v, self.heads)

        # attention
        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = weight @ v

        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)

        return self.to_out(out)

class Resampler(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)

        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )
        
        self.linear_weight_mapping = nn.Sequential(
            nn.Linear(embedding_dim, 1024),
            nn.GELU(),
            nn.Linear(1024, 1024),
            nn.GELU(),
            nn.Linear(1024, output_dim)
        )

    def forward(self, x):

        lm_weight = self.linear_weight_mapping(x[:, 0, :])
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)
        
        return self.norm_out(latents), lm_weight

class Resampler2(nn.Module):
    def __init__(
        self,
        dim=1024,
        depth=8,
        dim_head=64,
        heads=16,
        num_queries=8,
        embedding_dim=768,
        output_dim=1024,
        ff_mult=4,
        max_seq_len: int = 257,  # CLIP tokens + CLS token
        apply_pos_emb: bool = False,
        num_latents_mean_pooled: int = 0,  # number of latents derived from mean pooled representation of the sequence
    ):
        super().__init__()
        self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None

        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)

        self.proj_in = nn.Linear(embedding_dim, dim)

        self.proj_out = nn.Linear(dim, output_dim)
        self.norm_out = nn.LayerNorm(output_dim)

        self.to_latents_from_mean_pooled_seq = (
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim * num_latents_mean_pooled),
                Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
            )
            if num_latents_mean_pooled > 0
            else None
        )

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
                        FeedForward(dim=dim, mult=ff_mult),
                    ]
                )
            )
        
        self.linear_weight_mapping = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.GELU(),
            nn.Linear(embedding_dim, 4096),
            nn.GELU(),
            nn.Linear(4096, output_dim)
        )
        self.main_word_mapping = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.GELU(),
            nn.Linear(embedding_dim, 4096),
            nn.GELU(),
            nn.Linear(4096, output_dim)
        )

    def forward(self, x):

        lm_weight = self.linear_weight_mapping(x[:, 0, :])
        main_word_embedding = self.main_word_mapping(x[:, 0, :])
        x = x[:, 1:, :]
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents

        latents = self.proj_out(latents)

        context_embeddings = self.norm_out(latents)
        word_embeddings = torch.cat([main_word_embedding.unsqueeze(1), context_embeddings], dim=1)
        
        return word_embeddings, lm_weight

def convert_ids_to_tokens(
    self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
    """
    Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
    added tokens.

    Args:
        ids (`int` or `List[int]`):
            The token id (or token ids) to convert to tokens.
        skip_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to remove special tokens in the decoding.

    Returns:
        `str` or `List[str]`: The decoded token(s).
    """
    if isinstance(ids, int):
        if ids in self._added_tokens_decoder:
            return self._added_tokens_decoder[ids].content
        else:
            return self._convert_id_to_token(ids)
    tokens = []
    for index in ids:
        index = int(index)
        if skip_special_tokens and index in self.all_special_ids:
            continue
        if index in self._added_tokens_decoder:
            tokens.append(self._added_tokens_decoder[index].content)
        elif index == self.vocab_size:
            tokens.append('<sks>')
        else:
            tokens.append(self._convert_id_to_token(index))
    return tokens

def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
    old_embeddings = self.get_input_embeddings()
    new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens + 16, pad_to_multiple_of)
    if hasattr(old_embeddings, "_hf_hook"):
        hook = old_embeddings._hf_hook
        add_hook_to_module(new_embeddings, hook)
    old_embeddings_requires_grad = old_embeddings.weight.requires_grad
    new_embeddings.requires_grad_(old_embeddings_requires_grad)
    self.set_input_embeddings(new_embeddings)

    # Update new_num_tokens with the actual size of new_embeddings
    if pad_to_multiple_of is not None:
        if is_deepspeed_zero3_enabled():
            import deepspeed

            with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
                new_num_tokens = new_embeddings.weight.shape[0]
        else:
            new_num_tokens = new_embeddings.weight.shape[0]
    
    if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
        old_lm_head = self.get_output_embeddings()
        new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
        if hasattr(old_lm_head, "_hf_hook"):
            hook = old_lm_head._hf_hook
            add_hook_to_module(new_lm_head, hook)
        old_lm_head_requires_grad = old_lm_head.weight.requires_grad
        new_lm_head.requires_grad_(old_lm_head_requires_grad)
        self.set_output_embeddings(new_lm_head)
    
    return self.get_input_embeddings()

def _resize_token_embeddings_multiple(self, new_num_tokens, pad_to_multiple_of=None):
    old_embeddings = self.get_input_embeddings()
    new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens + 32, pad_to_multiple_of)
    if hasattr(old_embeddings, "_hf_hook"):
        hook = old_embeddings._hf_hook
        add_hook_to_module(new_embeddings, hook)
    old_embeddings_requires_grad = old_embeddings.weight.requires_grad
    new_embeddings.requires_grad_(old_embeddings_requires_grad)
    self.set_input_embeddings(new_embeddings)

    # Update new_num_tokens with the actual size of new_embeddings
    if pad_to_multiple_of is not None:
        if is_deepspeed_zero3_enabled():
            import deepspeed

            with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
                new_num_tokens = new_embeddings.weight.shape[0]
        else:
            new_num_tokens = new_embeddings.weight.shape[0]
    
    if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
        old_lm_head = self.get_output_embeddings()
        new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
        if hasattr(old_lm_head, "_hf_hook"):
            hook = old_lm_head._hf_hook
            add_hook_to_module(new_lm_head, hook)
        old_lm_head_requires_grad = old_lm_head.weight.requires_grad
        new_lm_head.requires_grad_(old_lm_head_requires_grad)
        self.set_output_embeddings(new_lm_head)
    
    return self.get_input_embeddings()

# @torch.no_grad()
# def inference(args):
#     model_name = get_model_name_from_path(args.model_path)
#     tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
#     tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens, tokenizer)
#     if args.use_features == "dino":
#         dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
#         dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
#         proj_layer = Resampler(
#             dim=768,
#             depth=4,
#             dim_head=64,
#             heads=16,
#             num_queries=17,
#             output_dim=4096,
#             ff_mult=4
#         )
#     else:
#         dino_processor = None
#         dino_model = None
#         max_seg_length = 1
#         proj_layer = Resampler(
#             dim=768,
#             depth=4,
#             dim_head=64,
#             heads=16,
#             num_queries=17,
#             embedding_dim=512,
#             output_dim=4096,
#             max_seq_len=max_seg_length,
#             ff_mult=4
#         )
#         from facenet_pytorch import MTCNN, InceptionResnetV1
#         mtcnn = MTCNN(image_size=160,
#                         margin=0,
#                         min_face_size=20,
#                         thresholds=[0.6, 0.7, 0.7],
#                         factor=0.709,
#                         post_process=True,
#                         device="cuda")
#         face_model = InceptionResnetV1(pretrained="vggface2").eval().to("cuda")

#     vocab_size = tokenizer.vocab_size
#     model._resize_token_embeddings = types.MethodType(_resize_token_embeddings, model)
#     model.resize_token_embeddings(vocab_size + 1)

#     state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
#     try:
#         proj_layer.load_state_dict(state_dict)
#     except:
#         proj_layer.load_state_dict(state_dict["proj_layer"])
#     proj_layer.to("cuda", dtype=torch.float16)
#     if dino_model is not None:
#         dino_model.to("cuda", dtype=torch.float16)

#     ref_img_dir = args.infer_ref_img
#     query_img_dir = args.infer_query_img
#     ref_img = Image.open(ref_img_dir).convert("RGB")
#     query_img = Image.open(query_img_dir)
#     image_size = query_img.size
#     if args.use_features == "dino":
#         ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
#     else:
#         batch_boxes, batch_probs = mtcnn.detect(ref_img, landmarks=False)
#         ref_img = mtcnn.extract(ref_img, batch_boxes, save_path=None)
#     query_img = process_images([query_img], image_processor, model.config)
#     if args.use_features=="dino":
#         ref_img = ref_img.to("cuda", dtype=torch.float16)
#         dino_features = dino_model(pixel_values=ref_img).last_hidden_state
#     else:
#         dino_features = face_model(ref_img.unsqueeze(0).cuda()).unsqueeze(0).to(dtype=torch.float16)
#     learned_embeddings, proj_weight = proj_layer(dino_features)
#     for i in range(learned_embeddings.shape[1]):
#         model.model.embed_tokens.weight.data[i+vocab_size] = learned_embeddings[0, i, :]
#     model.lm_head.weight.data[vocab_size] = proj_weight[0]

#     query_img = query_img.to("cuda", dtype=torch.float16)
#     conv_mode = "personalized"
#     conv = conv_templates[conv_mode].copy()
#     roles = conv.roles
#     while(True):
#         i = 0
#         try:
#             inp = input(f"{roles[0]}: ")
#         except EOFError:
#             inp = ""
#         if not inp:
#             print("exit...")
#             break

#         print(f"{roles[1]}: ", end="")
#         if query_img_dir is not None:
#             # first message
#             if model.config.mm_use_im_start_end:
#                 inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
#             else:
#                 inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
#             query_img_dir = None
        
#         conv.append_message(conv.roles[0], inp)
#         conv.append_message(conv.roles[1], None)
#         prompt = conv.get_prompt()
#         input_ids = tokenizer_image_token(prompt,
#                                       tokenizer,
#                                       vocab_size)
#         input_ids = input_ids.unsqueeze(0).to(model.device)
#         stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
#         keywords = [stop_str]
#         streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

#         with torch.inference_mode():
#             output_ids = model.generate(
#                 input_ids,
#                 images=query_img,
#                 image_sizes=[image_size],
#                 do_sample=True if args.temperature > 0 else False,
#                 temperature=args.temperature,
#                 max_new_tokens=args.max_new_tokens,
#                 streamer=streamer,
#                 use_cache=True)

#         outputs = tokenizer.decode(output_ids[0]).strip()
#         conv.messages[-1][-1] = outputs

@torch.no_grad()
def inference(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens, tokenizer)
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    max_seq_length=257
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )

    vocab_size = tokenizer.vocab_size
    model._resize_token_embeddings = types.MethodType(_resize_token_embeddings, model)
    model.resize_token_embeddings(vocab_size + 1)

    state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
    proj_layer.load_state_dict(state_dict)
    proj_layer.to("cuda", dtype=torch.float16)
    if dino_model is not None:
        dino_model.to("cuda", dtype=torch.float16)

    ref_img_dir = args.infer_ref_img
    query_img_dir = args.infer_query_img
    # question = "<image>\nChoose the letter corresponding to the correct answer: Which person is <sks> in this photo?\nA.the person at the left of the photo\n.B.the person at the right of the photo"
    # question="<image>\nWhere is the position of <sks> in this photo?"
    question=args.question
    # question = "<image>" + "\n" + question

    conv = conv_templates["personalized"].copy()
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)

    ref_img = Image.open(ref_img_dir).convert("RGB")
    query_img = Image.open(query_img_dir)
    image_size = query_img.size
    if args.use_features == "dino":
        ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
    else:
        batch_boxes, batch_probs = mtcnn.detect(ref_img, landmarks=False)
        ref_img = mtcnn.extract(ref_img, batch_boxes, save_path=None)
    query_img = process_images([query_img], image_processor, model.config)

    query_img = query_img.to("cuda", dtype=torch.float16)

    prompt = conv.get_prompt()
    print(prompt)
    input_ids = tokenizer_image_token(prompt,
                                      tokenizer,
                                      vocab_size)
    input_ids = input_ids.unsqueeze(0)
    import time
    start_time = time.time()
    if args.use_features=="dino":
        ref_img = ref_img.to("cuda", dtype=torch.float16)
        dino_features = dino_model(pixel_values=ref_img).last_hidden_state
    else:
        dino_features = face_model(ref_img.unsqueeze(0).cuda()).unsqueeze(0).to(dtype=torch.float16)
    learned_embeddings, proj_weight = proj_layer(dino_features)
    for i in range(learned_embeddings.shape[1]):
        model.model.embed_tokens.weight.data[i+vocab_size] = learned_embeddings[0, i, :]
    model.lm_head.weight.data[vocab_size] = proj_weight[0]

    output_ids = model.generate(
        input_ids,
        images=query_img,
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        # do_sample=False,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        use_cache=True
    )
    end_time = time.time()
    print(f"Running time {end_time-start_time}")

    outputs = tokenizer.decode(output_ids[0]).strip()
    print(outputs)

@torch.no_grad()
def debugging(args):
    from llava.mm_utils import tokenizer_image_token
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    query_img_dir = args.infer_query_img
    question_list = [
        "What is this person's hair color?",
        "What color are this person's eyes?",
        "What is this person's skin tone?",
        "How would you describe this person's hairstyle?",
        "Does this person have any distinctive facial features?",
        # "What is this person general expression or demeanor?",
        # "What would you describe this person's face?",
        "Is this person young or old?",
        "What do you describe about this person's nose?",
        # "What do you describe about this person's mouth?"
    ]
    question = "<image>\nIs <sks> in the image?"
    conv = conv_templates["v1"].copy()
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)

    query_img = Image.open(query_img_dir)
    image_size = query_img.size
    query_img = process_images([query_img], image_processor, model.config)
    query_img = query_img.to(torch.float16)

    prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
    input_ids = input_ids.unsqueeze(0)
    output_ids = model.generate(
        input_ids,
        images=query_img,
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        use_cache=True
    )
    outputs = tokenizer.decode(output_ids[0]).strip()
    print(outputs)

def main(args):
    torch.manual_seed(42)
    np.random.seed(42)
    logging_dir = os.path.join(args.output_dir, args.logging_dir)
    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # Disable AMP for MPS.
    if torch.backends.mps.is_available():
        accelerator.native_amp = False

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)


    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    
    # Load pretrained Dinov2 model
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    max_seq_length=257
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )
    if args.resume:
        print(f'Resume from {args.resume}')
        state_dict = torch.load(os.path.join(args.output_dir, f"checkpoint_{args.resume}.ckpt"))
        proj_layer.load_state_dict(state_dict)
    
    vocab_size = model.vocab_size
    dataset = EncoderBasedDataset(
        identity_dict_file="identity_dict2.pkl",
        pos_ques_file="personalization_dataset/pos_question.txt",
        pos_ans_file="personalization_dataset/pos_answer.txt",
        neg_ques_file="personalization_dataset/neg_question.txt",
        neg_ans_file="personalization_dataset/neg_answer.txt",
        caption_folder="caption_list2",
        image_path="data/CelebAMask-HQ/CelebA-HQ-img",
        query_img_path=args.query_img_path,
        ref_img_processor=dino_processor,
        query_img_processor=image_processor,
        tokenizer=tokenizer,
        personalized_id=vocab_size,
        model_config=model.config,
        pos_prob=args.pos_prob,
        yes_no_ratio=args.yes_no_ratio,
        num_soft_tokens=args.num_query
    )
    
    print(f"Number of parameters {len(list(proj_layer.parameters()))}")
    train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True,
                                    collate_fn=collate_fn)
    
    for param in model.parameters():
        param.requires_grad = False
    if dino_model is not None:
        for param in dino_model.parameters():
            param.requires_grad = False
    
    optimizer = torch.optim.AdamW(
        proj_layer.parameters(),
        lr=1e-4
    )
    model.prepare_inputs_labels_for_multimodal = types.MethodType(prepare_inputs_labels_for_multimodal, model)
    weight_dtype=torch.float32
    if args.mixed_precision=="fp16":
        weight_dtype=torch.float16
    proj_layer.to(accelerator.device)
    if dino_model is not None:
        dino_model.to(accelerator.device, dtype=weight_dtype)
    proj_layer, optimizer, train_dataloader = accelerator.prepare(proj_layer, optimizer, train_dataloader)
    global_step = 0
    if args.resume:
        global_step = int(args.resume)
    progress_bar = tqdm(
        range(0, args.num_train_steps),
        initial=global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )
    num_updates_per_epochs = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_train_epochs = math.ceil(args.num_train_steps / num_updates_per_epochs)

    for epoch in range(num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(proj_layer):
                if args.use_features == "dino":
                    dino_features = dino_model(pixel_values=batch["ref_img"].to(dtype=weight_dtype)).last_hidden_state
                else:
                    dino_features = batch["ref_embeddings"].unsqueeze(1).to(dtype=weight_dtype)
                learned_embeddings, proj_weight = proj_layer(dino_features)
                learned_embeddings = learned_embeddings.to(dtype=weight_dtype)
                (
                    input_ids,
                    position_ids,
                    attention_mask,
                    past_key_values,
                    inputs_embeds,
                    labels,
                    mask_weights
                ) = model.prepare_inputs_labels_for_multimodal(
                    input_ids=batch["input_ids"],
                    position_ids=None,
                    attention_mask=batch["attention_mask"],
                    past_key_values=None,
                    labels=batch["labels"],
                    images=batch["query_img"].to(dtype=weight_dtype),
                    personalized_id=vocab_size,
                    learnable_features=learned_embeddings,
                    mask_weights=batch["mask_weights"]
                )
                output_attentions = model.config.output_attentions
                output_hidden_states = model.config.output_hidden_states
                return_dict = model.config.use_return_dict

                outputs = model.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    inputs_embeds=inputs_embeds,
                    use_cache=None,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )

                hidden_states = outputs[0]
                logits_wo_personalize = model.lm_head(hidden_states)
                personalized_logits = hidden_states.unsqueeze(2).float() @ proj_weight.unsqueeze(-1).unsqueeze(1)
                personalized_logits = personalized_logits.squeeze(2)
                logits = torch.cat([logits_wo_personalize.float(), personalized_logits], dim=-1)

                loss = None
                if labels is not None:
                    loss_fct = CrossEntropyLoss(reduction="none")
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    shift_mask_weights = mask_weights[..., 1:]
                    yes_no_batch = batch["yes_no"].squeeze(-1)
                    loss = []
                    if yes_no_batch.any():
                        yes_no_logits = shift_logits[yes_no_batch, :, :].view(-1, model.config.vocab_size + 1)
                        yes_no_labels = shift_labels[yes_no_batch, :].view(-1)
                        shift_mask_weights = shift_mask_weights[yes_no_batch, :].view(-1)
                        yes_no_loss = loss_fct(yes_no_logits, yes_no_labels)
                        weights = (yes_no_labels != -100).float() * torch.where(shift_mask_weights==True, args.importance_weight, args.other_weights)
                        weights[yes_no_labels==dataset.no_ids[0]] *= 2.0
                        yes_no_loss = (weights * yes_no_loss).sum() / weights.sum()
                        loss.append(yes_no_loss)
                    
                    if (yes_no_batch == False).any():
                        non_yes_no_batch = torch.logical_not(yes_no_batch)
                        non_yes_no_logits = shift_logits[non_yes_no_batch, :, :].view(-1, model.config.vocab_size + 1)
                        non_yes_no_labels = shift_labels[non_yes_no_batch, :].view(-1)
                        non_yes_no_loss = loss_fct(non_yes_no_logits, non_yes_no_labels)
                        num_logits = torch.sum(non_yes_no_labels != -100).float()
                        loss.append(non_yes_no_loss.sum() / num_logits)
                    
                    loss = torch.mean(torch.stack(loss))
                # if labels is not None:
                #     # Shift so that tokens < n predict n
                #     shift_logits = logits[..., :-1, :].contiguous()
                #     shift_labels = labels[..., 1:].contiguous()
                #     # Flatten the tokens
                #     loss_fct = CrossEntropyLoss()
                #     shift_logits = shift_logits.view(-1, model.config.vocab_size + 1)
                #     shift_labels = shift_labels.view(-1)
                #     # Enable model parallelism
                #     shift_labels = shift_labels.to(shift_logits.device)
                #     loss = loss_fct(shift_logits, shift_labels)

                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1
                    if global_step % 15000 == 0:
                        with torch.no_grad():
                            save_checkpoint(accelerator, proj_layer, args, global_step)
                
                logs = {"loss": loss.detach().item()}
                progress_bar.set_postfix(**logs)
            
            if global_step >= args.num_train_steps:
                break

def eval_yes_no_question(args):
    torch.manual_seed(42)
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path,
                                                                 args.model_base,
                                                                 model_name,
                                                                 args.load_8bit,
                                                                 args.load_4bit,
                                                                 device=args.device)
    tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens, tokenizer)
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )
    vocab_size=tokenizer.vocab_size
    model._resize_token_embeddings = types.MethodType(_resize_token_embeddings, model)
    model.resize_token_embeddings(vocab_size + 1)
    # print(args.output_dir, args.checkpoint_path)
    state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
    try:
        proj_layer.load_state_dict(state_dict)
    except:
        proj_layer.load_state_dict(state_dict["proj_layer"])
    proj_layer.to("cuda", dtype=torch.float16)
    if dino_model is not None:
        dino_model.to("cuda", dtype=torch.float16)
    
    with open(args.data_eval_file, "r") as f:
        test_data = json.load(f)
    num_correct_answer = 0
    incorrect_indices = []
    num_correct_yes = 0
    num_correct_no = 0
    num_yes_question = 0
    num_no_question = 0
    wrong_answers = []
    for idx in tqdm(range(len(test_data))):
        conv = conv_templates["personalized"].copy()
        question = f"<image>\n{test_data[idx]['question']}"
        ref_img_dir = test_data[idx]["ref_img"]
        query_img_dir = test_data[idx]["query_img"]
        conv = conv_templates["personalized"].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        # ref_img = Image.open(os.path.join(ref_img_path, ref_img_dir))
        # query_img = Image.open(os.path.join(query_img_path, query_img_dir))
        ref_img = Image.open(ref_img_dir).convert("RGB")
        query_img = Image.open(query_img_dir)
        image_size = query_img.size
        if args.use_features == "dino":
            ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
            ref_img = ref_img.to("cuda", dtype=torch.float16)
        else:
            batch_boxes, batch_probs = mtcnn.detect(ref_img, landmarks=False)
            ref_img = mtcnn.extract(ref_img, batch_boxes, save_path=None)
        query_img = process_images([query_img], image_processor, model.config)

        query_img = query_img.to("cuda", dtype=torch.float16)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt, tokenizer, vocab_size)
        input_ids = input_ids.unsqueeze(0)
        if args.use_features == "dino":
            dino_features = dino_model(pixel_values=ref_img).last_hidden_state
        else:
            dino_features = face_model(ref_img.unsqueeze(0).cuda()).unsqueeze(0).to(dtype=torch.float16)
        learned_embeddings, proj_weight = proj_layer(dino_features)

        for i in range(learned_embeddings.shape[1]):
            model.model.embed_tokens.weight.data[i+vocab_size] = learned_embeddings[0, i, :]
        model.lm_head.weight.data[vocab_size] = proj_weight[0]

        output_ids = model.generate(
            input_ids,
            images=query_img,
            image_sizes=[image_size],
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            max_new_tokens=args.max_new_tokens,
            use_cache=True
        )
        outputs = tokenizer.decode(output_ids[0]).strip()
        if test_data[idx]["answer"] == "Yes":
            if "Yes" in outputs: 
                num_correct_answer += 1
                # print(f'Correct. Num correct answer {num_correct_answer}')
                num_correct_yes += 1
            else:
                incorrect_indices.append(idx)
                wrong_answers.append(
                    {
                        "ref_img": ref_img_dir,
                        "query_img": query_img_dir
                    }
                )
            num_yes_question += 1
        if test_data[idx]["answer"] == "No":
            if "No" in outputs:
                num_correct_answer += 1
                # print(f'Correct. Num correct answer {num_correct_answer}')
                num_correct_no += 1
            else:
                incorrect_indices.append(idx)
            
            num_no_question+=1
    
    print(f"Number of correct answers {num_correct_answer}")
    print(f"Num yes correct {num_correct_yes} / {num_yes_question}, acc = {num_correct_yes/num_yes_question}")
    print(f"Num no correct {num_correct_no}/{num_no_question}, acc={num_correct_no / num_no_question}")
    # print(f"Wrong answers {wrong_answers}")

@torch.no_grad()
def eval_non_yes_no(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens, tokenizer)
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )
    vocab_size = tokenizer.vocab_size
    model._resize_token_embeddings = types.MethodType(_resize_token_embeddings, model)
    model.resize_token_embeddings(vocab_size + 1)
    state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
    proj_layer.load_state_dict(state_dict)
    proj_layer.to("cuda", dtype=torch.float16)
    dino_model.to("cuda", dtype=torch.float16)

    with open("non_query_ques.json", "r") as f:
        test_data = json.load(f)
    num_correct = 0
    wrong_answers = []
    for ques_ans in test_data:
        question = f"Choose the letter corresponding to the correct answer: {ques_ans['question']}"
        multiple_choice = ques_ans["multiple_choice"]
        multiple_choice_prompt = ''
        for key, value in multiple_choice.items():
            multiple_choice_prompt = multiple_choice_prompt + f"\\n{key}.{value}"
        question += multiple_choice_prompt
        conv = conv_templates["personalized"].copy()
        conv.append_message(conv.roles[0], question)
        conv.append_message(conv.roles[1], None)
        ref_img = Image.open(ques_ans["ref_img"]).convert("RGB")
        
        if "query_img" in ques_ans:
            query_img = Image.open(ques_ans["query_img"])
            image_size = query_img.size
        else:
            query_img = Image.open("validation_recognition/Billie/query_img8.jpg").convert("RGB")
            image_size = query_img.size
        ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
        query_img = process_images([query_img], image_processor, model.config)
        query_img = query_img.to("cuda", dtype=torch.float16)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt,
                                        tokenizer,
                                        vocab_size)
        input_ids = input_ids.unsqueeze(0)
        ref_img = ref_img.to("cuda", dtype=torch.float16)
        dino_features = dino_model(pixel_values=ref_img).last_hidden_state
        learned_embeddings, proj_weight = proj_layer(dino_features)
        for i in range(learned_embeddings.shape[1]):
            model.model.embed_tokens.weight.data[i+vocab_size] = learned_embeddings[0, i, :]
        model.lm_head.weight.data[vocab_size] = proj_weight[0]

        output_ids = model.generate(
            input_ids,
            images=query_img,
            image_sizes=[image_size],
            do_sample=False,
            temperature=args.temperature,
            max_new_tokens=args.max_new_tokens,
            use_cache=True
        )

        outputs = tokenizer.decode(output_ids[0]).strip()
        # if (ques_ans["ref_img"] == "validation_recognition/Pep/ref_img.png") and (ques_ans["question"] == "What is the gender of <sks>?"):
        #     breakpoint()
        outputs = outputs.replace('<s>','')
        outputs = outputs.replace('</s>','')
        outputs = outputs.strip()[0]
        if ques_ans["correct_answer"] == outputs:
            num_correct += 1
        else:
            wrong_answers.append(ques_ans)
    print(wrong_answers)
    print(f"Correct percentage {num_correct, num_correct/len(test_data)}")

@torch.no_grad()
def make_evaluation_for_yollava_dataset(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens, tokenizer)
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )
    vocab_size = tokenizer.vocab_size
    model._resize_token_embeddings = types.MethodType(_resize_token_embeddings, model)
    model.resize_token_embeddings(vocab_size + 1)
    state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
    proj_layer.load_state_dict(state_dict)
    proj_layer.to("cuda", dtype=torch.float16)
    dino_model.to("cuda", dtype=torch.float16)
    all_identities = sorted(os.listdir("yollava_dataset"))
    identity_images = {}
    for identity in all_identities:
        all_imgs = os.listdir(os.path.join("yollava_dataset", identity))
        query_img = [img_name for img_name in all_imgs if img_name!="ref_img.png"]
        query_img = [os.path.join("yollava_dataset", identity, img_name) for img_name in query_img]
        identity_images[identity] = query_img
    
    num_yes_correct = 0; num_no_correct = 0
    num_yes_question = 0; num_no_question = 0
    question = "Is <sks> in this photo?"
    for identity in tqdm(all_identities):
        positive_images = identity_images[identity]
        negative_images = []
        for name in all_identities:
            if name!= identity: negative_images.extend(identity_images[name])
        ref_img = os.path.join("yollava_dataset", identity, "ref_img.png")
        ref_img = Image.open(ref_img).convert("RGB")

        ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
        ref_img = ref_img.to("cuda", dtype=torch.float16)
        dino_features = dino_model(pixel_values=ref_img).last_hidden_state
        learned_embeddings, proj_weight = proj_layer(dino_features)

        for i in range(learned_embeddings.shape[1]):
            model.model.embed_tokens.weight.data[i+vocab_size] = learned_embeddings[0, i, :]
        model.lm_head.weight.data[vocab_size] = proj_weight[0]

        for query_img_dir in positive_images:
            query_img = Image.open(query_img_dir)
            image_size = query_img.size
            query_img = process_images([query_img], image_processor, model.config)
            query_img = query_img.to("cuda", dtype=torch.float16)
            conv = conv_templates["personalized"].copy()
            prompt = "<image>" + '\n' + question
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, tokenizer, vocab_size)
            input_ids = input_ids.unsqueeze(0)
            output_ids = model.generate(
                input_ids,
                images=query_img,
                image_sizes=[image_size],
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                use_cache=True
            )
            outputs = tokenizer.decode(output_ids[0]).strip()
            if "Yes" in outputs:
                num_yes_correct += 1
            
            num_yes_question += 1
        for query_img_dir in negative_images:
            query_img = Image.open(query_img_dir)
            image_size = query_img.size
            query_img = process_images([query_img], image_processor, model.config)
            query_img = query_img.to("cuda", dtype=torch.float16)
            conv = conv_templates["personalized"].copy()
            prompt = "<image>" + '\n' + question
            conv.append_message(conv.roles[0], prompt)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, tokenizer, vocab_size)
            input_ids = input_ids.unsqueeze(0)
            output_ids = model.generate(
                input_ids,
                images=query_img,
                image_sizes=[image_size],
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                use_cache=True
            )
            outputs = tokenizer.decode(output_ids[0]).strip()
            if "No" in outputs:
                num_no_correct += 1
            
            num_no_question += 1
    print(f"Yes accuracy {num_yes_correct / num_yes_question}")
    print(f"No accuracy {num_no_correct / num_no_question}")


def make_yes_no_evaluation_set():
    all_identities = sorted(os.listdir("validation_recognition"))
    test_data_anno = []
    ref_img_list = []
    query_img_list = []
    for id in all_identities:
        ref_img = [img_name for img_name in sorted(os.listdir(os.path.join("validation_recognition", id))) if img_name.startswith("ref_img")][0]
        ref_img = os.path.join("validation_recognition", id, ref_img)
        query_imgs = [img_name for img_name in sorted(os.listdir(os.path.join("validation_recognition", id))) if img_name.startswith("query_img")]
        query_imgs = [os.path.join("validation_recognition", id, query_img) for query_img in query_imgs]
        for query_img in query_imgs:
            ann = {
                "ref_img": ref_img,
                "query_img": query_img,
                "question": "Is <sks> in this image?",
                "answer": "Yes"
            }
            test_data_anno.append(ann)
        ref_img_list.append(ref_img)
        query_img_list.append(query_imgs)
    for i, ref_img in enumerate(ref_img_list):
        default_query_list = []
        for j, query_imgs in enumerate(query_img_list):
            if j==i: continue
            default_query_list.extend(query_imgs)
        
        for query_img in default_query_list:
            ann = {
                "ref_img": ref_img,
                "query_img": query_img,
                "question": "Is <sks> in this image?",
                "answer": "No"
            }
            test_data_anno.append(ann)
    
    with open("real_test_data.json", "w") as f:
        json.dump(test_data_anno, f)

def multi_personalized_image_token(prompt, tokenizer, vocab_size, num_soft_tokens=16, return_tensors="pt"):
    sys_prompt = prompt.split("<B_token>.")[0]
    sys_prompt += "<B_token>."
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
    sep_prompt_list = prompt.split('<image>')
    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])
    part1 = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
    input_ids.extend(tokenizer(part1).input_ids[offset:])
    input_ids.append(vocab_size)
    input_ids.extend(tokenizer('is').input_ids[offset:])
    soft_token_A = [vocab_size + i for i in range(2, 18)]
    input_ids.extend(soft_token_A)
    input_ids.extend(tokenizer('.').input_ids[offset:])
    input_ids.append(vocab_size + 1)
    input_ids.extend(tokenizer('is').input_ids[offset:])
    soft_token_B = [vocab_size + i for i in range(18, 34)]
    input_ids.extend(soft_token_B)
    input_ids.extend(tokenizer('.').input_ids[offset:])
    
    user_prompt = prompt.split("<B_token>. ")[1]
    sep_prompt_list = user_prompt.split("<image>")
    first_set = sep_prompt_list[0]
    input_ids.extend(tokenizer(first_set).input_ids[offset:])
    input_ids.append(IMAGE_TOKEN_INDEX)
    second_set = user_prompt.split("<image>")[1]
    if "<A>" in second_set:
        prompt_lists = second_set.split("<A>")
        for j,prompt_list in enumerate(prompt_lists):
            while prompt_list[-1] == " ": prompt_list = prompt_list[:-1]
            while prompt_list[0] == " ": prompt_list = prompt_list[1:]
            if "<B>" in prompt_list:
                sub_prompt_lists = prompt_list.split("<B>")
                for k, sub_prompt_list in enumerate(sub_prompt_lists):
                    while(sub_prompt_list[-1] == " "): sub_prompt_list = sub_prompt_list[:-1]
                    input_ids.extend(tokenizer(sub_prompt_list).input_ids[offset:])
                    if k < len(sub_prompt_lists) - 1:
                        input_ids.append(vocab_size + 1)
            else:
                input_ids.extend(tokenizer(prompt_list).input_ids[offset:])
            
            if j < len(prompt_lists) - 1:
                input_ids.append(vocab_size)
    else:
        prompt_list = second_set
        if "<B>" in prompt_list:
            sub_prompt_lists = prompt_list.split("<B>")
            for k, sub_prompt_list in enumerate(sub_prompt_lists):
                while(sub_prompt_list[-1] == " "): sub_prompt_list = sub_prompt_list[:-1]
                input_ids.extend(tokenizer(sub_prompt_list).input_ids[offset:])
                if k < len(sub_prompt_lists) - 1:
                    input_ids.append(vocab_size + 1)
        else:
            input_ids.extend(tokenizer(prompt_list).input_ids[offset:])
        input_ids.extend(tokenizer(second_set).input_ids[offset:])
    
    return torch.tensor(input_ids, dtype=torch.long)

def convert_ids_to_tokens_multiple(
    self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
    """
    Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
    added tokens.

    Args:
        ids (`int` or `List[int]`):
            The token id (or token ids) to convert to tokens.
        skip_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to remove special tokens in the decoding.

    Returns:
        `str` or `List[str]`: The decoded token(s).
    """
    if isinstance(ids, int):
        if ids in self._added_tokens_decoder:
            return self._added_tokens_decoder[ids].content
        else:
            return self._convert_id_to_token(ids)
    tokens = []
    for index in ids:
        index = int(index)
        if skip_special_tokens and index in self.all_special_ids:
            continue
        if index in self._added_tokens_decoder:
            tokens.append(self._added_tokens_decoder[index].content)
        elif index == self.vocab_size:
            tokens.append('<A>')
        elif index == self.vocab_size + 1:
            tokens.append('<B>')
        else:
            tokens.append(self._convert_id_to_token(index))
    return tokens

@torch.no_grad()
def inference_with_multiple_identity(args):
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
    tokenizer.convert_ids_to_tokens = types.MethodType(convert_ids_to_tokens_multiple, tokenizer)
    dino_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    dino_model = Dinov2Model.from_pretrained("facebook/dinov2-base")
    proj_layer = Resampler2(
        dim=1024,
        depth=4,
        dim_head=64,
        heads=16,
        num_queries=16,
        embedding_dim=768,
        output_dim=4096,
        max_seq_len=max_seq_length,
        ff_mult=4
    )
    vocab_size = tokenizer.vocab_size
    model._resize_token_embeddings = types.MethodType(_resize_token_embeddings_multiple, model)
    model.resize_token_embeddings(vocab_size + 2)
    state_dict = torch.load(os.path.join(args.output_dir, args.checkpoint_path))
    proj_layer.load_state_dict(state_dict)
    proj_layer.to("cuda", dtype=torch.float16)
    dino_model.to("cuda", dtype=torch.float16)
    ref_img_list = args.ref_img_list.split(" ")
    question = args.question
    question = DEFAULT_IMAGE_TOKEN + '\n' + question
    learned_embeddings, proj_weights = [], []
    for ref_img_name in ref_img_list:
        ref_img = Image.open(ref_img_name).convert("RGB")
        ref_img = dino_processor(ref_img, return_tensors="pt")["pixel_values"]
        ref_img = ref_img.to("cuda", dtype=torch.float16)
        dino_features = dino_model(pixel_values=ref_img).last_hidden_state
        learned_embedding, proj_weight = proj_layer(dino_features)
        learned_embeddings.append(learned_embedding)
        proj_weights.append(proj_weight)
    
    learned_embeddings = torch.cat(learned_embeddings)
    proj_weights = torch.stack(proj_weights)
    model.model.embed_tokens.weight.data[vocab_size] = learned_embeddings[0,0, :]
    model.model.embed_tokens.weight.data[vocab_size + 1] = learned_embeddings[1,0,:]
    model.model.embed_tokens.weight.data[vocab_size+2:vocab_size + 18] = learned_embeddings[0, 1:, :]
    model.model.embed_tokens.weight.data[vocab_size+18:] = learned_embeddings[1, 1:, :]
    model.lm_head.weight.data[vocab_size] = proj_weights[0,0]
    model.lm_head.weight.data[vocab_size+1] = proj_weights[1,0]

    conv = conv_templates["multi_personalized"].copy()
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = multi_personalized_image_token(
        prompt, tokenizer, vocab_size
    )
    input_ids = input_ids.unsqueeze(0)
    query_img = Image.open(args.infer_query_img)
    image_size = query_img.size
    query_img = process_images([query_img], image_processor, model.config)

    query_img = query_img.to("cuda", dtype=torch.float16)
    output_ids = model.generate(
        input_ids,
        images=query_img,
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        # do_sample=False,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        use_cache=True
    )

    outputs = tokenizer.decode(output_ids[0]).strip()
    print(outputs)

if __name__ == "__main__":
    args = parse_arguments()
    if args.task=="train":
        main(args)
    elif args.task=="infer":
        inference(args)
        # inference_with_multiple_identity(args)
    elif args.task=="eval":
        eval_yes_no_question(args)
        # eval_non_yes_no(args)
        # make_evaluation_for_yollava_dataset(args)