import os
import random
import numpy as np
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import sys
sys.path.append('/home/user/llava/LLaVA/llava')

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
import torch
from llava.model import *
from PIL import Image
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.conversation import conv_templates, SeparatorStyle
import numpy as np

# Import seed control utility
from utils import setup_seeds

# Initialize seeds
setup_seeds(42)


whiten_attn_matrix = {
    0: np.zeros((32, 576), dtype=np.float16).tolist()  # Zero attention metric for layer 0
}

model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    torch_dtype=torch.float16,cache_dir = "/hdd/user/",
    device="cuda:0",
    attn_implementation="eager",
    model_name=get_model_name_from_path(model_path),)
    
    


from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)


# Example input
line = {
    "question_id": "0001",
    "image": "COCO_val2014_000000399702.jpg",  # Replace with the actual image file path
    "text": "Please describe the image in detail."
}

args = {
    "image_folder": "/hdd/user/vlm/coco/val2014/",  # Update this path
    "temperature": 0.9,
    "top_p": None,
    "num_beams": 9,
    "conv_mode": "llava_v1",  # Adjust based on your conversation template setup
}


qs = line["text"]
if model.config.mm_use_im_start_end:
    qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
else:
    qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

conv = conv_templates[args['conv_mode']].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

# Tokenize the prompt
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(0)

# Process the image
image_path = os.path.join(args["image_folder"], line["image"])
image = Image.open(image_path).convert("RGB")
image = image.resize((244, 244))
white_image = Image.new('RGB', (244, 244), (255, 255, 255))


image_size = image.size
image_tensor = process_images([white_image], image_processor, model.config).to(model.device, dtype=torch.float16)


with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image_size],
        max_new_tokens=512,
        do_sample=False,
        temperature=0.9,
        output_attentions=True, 
        return_dict_in_generate=True,

    )

# Decode the output
response = tokenizer.decode(output_ids['sequences'][0], skip_special_tokens=True).strip()
print(f"Response: {response}")

# Define save directory
save_dir = "/home/user/llava/LLaVA/llava_next_toy"
os.makedirs(save_dir, exist_ok=True)  # Ensure directory exists

# Loop through 32 layers
for layer_idx in range(32):
    # Extract attention matrix for the current layer (32 heads, 576 tokens)
    attn_matrix = output_ids['attentions'][0][layer_idx][0, :, -1, 5:576+5]  # Shape: (32, 576)

    # Compute the average attention for each head (32, 1)
    head_avg_attn = attn_matrix.mean(dim=1, keepdim=True)  # Shape: (32, 1)

    # Expand to match (32, 576) shape
    head_avg_expanded = head_avg_attn.expand(-1, 576)  # Shape: (32, 576)

    # Multiply element-wise with original attention matrix
    diag_metric = attn_matrix * head_avg_expanded  # Shape: (32, 576)

    # Convert to NumPy array and save
    save_path = os.path.join(save_dir, f"layer_{layer_idx}.npy")
    np.save(save_path, diag_metric.cpu().numpy())

    print(f"Saved: {save_path}")  # Print confirmation

print("All layers processed and saved successfully!")