import os
import time

# Record the start time
start_time = time.time()
print(f"Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
import json
import torch
import sys
sys.path.append('/home/user/llava/LLaVA')
from PIL import Image
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)

# Set environment variables
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
import argparse

parser = argparse.ArgumentParser(description="POPE QA script")
parser.add_argument("--checkpoints_dir", type=str, default="/hdd/user/checkpoints/mplug-owl2/baseline_45", help="Directory for checkpoints")
parser.add_argument("--pope_file", type=str, default="/hdd/user/vlm/pope/output/coco/coco_pope_popular.json", help="Path to POPE JSON file")
parser.add_argument("--output_dir", type=str, default="playground/data/baseline/", help="Output directory")
parser.add_argument("--image_folder", type=str, default="/hdd/user/vlm/coco/val2014/", help="Image folder")

parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=128)
    
args = parser.parse_args()

# Directories
checkpoints_dir = args.checkpoints_dir
pope_file = args.pope_file
output_dir = args.output_dir
image_folder = args.image_folder

# Load POPE JSON data
try:
    with open(pope_file, "r") as f:
        pope_data = [json.loads(line) for line in f]
except:
    with open(pope_file, "r") as f:
        pope_data = json.load(f)


# Ensure output directory exists    
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = checkpoints_dir
output_name = os.path.basename(checkpoints_dir) + ".json"
print(os.path.basename(pope_file))
print(output_name)
# Load the model
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=checkpoint_path,
    model_base=None,
    torch_dtype=torch.float16,
    cache_dir="/hdd/user/",
    device="cuda:0",
    model_name=get_model_name_from_path(checkpoint_path),
    image_start=35,
    image_length=576
)

# Prepare responses
responses = []

for item in pope_data:
    image_name = item["image"]
    qs = item["text"]
    if model.config.mm_use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    conv = conv_templates["llava_v1"].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(0)

    image_path = os.path.join(image_folder, image_name)
    image = Image.open(image_path).convert("RGB")
    image_tensor = process_images([image], image_processor, model.config).to(model.device, dtype=torch.float16)


    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            output_attentions=True, 
        )


    response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    responses.append({
        "question_id": question["question_id"],
        "image": image_name,
        "text": question["text"],
        "response": response
    })

# Save responses to file
os.makedirs(output_dir, exist_ok=True)

output_file = os.path.join(output_dir, f"{output_name}")
with open(output_file, "w") as f:
    json.dump(responses, f, indent=4)


# Release GPU memory
del model
del tokenizer
del image_processor
torch.cuda.empty_cache()
    

print("Processing complete.")
end_time = time.time()
print(f"End Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}")

# Calculate the duration
duration = end_time - start_time
print(f"Duration: {duration:.2f} seconds")