import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm

import os
import json
import torch
import sys
sys.path.append('/home/user/llava/LLaVA')
from PIL import Image
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)

# Import seed control utility
from utils import setup_seeds

# Set environment variables
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
import argparse

parser = argparse.ArgumentParser(description="POPE QA script")
parser.add_argument("--checkpoints_dir", type=str, default="liuhaotian/llava-v1.5-7b", help="Directory for checkpoints")
parser.add_argument("--output_dir", type=str, default="playground/data/chair/", help="Output directory")
parser.add_argument("--image_folder", type=str, default="/hdd/user/vlm/coco/val2014/", help="Image folder")
parser.add_argument("--custom_image_folder", type=str, default="/hdd/user/vlm/at/coco_100/", help="Custom image folder")

parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--toy_attention_layer", type=int, default=None)

args = parser.parse_args()

# Directories
toy_attention_layer = args.toy_attention_layer
checkpoints_dir = args.checkpoints_dir
output_dir = args.output_dir
image_folder = args.image_folder

setup_seeds(42)

# ========================================
#             Model Initialization
# ========================================
print('Initializing Model')
# Ensure output directory exists    
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = checkpoints_dir
output_name = os.path.basename(checkpoints_dir)
print(output_name)

# Load the model
if toy_attention_layer is not None:
    whiten_attn_matrix = {
        toy_attention_layer: np.load(f"/home/user/llava/LLaVA/toy2/layer_{toy_attention_layer}.npy").astype(np.float16).tolist()
    }
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=checkpoint_path,
        model_base=None,
        torch_dtype=torch.float16,
        cache_dir="/hdd/user/",
        device="cuda:0",
        model_name=get_model_name_from_path(checkpoint_path),
        image_start=35,
        image_length=576,
        toy_attention_layers=[toy_attention_layer],
        toy_attention_metrics=whiten_attn_matrix
)
else:
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=checkpoint_path,
        model_base=None,
        torch_dtype=torch.float16,
        cache_dir="/hdd/user/",
        device="cuda:0",
        model_name=get_model_name_from_path(checkpoint_path),
        image_start=35,
        image_length=576,
    )



custom_base_ids = set()
for filename in os.listdir(custom_image_folder):
    if filename.endswith(".jpg") or filename.endswith(".png"):
        base_id = filename.split("-")[0]  # e.g., '000000271620' from '000000271620-2-9.jpg'
        custom_base_ids.add(base_id)

img_files = os.listdir(args.image_folder)
img_files = [
    fname for fname in img_files
    if fname.endswith((".jpg", ".png"))
    and fname.startswith("COCO_val2014_")
    and fname.replace("COCO_val2014_", "").split(".")[0] not in custom_base_ids
]

random.shuffle(img_files)

with open('/hdd/user/vlm/coco/val2014/anno/instances_val2014.json', 'r') as f:
    lines = f.readlines()
coco_anns = json.loads(lines[0])

img_dict = {}

categories = coco_anns["categories"]
category_names = [c["name"] for c in categories]
category_dict = {int(c["id"]): c["name"] for c in categories}

for img_info in coco_anns["images"]:
    img_dict[img_info["id"]] = {"name": img_info["file_name"], "anns": []}

for ann_info in coco_anns["annotations"]:
    img_dict[ann_info["image_id"]]["anns"].append(
        category_dict[ann_info["category_id"]]
    )



# problematic_files = ["COCO_val2014_000000260230.jpg"]
# img_files = [file for file in img_files if file not in problematic_files]

##128  fail
for img_id in range(len(img_files)):
    print("img_id: ", img_id)
    if img_id == 500:
        break
    img_file = img_files[img_id]
    img_id = int(img_file.split(".jpg")[0][-6:])
    img_info = img_dict[img_id]
    assert img_info["name"] == img_file
    img_anns = set(img_info["anns"])
    img_save = {}
    img_save["image_id"] = img_id

    
    qs = "Please describe this image in detail."
    if model.config.mm_use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    conv = conv_templates["llava_v1"].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    image_path = args.image_folder + "/" + img_file
    raw_image = Image.open(image_path).convert("RGB")

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(0)
    image_tensor = process_images([raw_image], image_processor, model.config).to(model.device, dtype=torch.float16)
        
    with torch.inference_mode():
        output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                max_new_tokens=args.max_new_tokens,
                output_attentions=True, 
            )

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    img_save["caption"] = response

    with open(os.path.join(output_dir, '{}.jsonl'.format(output_name)), "a") as f:
                json.dump(img_save, f)
                f.write('\n')



