import argparse
import os
import json
import time
import yaml
from tqdm import tqdm
import torch
import shortuuid
import math

from PIL import Image
from contextlib import contextmanager
from typing import Callable

# ====== 你原来的 LLaVA 基础加载 ======
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model_both
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
from peft import PeftModel

def get_random_indices(start_end_indices, num_indices):
    start, end = start_end_indices[0]
    return random.sample(range(start, end), num_indices)
# ===== Ablation 用的 Hook 类 =====
class HookedLVLM:
    def __init__(self, model_path, model_base=None, device="cuda:0", use_prompt_tuning=True):
        disable_torch_init()
        model_path = os.path.expanduser(model_path)
        model_name = get_model_name_from_path(model_path)
        tokenizer, model, image_processor, context_len = load_pretrained_model_both(
            model_path, model_base, model_name, use_prompt_tuning
        )
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.context_len = context_len
        self.device = device

    # @contextmanager
    # def ablate_inputs(self, indices, replacement_tensor):
    #     def ablation_hook(module, args, kwargs):
    #         input_embeds = kwargs["inputs_embeds"]
    #         if input_embeds.shape[-2] == 1:  # not first forward
    #             return args, kwargs
    #         modified_input = input_embeds.clone()
    #         local_replacement_tensor = replacement_tensor.to(modified_input.dtype).to(modified_input.device)
    #         modified_input[:, indices, :] = local_replacement_tensor
    #         kwargs["inputs_embeds"] = modified_input
    #         return args, kwargs
    #     # import ipdb; ipdb.set_trace()
    #     hook = self.model.model.register_forward_pre_hook(ablation_hook, with_kwargs=True)
    #     try:
    #         yield
    #     finally:
    #         hook.remove()
    @contextmanager
    def ablate_inputs(self, indices, replacement_tensor):
        def ablation_hook(module, args, kwargs):
            # 先检查inputs_embeds是否存在且不为None
            if "inputs_embeds" not in kwargs or kwargs["inputs_embeds"] is None:
                return args, kwargs  # 如果None，直接返回，不修改
            
            input_embeds = kwargs["inputs_embeds"]
            if input_embeds.shape[-2] == 1:  # not first forward
                return args, kwargs
            
            modified_input = input_embeds.clone()
            local_replacement_tensor = replacement_tensor.to(modified_input.dtype).to(modified_input.device)
            modified_input[:, indices, :] = local_replacement_tensor
            kwargs["inputs_embeds"] = modified_input
            return args, kwargs
        
        hook = self.model.model.register_forward_pre_hook(ablation_hook, with_kwargs=True)
        try:
            yield
        finally:
            hook.remove()

    def generate(self, image_tensor, prompt, max_new_tokens=50, do_sample=False):
        conv = conv_templates["llava_v1"].copy()
        conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + prompt)
        conv.append_message(conv.roles[1], None)
        full_prompt = conv.get_prompt()
        # full_prompt = full_prompt.unsqueeze(0)
        # print(f"Full prompt: {full_prompt}")
        input_ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").to(self.device)
        # adding batch dimension
        # print(f"Input IDs shape: {input_ids}")
        input_ids = input_ids.unsqueeze(0)
        # print(f"Input IDs shape: {input_ids}")
        # print(f"Image tensor shape: {IMAGE_TOKEN_INDEX}")
        image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
        # import ipdb; ipdb.set_trace()
        with torch.inference_mode():
            output_ids = self.model.generate(
                inputs=input_ids,
                images=image_tensor,
                do_sample=do_sample,
                max_new_tokens=max_new_tokens,
                use_cache=True,
            )
        return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()


# ====== Dataset 保留你的 CustomDataset ======
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, conv_mode="llava_v1"):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.conv_mode = conv_mode

    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]
        if self.model_config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], self.image_processor, self.model_config)[0]

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return line["question_id"], input_ids, image_tensor, prompt

    def __len__(self):
        return len(self.questions)


# import torch

def preprocess_image_for_model(model, image_tensor):
    # Ensure image_tensor is in the correct format (e.g., [batch, channels, height, width])
    image_tensor = image_tensor.to(model.device)
    # if hasattr(model.model, 'vision_tower'):
    with torch.no_grad():
        image_embeds = model.model.model.vision_tower(image_tensor)  # Process image
    return image_embeds
    # else:
    #     raise ValueError("Model does not have a vision tower for image processing.")

def integrated_gradients(model, image_tensor, prompt, image_mean_tensor, steps=50):
    # Preprocess image to get embeddings
    input_embeds = preprocess_image_for_model(model, image_tensor)
    
    # Tokenize prompt if needed
    if prompt:
        text_inputs = model.tokenizer(prompt, return_tensors="pt").to(model.device)
        text_embeds = model.model.get_input_embeddings()(text_inputs['input_ids'])
        # import ipdb; ipdb.set_trace()
        input_embeds = model.model.model.mm_projector(input_embeds.to(model.model.model.mm_projector[0].weight.dtype))
        input_embeds = torch.cat([input_embeds, text_embeds], dim=1)  # Combine image and text embeddings

    # Create baseline embeddings
    baseline_embeds = input_embeds.clone()
    replacement_tensor = image_mean_tensor.to(baseline_embeds.dtype).to(baseline_embeds.device)
    baseline_embeds[:] = replacement_tensor

    diff = input_embeds - baseline_embeds
    integrated_grads = torch.zeros_like(input_embeds)

    for alpha in torch.linspace(0, 1, steps):
        interpolated_embeds = baseline_embeds + alpha * diff
        interpolated_embeds.requires_grad_(True)
        outputs = model.model(inputs_embeds=interpolated_embeds)
        logits = outputs.logits
        yes_logit = logits[0, -1, model.tokenizer.encode("Yes")[0]]
        model.model.zero_grad()
        yes_logit.backward(retain_graph=True)
        if interpolated_embeds.grad is not None:
            integrated_grads += interpolated_embeds.grad
        else:
            print("Warning: Gradients are None for alpha =", alpha.item())

    integrated_grads *= diff / steps
    return integrated_grads

# ====== Integrated Gradients (保留原来的) ======
# def integrated_gradients(model, image_tensor, prompt, image_mean_tensor, steps=50):
#     with torch.no_grad():
#         import ipdb;ipdb.set_trace()
#         input_embeds = model.model.get_input_embeddings()(image_tensor.to(model.device))
#         baseline_embeds = input_embeds.clone()
#         replacement_tensor = image_mean_tensor.to(baseline_embeds.dtype).to(baseline_embeds.device)
#         baseline_embeds[:] = replacement_tensor

#     diff = input_embeds - baseline_embeds
#     integrated_grads = torch.zeros_like(input_embeds)

#     for alpha in torch.linspace(0, 1, steps):
#         interpolated_embeds = baseline_embeds + alpha * diff
#         interpolated_embeds.requires_grad_(True)
#         outputs = model.model(inputs_embeds=interpolated_embeds)
#         logits = outputs.logits
#         yes_logit = logits[0, -1, model.tokenizer.encode("Yes")[0]]
#         model.model.zero_grad()
#         yes_logit.backward(retain_graph=True)
#         integrated_grads += interpolated_embeds.grad

#     integrated_grads *= diff / steps
#     return integrated_grads


def get_high_gradient_indices(model, image_tensor, class_name, prompt, image_mean_tensor, steps=50):
    grads = integrated_gradients(model, image_tensor, prompt, image_mean_tensor, steps)
    token_importance = grads.abs().sum(dim=-1).squeeze()
    sorted_indices = torch.argsort(token_importance, descending=True)
    return sorted_indices.tolist()


# ====== Identification Check ======
def check_identification(model, image_tensor, class_name, ablate_indices, replacement_tensor):
    q1 = "Describe this image."
    prompt1 = f"{q1} ASSISTANT:"
    if ablate_indices:
        with model.ablate_inputs(indices=ablate_indices, replacement_tensor=replacement_tensor):
            answer1 = model.generate(image_tensor, prompt1, max_new_tokens=200)
    else:
        # import ipdb;ipdb.set_trace()
        answer1 = model.generate(image_tensor, prompt1, max_new_tokens=200)

    check1 = class_name in answer1

    q2 = f"Is there a {class_name} in this image?"
    prompt2 = f"{q2} ASSISTANT:"
    if ablate_indices:
        with model.ablate_inputs(indices=ablate_indices, replacement_tensor=replacement_tensor):
            answer2 = model.generate(image_tensor, prompt2, max_new_tokens=10)
    else:
        answer2 = model.generate(image_tensor, prompt2, max_new_tokens=10)

    check2 = "yes" in answer2.lower()
    return {"generative": check1, "polling": check2, "description": answer1}


# ====== Main ======
def main(args):
    # Load model
    model = HookedLVLM(args.model_path, args.model_base, device=args.device, use_prompt_tuning=args.use_prompt_tuning)
    model.model.eval()
    image_mean_tensor = torch.load(args.mean_tensor)

    # Load dataset
    questions = [json.loads(q) for q in open(args.question_file, "r")]
    dataset = CustomDataset(questions, args.image_folder, model.tokenizer, model.image_processor, model.model.config, conv_mode=args.conv_mode)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    results = {}
    i = 0
    for qid, input_ids, image_tensor, prompt in tqdm(dataloader):
        # for class_name in ["Donald Trump","Chihuahua"]:
        # class_name = "dog"  # ⚠️ 可以改成自动从 question 提取
        class_name = "Donald Trump"
        # import ipdb; ipdb.set_trace()
        grad_indices = get_high_gradient_indices(model, image_tensor, class_name, prompt, image_mean_tensor)
        qid_str = str(qid) if isinstance(qid, (list, tuple)) else qid
        print("grad_indices",len(grad_indices))
        res = {
            "no_ablation": check_identification(model, image_tensor, class_name, None, None),
            "grad_ablation_5": check_identification(model, image_tensor, class_name, grad_indices[:5], image_mean_tensor),
            "grad_ablation_10": check_identification(model, image_tensor, class_name, grad_indices[:10], image_mean_tensor),
            "grad_ablation_20": check_identification(model, image_tensor, class_name, grad_indices[:20], image_mean_tensor),
            "grad_ablation_40": check_identification(model, image_tensor, class_name, grad_indices[:40], image_mean_tensor),
            "grad_ablation_60": check_identification(model, image_tensor, class_name, grad_indices[:60], image_mean_tensor),
            "grad_ablation_100": check_identification(model, image_tensor, class_name, grad_indices[:100], image_mean_tensor),
            "grad_ablation_250": check_identification(model, image_tensor, class_name, grad_indices[:250], image_mean_tensor),        
        }
        if i >=1:
            break
        i+=1
        # results[qid] = res
        results[qid_str] = res
    with open(args.results_file, "w") as f:
        json.dump(results, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, required=True)
    parser.add_argument("--question-file", type=str, required=True)
    parser.add_argument("--results-file", type=str, default="results.json")
    parser.add_argument("--mean-tensor", type=str, required=True)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--use-prompt-tuning", action="store_true", default=True)
    args = parser.parse_args()

    main(args)
