import os
import time
import numpy as np
import random
# Record the start time
start_time = time.time()
print(f"Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
import json
import torch
import sys
sys.path.append('/home/user/llava/LLaVA')
from PIL import Image
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)

# Import seed control utility
from utils import setup_seeds

# Initialize seeds
setup_seeds(42)

# Set environment variables
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
import argparse

parser = argparse.ArgumentParser(description="POPE QA script")
parser.add_argument("--checkpoints_dir", type=str, default="/hdd/user/checkpoints/mplug-owl2/baseline_45", help="Directory for checkpoints")
parser.add_argument("--pope_file", type=str, default="/hdd/user/vlm/pope/output/coco/coco_pope_popular.json", help="Path to POPE JSON file")
parser.add_argument("--output_dir", type=str, default="playground/data/baseline/", help="Output directory")
parser.add_argument("--image_folder", type=str, default="/hdd/user/vlm/coco/val2014/", help="Image folder")
parser.add_argument("--toy_attention_layer", type=int, required=True, help="Specify the toy attention layer (0 or 31)")

parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=128)
    
args = parser.parse_args()

# Directories
checkpoints_dir = args.checkpoints_dir
pope_file = args.pope_file
output_dir = args.output_dir
image_folder = args.image_folder

# Load POPE JSON data
try:
    with open(pope_file, "r") as f:
        pope_data = [json.loads(line) for line in f]
except:
    with open(pope_file, "r") as f:
        pope_data = json.load(f)


# Ensure output directory exists    
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = checkpoints_dir
output_name = f"{os.path.basename(checkpoints_dir)}_layer{args.toy_attention_layer}.json"
print(os.path.basename(pope_file))
print(output_name)
# Load the model
whiten_attn_matrix = {
    args.toy_attention_layer: np.load(f"/home/user/llava/LLaVA/toy2/layer_{args.toy_attention_layer}.npy").astype(np.float16).tolist()
}

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=checkpoint_path,
    model_base=None,
    torch_dtype=torch.float16,
    cache_dir="/hdd/user/",
    device="cuda:0",
    model_name=get_model_name_from_path(checkpoint_path),
    image_start=35,
    image_length=576,
    toy_attention_layers=[args.toy_attention_layer],
    toy_attention_metrics=whiten_attn_matrix
)

# Prepare responses
responses = []

unique_images = list({item["image"] for item in pope_data})
# unique_images = unique_images[100:]
print("processing unique images", len(unique_images))

for image_name in unique_images:
    questions = [item for item in pope_data if item["image"] == image_name]
    for question in questions:
        qs = question["text"]
        if model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates["llava_v1"].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(0)

        image_path = os.path.join(image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        image_tensor = process_images([image], image_processor, model.config).to(model.device, dtype=torch.float16)


        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                output_attentions=True, 
            )


        response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
        responses.append({
            "question_id": question["question_id"],
            "image": image_name,
            "text": question["text"],
            "response": response
        })

# Save responses to file
os.makedirs(output_dir, exist_ok=True)

output_file = os.path.join(output_dir, f"{output_name}")
with open(output_file, "w") as f:
    json.dump(responses, f, indent=4)


# Release GPU memory
del model
del tokenizer
del image_processor
torch.cuda.empty_cache()
    

print("Processing complete.")
end_time = time.time()
print(f"End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}")

# Calculate the duration
duration = end_time - start_time
print(f"Duration: {duration:.2f} seconds")