import torch
import warnings

# Suppress all warnings
warnings.filterwarnings("ignore")

# Suppress PyTorch specific warnings
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_warn_always(False)

import os
import json
import copy
from datetime import timedelta
from typing import List, Optional, Tuple, Union

from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from packaging import version
from tqdm import tqdm

from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

from loguru import logger as eval_logger

try:
    from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
    from llava.conversation import conv_templates
    from llava.mm_utils import (
        get_model_name_from_path,
        process_images,
        tokenizer_image_token,
    )
    from llava.model.builder import load_pretrained_model
except Exception as e:
    eval_logger.debug("LLaVA is not installed. Please install LLaVA to use this model.\nError: %s" % e)

# inference implementation for attention, can be "sdpa", "eager", "flash_attention_2"
if version.parse(torch.__version__) >= version.parse("2.1.2"):
    best_fit_attn_implementation = "sdpa"
else:
    best_fit_attn_implementation = "eager"


@register_model("llava")
class Llava(lmms):
    """
    Llava Model
    """

    def __init__(
        self,
        pretrained: str = "liuhaotian/llava-v1.5-7b",
        truncation: Optional[bool] = True,
        device: Optional[str] = "cuda:0",
        batch_size: Optional[Union[int, str]] = 1,
        model_name=None,
        attn_implementation=best_fit_attn_implementation,
        device_map="cuda:0",
        conv_template="vicuna_v1",
        use_cache=True,
        tie_weights: bool = True,
        truncate_context=False,  # whether to truncate the context in generation, set it False for LLaVA-1.6
        customized_config=None,  # ends in json
        cfg=None,
        **kwargs,
    ) -> None:
        super().__init__()
        assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

        self.cfg = cfg or {}

        accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
        accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
        self.accelerator = accelerator
        if accelerator.num_processes > 1:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"
        elif accelerator.num_processes == 1 and device_map == "auto":
            self._device = torch.device(device)
            self.device_map = device_map
        else:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"

        llava_model_args = {"multimodal": True}
        if customized_config is not None:
            llava_model_args["customized_config"] = customized_config
        if attn_implementation is not None:
            llava_model_args["attn_implementation"] = attn_implementation
        if "use_flash_attention_2" in kwargs:
            llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"]

        model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
        try:
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(
                pretrained, None, model_name, device_map=self.device_map, **llava_model_args
            )
        except TypeError:
            llava_model_args.pop("multimodal", None)
            self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(
                pretrained, None, model_name, device_map=self.device_map, **llava_model_args
            )
        self._config = self._model.config
        self.model.eval()
        if tie_weights:
            self.model.tie_weights()

        self.truncation = truncation
        self.batch_size_per_gpu = int(batch_size)
        self.conv_template = conv_template
        self.use_cache = use_cache
        self.truncate_context = truncate_context

        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [
                DistributedType.FSDP,
                DistributedType.MULTI_GPU,
                DistributedType.DEEPSPEED,
            ], "Unsupported distributed type provided. Only DDP and FSDP are supported."
            if accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs = {
                    "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
                    "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
                }
                AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
                eval_logger.info(
                    "Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0"
                )

            if accelerator.distributed_type in (DistributedType.FSDP, DistributedType.DEEPSPEED):
                self._model = accelerator.prepare(self.model)
            else:
                self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        elif accelerator.num_processes == 1 and device_map == "auto":
            eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
            self._rank = 0
            self._world_size = 1
        else:
            eval_logger.info(f"Using single device: {self._device}")
            self.model.to(self._device)
            self._rank = 0
            self._world_size = 1

        # Head-masking 설정
        md = self.cfg.get("metadata", {}) if self.cfg else {}
        self.head_masking_start_layer = md.get("head_masking_start_layer", 0)
        self.head_masking_end_layer = md.get("head_masking_end_layer", 31)
        self.head_masking_start_head = md.get("head_masking_start_head", 0)
        self.head_masking_end_head = md.get("head_masking_end_head", 31)

    @property
    def config(self):
        return self._config

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

    @property
    def eot_token_id(self):
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        return self._max_length

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if self.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
        add_special_tokens = False if add_special_tokens is None else add_special_tokens
        encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
        if left_truncate_len:
            encoding = encoding[-len(encoding):]  # no-op safe
        return encoding

    def tok_decode(self, tokens):
        try:
            return self.tokenizer.decode(tokens)
        except Exception:
            return self.tokenizer.decode([tokens])

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        res = []
        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")

        for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
            if type(doc_to_target) == str:
                continuation = doc_to_target
            else:
                continuation = doc_to_target(self.task_dict[task][split][doc_id])
            visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
            visuals = self.flatten(visuals)
            image_sizes = [[visual.size[0], visual.size[1]] for visual in visuals]
            if visuals:
                image = process_images(visuals, self._image_processor, self._config)
                if type(image) is list:
                    image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
                else:
                    image = image.to(dtype=torch.float16, device=self.device)
            else:
                image = None

            prompts_input = contexts[0] if isinstance(contexts, list) else contexts
            if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
                image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
                image_tokens = " ".join(image_tokens)
                prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts)

            if "llama_3" in self.conv_template:
                conv = copy.deepcopy(conv_templates[self.conv_template])
            else:
                conv = conv_templates[self.conv_template].copy()
            conv.append_message(conv.roles[0], prompts_input)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()
            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
            contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            conv.messages[1][1] = continuation

            prompt = conv.get_prompt()
            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            labels = input_ids.clone()
            labels[0, : contxt_id.shape[1]] = -100
            with torch.inference_mode():
                outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)
            loss = outputs["loss"]
            logits = outputs["logits"]
            greedy_tokens = logits.argmax(dim=-1)
            cont_toks = input_ids[:, contxt_id.shape[1] :]
            greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]]
            max_equal = (greedy_tokens == cont_toks).all()
            res.append((float(loss.item()), bool(max_equal)))
            pbar.update(1)
        pbar.close()
        return res

    def flatten(self, input):
        if not input or any(i is None for i in input):
            return []
        new_list = []
        for i in input:
            if i:
                for j in i:
                    new_list.append(j)
        return new_list

    def generate_until(self, requests: List[Instance], mask_info: dict = None) -> List[str]:
        res = []

        def _collate(x):
            toks = self.tok_encode(x[0])
            return -len(toks), x[0]

        re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
        chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)

        for chunk in chunks:
            contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
            task = task[0]
            split = split[0]
            batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]  # [B, N]
            flattened_visuals = self.flatten(batched_visuals)

            gen_kwargs = all_gen_kwargs[0]

            if "image_aspect_ratio" in gen_kwargs.keys() and "image_aspect_ratio" not in self._config.__dict__:
                self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
                eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")

            if flattened_visuals:
                image_tensor = process_images(flattened_visuals, self._image_processor, self._config)
                if type(image_tensor) is list:
                    image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
                else:
                    image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
            else:
                image_tensor = None

            # 안전한 이미지 존재 플래그
            has_image = False
            if image_tensor is not None:
                if isinstance(image_tensor, list):
                    has_image = len(image_tensor) > 0
                else:
                    has_image = image_tensor.numel() > 0

            question_input = []
            for visual, context in zip(batched_visuals, contexts):
                if has_image and DEFAULT_IMAGE_TOKEN not in context:
                    image_tokens = [DEFAULT_IMAGE_TOKEN] * (len(visual) if isinstance(visual, list) else 1)
                    question = " ".join(image_tokens) + "\n" + context
                else:
                    question = context
                if "llama_3" in self.conv_template:
                    conv = copy.deepcopy(conv_templates[self.conv_template])
                else:
                    conv = conv_templates[self.conv_template].copy()
                conv.append_message(conv.roles[0], question)
                conv.append_message(conv.roles[1], None)
                prompt_question = conv.get_prompt()
                question_input.append(prompt_question)

            if has_image:
                gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))]

            if "max_new_tokens" not in gen_kwargs:
                gen_kwargs["max_new_tokens"] = 1024
            if "temperature" not in gen_kwargs:
                gen_kwargs["temperature"] = 0
            if "top_p" not in gen_kwargs:
                gen_kwargs["top_p"] = None
            if "num_beams" not in gen_kwargs:
                gen_kwargs["num_beams"] = 1

            # ----- 토크나이즈 -----
            raw_input_ids_list = [
                tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
                for prompt in question_input
            ]

            # (옵션) 입력 끝 EOS 제거 — 조기종료 방지
            eos_id = self.tokenizer.eos_token_id
            input_ids_list = []
            for ids in raw_input_ids_list:
                ids = ids.clone()
                if ids.shape[0] > 0 and int(ids[-1].item()) == eos_id:
                    ids = ids[:-1]
                input_ids_list.append(ids)

            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else eos_id
            input_ids = self.pad_sequence(input_ids_list, batch_first=True, padding_value=pad_token_id).to(self.device)
            attention_masks = input_ids.ne(pad_token_id).to(self.device)

            # ★ 실제 입력 길이(생성 시작 오프셋)는 attention mask 합으로
            attn_lens = attention_masks.sum(dim=1).tolist()  # List[int]

            want_scores = bool(self.cfg.get('metadata', {}).get("output_scores", False))

            try:
                with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
                    if has_image:
                        cont = self.model.generate(
                            inputs=input_ids,                     # 중요!
                            attention_mask=attention_masks,
                            pad_token_id=pad_token_id,
                            images=image_tensor,
                            image_sizes=gen_kwargs["image_sizes"],
                            do_sample=True if gen_kwargs["temperature"] > 0 else False,
                            num_beams=gen_kwargs["num_beams"],
                            max_new_tokens=gen_kwargs["max_new_tokens"],
                            use_cache=self.use_cache,
                            output_scores=want_scores,
                            return_dict_in_generate=want_scores,
                            output_attentions=self.cfg.get('metadata', {}).get("output_attentions", False),
                            mask_info=mask_info,
                        )
                    else:
                        cont = self.model.generate(
                            inputs=input_ids,                     # 중요!
                            attention_mask=attention_masks,
                            pad_token_id=pad_token_id,
                            do_sample=True if gen_kwargs["temperature"] > 0 else False,
                            num_beams=gen_kwargs["num_beams"],
                            max_new_tokens=gen_kwargs["max_new_tokens"],
                            use_cache=self.use_cache,
                            output_scores=want_scores,
                            return_dict_in_generate=want_scores,
                            output_attentions=self.cfg.get('metadata', {}).get("output_attentions", False),
                        )

                text_outputs = self.tokenizer.batch_decode(
                    cont.sequences if want_scores else cont, skip_special_tokens=True
                )

            except Exception as e:
                raise e

            # ----- 첫 토큰 로그확률 저장 (옵션) -----
            md = self.cfg.get("metadata", {}) if self.cfg else {}
            save_topk = int(md.get("save_topk", 5))
            save_path = md.get("token_scores_out", "/tmp/llava_token_dists.jsonl")

            if want_scores:
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                try:
                    sequences = cont.sequences  # [B, L_in + T_max]
                    scores = cont.scores        # list of len=T_max, each [B, V]
                    T_max = len(scores)

                    with open(save_path, "a", encoding="utf-8") as f:
                        B = sequences.size(0)
                        for b in range(B):
                            # 1) 우선 attention 기반 오프셋
                            start = int(attn_lens[b])
                            gen_ids = sequences[b, start:]

                            # 2) 그래도 비면(모델 내부 토큰정리와 어긋난 경우) — scores 길이로 백업
                            if gen_ids.numel() == 0 and T_max > 0:
                                gen_start_global = sequences.size(1) - T_max
                                if gen_start_global >= 0:
                                    gen_ids = sequences[b, gen_start_global:]

                            T_b = gen_ids.shape[0]
                            first_token_logprob = None
                            first_token_topk = []

                            if T_b > 0 and T_max > 0:
                                # 첫 생성 스텝 분포
                                lp_b0 = torch.log_softmax(scores[0].float(), dim=-1)[b]  # [V]
                                first_tok_id = int(gen_ids[0].item())
                                first_token_logprob = float(lp_b0[first_tok_id].item())

                                topv, topi = torch.topk(lp_b0, k=min(save_topk, lp_b0.shape[-1]))
                                first_token_topk = [
                                    {"tok": self.tokenizer.decode([int(i.item())]),
                                     "id": int(i.item()),
                                     "logp": float(v.item())}
                                    for i, v in zip(topi, topv)
                                ]

                            rec = {
                                "doc_id": int(doc_id[b]) if isinstance(doc_id, (list, tuple)) else (int(doc_id) if isinstance(doc_id, int) else None),
                                "text": self.tokenizer.decode(gen_ids, skip_special_tokens=True),
                                "first_token_id": int(gen_ids[0].item()) if T_b > 0 else None,
                                "first_token_logprob": first_token_logprob,
                                "first_token_topk": first_token_topk,
                            }
                            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
                except Exception as e:
                    eval_logger.error(f"[TokenDist-SaveError] {e}")

            # 결과 수집
            res.extend(text_outputs)
            self.cache_hook.add_partial("generate_until", (contexts, gen_kwargs), text_outputs)

        res = re_ords.get_original(res)

        if self.cfg.get('metadata', {}).get("head_masking", False) and want_scores:
            attentions = [[att.cpu() for att in layer] for layer in cont.attentions]
            attentions = torch.stack([a.cpu() for a in attentions[0]])
            return res, attentions
        else:
            return res

    def generate_until_with_head_masking(self, requests: List[Instance]) -> List[str]:
        import json
        import os
        from datetime import datetime

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = "/root/Desktop/workspace/miso/LLaVA-NeXT/logs/attn_diff_results"
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f"attn_diff_results_{timestamp}.jsonl")

        print(f"Starting attention difference analysis for {len(requests)} samples...")
        print(f"Results will be saved to: {output_file}")

        with torch.no_grad():
            for request in tqdm(requests, total=len(requests), desc="Processing samples"):
                contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
                sample_data = self.task_dict[task][split][doc_id]

                image_path = sample_data.get('image_path', '')
                text_input = contexts[0] if isinstance(contexts, list) else contexts
                pos_target_word = sample_data.get('positive_target_word', '')
                neg_target_word = sample_data.get('negative_target_word', '')

                try:
                    single_pos_requests = self.update_samples_with_target_word([request], 'positive')
                    single_neg_requests = self.update_samples_with_target_word([request], 'negative')

                    _, pos_attn = self.generate_until(single_pos_requests, mask_info=None)
                    _, neg_attn = self.generate_until(single_neg_requests, mask_info=None)

                    original_attn_diff, original_pos_attn, original_neg_attn = self.get_attn_diff(pos_attn, neg_attn)

                    attn_diff_results = {}
                    attn_diff_results['original'] = {
                        'attn_diff': original_attn_diff[0].item(),
                        'pos_attn': original_pos_attn[0].item(),
                        'neg_attn': original_neg_attn[0].item(),
                    }

                    for layer in tqdm(range(self.head_masking_start_layer, self.head_masking_end_layer + 1), desc=f"Sample {doc_id} - Layers"):
                        for head in tqdm(range(self.head_masking_start_head, self.head_masking_end_head + 1), desc=f"Sample {doc_id} - Heads"):
                            mask_info = self.update_mask_info(layer, head)
                            _, pos_attn = self.generate_until(single_pos_requests, mask_info)
                            _, neg_attn = self.generate_until(single_neg_requests, mask_info)
                            ablated_attn_diff, ablated_pos_attn, ablated_neg_attn = self.get_attn_diff(pos_attn, neg_attn)
                            attn_diff_results[(layer, head)] = {
                                'attn_diff': ablated_attn_diff[0].item(),
                                'pos_attn': ablated_pos_attn[0].item(),
                                'neg_attn': ablated_neg_attn[0].item()
                            }

                    ablated_items = [(k, v) for k, v in attn_diff_results.items() if k != 'original']
                    sorted_ablated_attn_diff = sorted(ablated_items, key=lambda x: abs(x[1]['attn_diff']), reverse=True)
                    top_10_sorted_attn_diff = sorted_ablated_attn_diff[:10]

                    sample_result = {
                        "image_path": image_path,
                        "text_input": text_input,
                        "pos_target_word": pos_target_word,
                        "neg_target_word": neg_target_word,
                        "sorted_attn_diff": top_10_sorted_attn_diff,
                        "original_attn_diff": attn_diff_results['original'],
                        "doc_id": doc_id,
                        "timestamp": timestamp,
                        "output_file": output_file
                    }

                    with open(output_file, 'a', encoding='utf-8') as f:
                        f.write(json.dumps(sample_result, ensure_ascii=False) + '\n')

                    del sample_result

                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"❌ OOM ERROR on Sample {doc_id}!")
                        print(f"❌ Image path: {image_path}")
                        print(f"❌ Text length: {len(text_input)} chars")
                        if torch.cuda.is_available():
                            memory_after = torch.cuda.memory_allocated() / 1024**3
                            print(f"❌ GPU memory at OOM: {memory_after:.2f} GB")
                        continue
                    else:
                        raise e

        print(f"Completed attention difference analysis for all {len(requests)} samples!")
        print(f"All results saved to: {output_file}")
        return [json.dumps({"status": "completed", "output_file": output_file, "num_samples": len(requests)})]

    def get_attn_diff(self, pos_attn, neg_attn):
        # pos/neg shape: (num_layers, bs, num_heads, seq, seq) 가정
        pos_attn = pos_attn[3:15, :, :, -1, -1].mean(0).mean(1)
        neg_attn = neg_attn[3:15, :, :, -1, -1].mean(0).mean(1)
        attn_diff = neg_attn - pos_attn
        return attn_diff, pos_attn, neg_attn

    def update_mask_info(self, layer, head):
        md = self.cfg.get('metadata', {})
        if md.get("head_masking", False):
            mask_qkv = md.get("mask_qkv", ['q'])
            mask_scale_factor = md.get("mask_scale_factor", 0)
            if isinstance(mask_scale_factor, str):
                mask_scale_factor = float(mask_scale_factor)
            mask_info = {
                'layer': layer,
                'head': head,
                'mask_qkv': mask_qkv,
                'mask_scale_factor': mask_scale_factor,
            }
            return mask_info
        return None

    def generate_until_multi_round(self, requests) -> List[str]:
        raise NotImplementedError("TODO: Implement multi-round generation for LLaVA")

    @classmethod
    def from_config(cls, cfg):
        return cls(
            pretrained="liuhaotian/llava-v1.6-vicuna-7b",
            device="cuda:0",
            batch_size=1,
            attn_implementation="eager",
            conv_template="vicuna_v1",
            use_cache=True,
            truncate_context=False,
            cfg=cfg,
        )

    def update_samples_with_target_word(self, requests: List[Instance], target_type: str) -> List[Instance]:
        updated_requests = []
        for request in requests:
            contexts, doc_to_target, doc_to_visual, doc_id, task, split = request.args
            sample_data = self.task_dict[task][split][doc_id]
            if target_type == 'positive':
                target_word = sample_data.get('positive_target_word', '')
            else:
                target_word = sample_data.get('negative_target_word', '')
            if isinstance(contexts, list):
                updated_contexts = [ctx.replace('{target word}', target_word) for ctx in contexts]
            else:
                updated_contexts = contexts.replace('{target word}', target_word)
            updated_args = (updated_contexts, doc_to_target, doc_to_visual, doc_id, task, split)
            updated_request = Instance(
                request_type=request.request_type,
                arguments=updated_args,
                idx=request.idx,
                metadata={"task": task, "doc_id": doc_id, "repeats": request.repeats}
            )
            updated_requests.append(updated_request)
        return updated_requests
