import torch
import json
from einops import rearrange
import numpy as np
from functools import partial
import os
from tqdm import tqdm
import argparse
import sys
import re
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
from baukit import TraceDict
from pca import PCA
import random

from yolov12.ultralytics import YOLO

def extract_assistant_output(text):
    match = re.search(r'ASSISTANT:(.*)', text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    
    if "USER:" in text and "ASSISTANT:" not in text:
        user_part = re.search(r'USER:.*?(?=\n\n|\Z)', text, re.DOTALL)
        if user_part:
            return text.replace(user_part.group(0), "").strip()
    
    return text.strip()

def process_new_dataset(input_file, image_dir_path=None):
    results = []
    
    with open(input_file, 'r', encoding='utf-8') as infile:
        data_points = json.load(infile)
        
        for data_point in data_points:
            if image_dir_path and 'image' in data_point:
                data_point["image"] = os.path.join(image_dir_path, data_point["image"])
            
            if "query" in data_point:
                data_point["text"] = data_point["query"]
                
            data_point["question_id"] = str(data_point.get("image_id", data_point.get("id", "unknown")))
            
            if "image" in data_point:
                data_point["image_name"] = os.path.basename(data_point["image"])
            else:
                data_point["image_name"] = "unknown.jpg"
            
            results.append(data_point)
    
    return results

def layer_head_to_flattened_idx(layer, head, num_heads):
    return layer * num_heads + head

def flattened_idx_to_layer_head(flattened_idx, num_heads):
    return flattened_idx // num_heads, flattened_idx % num_heads

def get_question_embedding(question, model):
    embedding = model.encode(question, convert_to_tensor=True)
    return embedding

def find_most_similar_center(question_embedding, stored_centers, center_direction_map):
    similarities = []
    
    for i, center in enumerate(stored_centers):
        similarity = F.cosine_similarity(
            question_embedding.unsqueeze(0).to(center.device), 
            center.unsqueeze(0)
        ).item()
        similarities.append((i, similarity))
    
    most_similar_idx, max_similarity = max(similarities, key=lambda x: x[1])
    print(most_similar_idx)
    return center_direction_map[most_similar_idx], max_similarity

def add_diffusion_noise(encoded_tensor, noise_step):
    num_steps = 1000
    betas = torch.linspace(-6, 6, num_steps)
    betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
    alphas = 1 - betas
    alphas_prod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = torch.sqrt(alphas_prod)
    one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
    
    def q_x(x_0, t):
        noise = torch.randn_like(x_0)
        alphas_t = alphas_bar_sqrt[t]
        alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
        return (alphas_t * x_0 + alphas_1_m_t * noise)
    
    noise_step = int(noise_step)
    return q_x(encoded_tensor, noise_step)

def generate_hallucinations(input_list, synonyms_dict, category_dict):
    if isinstance(synonyms_dict, str):
        synonyms_dict = json.loads(synonyms_dict)
    if isinstance(category_dict, str):
        category_dict = json.loads(category_dict)
    
    word_to_synonyms = {}
    for synonym_group in synonyms_dict:
        for word in synonym_group:
            word_to_synonyms[word.lower()] = [syn.lower() for syn in synonym_group]
    
    word_to_category = {}
    for category in category_dict:
        category_name = category["类别"]
        for word in category["词语"]:
            word_to_category[word.lower()] = category_name
    
    category_to_words = {}
    for category in category_dict:
        category_name = category["类别"]
        category_to_words[category_name] = [word.lower() for word in category["词语"]]
    
    all_words = []
    for category in category_dict:
        all_words.extend([word.lower() for word in category["词语"]])
    
    result = []
    
    for word in input_list:
        word_lower = word.lower()
        
        if word_lower in word_to_category:
            category = word_to_category[word_lower]
            
            words_in_category = category_to_words[category]
            
            synonyms = word_to_synonyms.get(word_lower, [word_lower])
            
            available_words = [w for w in words_in_category if w not in synonyms]
            
            if available_words:
                replacement = random.choice(available_words)
                if word[0].isupper():
                    replacement = replacement.capitalize()
                result.append(replacement)
            else:
                replacement = random.choice([w for w in all_words if w not in synonyms])
                if word[0].isupper():
                    replacement = replacement.capitalize()
                result.append(replacement)
        else:
            replacement = random.choice(all_words)
            if word[0].isupper():
                replacement = replacement.capitalize()
            result.append(replacement)
    
    return result

def detect_objects(model, image_path):
    results = model.predict(source=image_path)
    objects_detected = []
    
    for result in results:
        boxes = result.boxes
        
        for box in boxes:
            confidence = box.conf.item()
            if confidence > 0.5:
                class_id = int(box.cls.item())
                class_name = result.names[class_id]
                
                objects_detected.append(class_name)
    return list(set(objects_detected))

def get_head_activations(model, processor, image, prompt, device, noise_level=None):
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    
    if noise_level is not None:
        if "pixel_values" in inputs:
            encoded_image = inputs["pixel_values"].clone()
            noisy_encoded_image = add_diffusion_noise(encoded_image, noise_level)
            inputs["pixel_values"] = noisy_encoded_image
    
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    activations = {}
    
    def hook_fn(name):
        def _hook(module, input, output):
            activations[name] = output.detach().cpu()
        return _hook
    
    hooks = []
    num_layers = 32
    
    for i in range(num_layers):
        layer_name = f"language_model.model.layers.{i}.self_attn.o_proj"
        module = dict([*model.named_modules()]).get(layer_name)
        if module:
            hook = module.register_forward_hook(hook_fn(layer_name))
            hooks.append(hook)
        else:
            print(f"Module not found: {layer_name}")
    
    with torch.no_grad():
        model(**inputs)
    
    for hook in hooks:
        hook.remove()
    
    return activations

def compute_activation_diff(orig_activations, noisy_activations, n_components=128):
    raw_diff = {}
    for layer_name in orig_activations:
        if layer_name in noisy_activations:
            orig_last_token = orig_activations[layer_name][:, -1, :]
            noisy_last_token = noisy_activations[layer_name][:, -1, :]
            raw_diff[layer_name] = orig_last_token - noisy_last_token
    
    layer_names = list(raw_diff.keys())
    stacked_diffs = torch.cat([raw_diff[name] for name in layer_names], dim=0) 
    
    if stacked_diffs.dim() == 2:
        stacked_diffs = stacked_diffs.unsqueeze(0)
    
    pca = PCA(n_components=n_components)
    transformed_diffs = pca.fit_transform(stacked_diffs)
    
    reconstructed_diffs = pca.inverse_transform(transformed_diffs)
    
    split_diffs = {}
    start_idx = 0
    for layer_name in layer_names:
        layer_size = raw_diff[layer_name].size(0)
        split_diffs[layer_name] = reconstructed_diffs[0, start_idx:start_idx+layer_size]
        start_idx += layer_size
    
    return split_diffs

def select_top_heads_from_image(activation_diff, num_heads_to_select):
    all_importance = []
    layer_head_mapping = {}
    
    for layer_idx, (layer_name, diff_tensor) in enumerate(activation_diff.items()):
        layer_idx = int(layer_name.split('.')[3])
        
        hidden_size = diff_tensor.shape[-1]
        
        head_size_options = [128, 64, 40]
        num_heads = None
        
        for hs in head_size_options:
            if hidden_size % hs == 0:
                num_heads = hidden_size // hs
                head_size = hs
                break
        
        if num_heads is None:
            num_heads = 32
            head_size = hidden_size // num_heads
        
        try:
            diff_reshaped = diff_tensor.reshape(-1, num_heads, head_size)
        except:
            print(f"Couldn't reshape tensor of size {diff_tensor.shape} to [-1, {num_heads}, {head_size}]")
            continue
        
        for head_idx in range(num_heads):
            head_diff = diff_reshaped[:, head_idx, :]
            importance = torch.mean(torch.abs(head_diff)).item()
            
            all_importance.append((importance, layer_idx, head_idx, head_diff))
            layer_head_key = f"{layer_name}_{head_idx}"
            layer_head_mapping[layer_head_key] = (layer_name, head_idx, head_diff)
    
    all_importance.sort(reverse=True, key=lambda x: x[0])
    
    selected_heads = {}
    for i in range(min(num_heads_to_select, len(all_importance))):
        importance, layer_idx, head_idx, head_diff = all_importance[i]
        layer_name = f"language_model.model.layers.{layer_idx}.self_attn.o_proj"
        
        if layer_name not in selected_heads:
            selected_heads[layer_name] = []
        
        head_diff_np = head_diff.numpy()
        
        direction = head_diff_np
            
        selected_heads[layer_name].append((head_idx, direction, 1.0))
    
    return selected_heads

def get_top_heads_by_activation_diff(activation_diff, num_heads_to_select=128):
    if not isinstance(activation_diff, np.ndarray):
        activation_diff = activation_diff.cpu().numpy()
    
    num_layers, num_heads, hidden_dim = activation_diff.shape
    
    head_importance = np.zeros((num_layers, num_heads))
    for layer in range(num_layers):
        for head in range(num_heads):
            head_importance[layer, head] = np.mean(activation_diff[layer, head])
    
    flattened_importance = head_importance.reshape(-1)
    top_indices = np.argsort(np.abs(flattened_importance))[::-1][:num_heads_to_select]
    
    top_heads = [flattened_idx_to_layer_head(idx, num_heads) for idx in top_indices]
    
    return top_heads, head_importance

def build_intervention_dict(top_heads, activation_diff, num_heads):

    if not isinstance(activation_diff, np.ndarray):
        activation_diff = activation_diff.cpu().numpy()

    if len(activation_diff.shape) == 1:
        num_layers = 32  
        hidden_dim = activation_diff.shape[0] // (num_layers * num_heads)
        activation_diff = activation_diff.reshape(num_layers, num_heads, hidden_dim)
    
    interventions = {}
    for layer, head in top_heads: 
        layer_key = f"language_model.model.layers.{layer}.self_attn.o_proj"
        if layer_key not in interventions:
            interventions[layer_key] = []
    
    for layer, head in top_heads:
        direction = activation_diff[layer, head]
        direction = direction / np.linalg.norm(direction)
        proj_val_std = 1.0
        
        layer_key = f"language_model.model.layers.{layer}.self_attn.o_proj"
        interventions[layer_key].append((head, direction, proj_val_std))
    
    for layer_key in interventions:
        interventions[layer_key] = sorted(interventions[layer_key], key=lambda x: x[0])
    
    return interventions

def merge_interventions(embedding_interventions, image_interventions, embedding_weight=0.6, image_weight=0.4):
    merged = {
        "embedding": embedding_interventions,
        "image": image_interventions
    }
    
    return merged

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--question_file", type=str, default="path", help="Input question file path")
    parser.add_argument("--image_dir_path", type=str, default='path', help="Directory containing the images")
    parser.add_argument('--num_heads_embedding', type=int, default=32, help='Number of top heads to select based on embedding')
    parser.add_argument('--num_heads_image', type=int, default=32, help='Number of top heads to select based on image noise')
    parser.add_argument('--alpha', type=float, default=2, help='Intervention strength for embedding-based interventions')
    parser.add_argument('--beta', type=float, default=2, help='Intervention strength for image-based interventions')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--noise_level', type=int, default=500, help='Diffusion noise step (0-999)')
    parser.add_argument('--model_path', type=str, default="path", help="Path to the LLaVA model")
    parser.add_argument('--activation_diff_path', type=str, default="path", help="Path to embedding->activation_diff map")
    parser.add_argument('--output_file', type=str, default="results.jsonl", help="Path to output results in jsonl format")
    parser.add_argument('--pca_components', type=int, default=1, help='Number of PCA components to use')
    parser.add_argument('--yolo_model_path', type=str, default="path", help="Path to YOLO model")
    parser.add_argument('--synonyms_dict_path', type=str, default="path", help="Path to synonyms dictionary")
    parser.add_argument('--category_dict_path', type=str, default="path", help="Path to category dictionary")

    args = parser.parse_args()


    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sentence_model = SentenceTransformer("path")
    sentence_model.to(device)
    
    print("Loading YOLO model...")
    yolo_model = YOLO(args.yolo_model_path)
    
    print("Loading object dictionaries...")
    with open(args.synonyms_dict_path, 'r', encoding='utf-8') as fp:
        synonyms_dict = json.load(fp)
    with open(args.category_dict_path, 'r', encoding='utf-8') as f:
        category_dict = json.load(f)
    
    print("Loading pre-computed cluster centers and directions...")
    center_direction_map = torch.load(args.activation_diff_path)
    stored_centers = [center_direction_map[item]['center'] for item in center_direction_map]
    
    print("Loading LLaVA model...")
    model = LlavaForConditionalGeneration.from_pretrained(args.model_path)
    model.to(device)
    processor = AutoProcessor.from_pretrained(args.model_path)
    
    print("Processing questions...")
    questions = process_new_dataset(args.question_file, args.image_dir_path)
    
    os.makedirs(os.path.dirname(os.path.abspath(args.output_file)), exist_ok=True)
    
    output_file = open(args.output_file, 'w', encoding='utf-8')

    num_heads = 32
    
    for item in tqdm(questions):
        item_id = item["question_id"]
        image_id = item.get("image_id", item_id)
        image_file = item["image"]
        query = item["text"]
        
        try:
            image = Image.open(image_file)

            prompt = f"USER: <image>\n{query} ASSISTANT:"

            import time
            start = time.time()
            
            question_embedding = get_question_embedding(query, sentence_model)
            most_similar_item, similarity = find_most_similar_center(
                question_embedding, stored_centers, center_direction_map
            )
            
            embedding_activation_diff = most_similar_item['direction']
            
            embedding_top_heads, _ = get_top_heads_by_activation_diff(embedding_activation_diff, args.num_heads_embedding)
            
            embedding_interventions = build_intervention_dict(embedding_top_heads, embedding_activation_diff, num_heads)
            
            true_objects = detect_objects(yolo_model, image_file)
            hallu_objects = generate_hallucinations(true_objects, synonyms_dict, category_dict)
            
            true_objects_str = ", ".join(true_objects)
            hallu_objects_str = ", ".join(hallu_objects)
            
            prompt_truth = f"USER: <image>\nPlease describe this image. ASSISTANT: The image depicts {true_objects_str}"
            prompt_hallu = f"USER: <image>\nPlease describe this image. ASSISTANT: The image depicts {hallu_objects_str}"
            
            orig_activations = get_head_activations(
                model, processor, image, prompt_truth, device
            )
            
            hallu_activations = get_head_activations(
                model, processor, image, prompt_hallu, device, noise_level=args.noise_level
            )
            
            activation_diff_dict = compute_activation_diff(
                orig_activations, hallu_activations, n_components=args.pca_components
            )
            
            image_interventions = select_top_heads_from_image(activation_diff_dict, args.num_heads_image)
            
            merged_interventions = merge_interventions(
                embedding_interventions, 
                image_interventions
            )
            
            def lt_modulated_vector_add(head_output, layer_name):
                embedding_layer = False
                image_layer = False
                
                if layer_name in merged_interventions["embedding"]:
                    embedding_layer = True
                
                if layer_name in merged_interventions["image"]:
                    image_layer = True
                
                if not embedding_layer and not image_layer:
                    return head_output
                    
                head_output = rearrange(head_output, 'b s (h d) -> b s h d', h=num_heads)
                
                if embedding_layer:
                    for head, direction, proj_val_std in merged_interventions["embedding"][layer_name]:
                        direction_tensor = torch.tensor(direction, dtype=head_output.dtype, device=head_output.device)
                        
                        head_output[:, -1, head, :] += args.alpha * proj_val_std * direction_tensor
                
                if image_layer:
                    for head, direction, proj_val_std in merged_interventions["image"][layer_name]:
                        direction_tensor = torch.tensor(direction, dtype=head_output.dtype, device=head_output.device)
                        
                        head_output[:, -1, head, :] += args.beta * proj_val_std * direction_tensor
                
                head_output = rearrange(head_output, 'b s h d -> b s (h d)')
                
                return head_output
            
            inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
            
            layers_to_intervene = set()
            for key in merged_interventions["embedding"]:
                layers_to_intervene.add(key)
            for key in merged_interventions["image"]:
                layers_to_intervene.add(key)
            layers_to_intervene = list(layers_to_intervene)
            
            intervention_dict = {layer: lt_modulated_vector_add for layer in layers_to_intervene}
            
            if layers_to_intervene:
                with TraceDict(model, layers_to_intervene, edit_output=lt_modulated_vector_add) as ret:
                    output = model.generate(
                        **inputs,
                        do_sample=False,
                        use_cache=True,
                        max_new_tokens=512
                    )
            else:
                output = model.generate(
                    **inputs,
                    do_sample=False,
                    use_cache=True,
                    max_new_tokens=512
                )
            end = time.time()

            print(end - start)
            
            outputs = processor.batch_decode(output, skip_special_tokens=True)[0]
            caption = extract_assistant_output(outputs)
            print(f"ID: {image_id}, Response: {caption[:100]}...")
            
            result_entry = {
                "image_id": image_id,
                "caption": caption
            }
            
            output_file.write(json.dumps(result_entry, ensure_ascii=False) + '\n')
            output_file.flush()
            
        except Exception as e:
            print(f"Error processing item {item_id}: {e}")
            result_entry = {
                "image_id": image_id,
                "caption": f"Error: {str(e)}"
            }
            output_file.write(json.dumps(result_entry, ensure_ascii=False) + '\n')
            output_file.flush()
    
    output_file.close()
    
    print(f"Results saved to {args.output_file}")

if __name__ == "__main__":
    main()