from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, AutoTokenizer
import torch
import numpy as np
import argparse
import os
import json
from tqdm import tqdm
from peft import PeftModel

FLUSH_EVERY = 200

def eval_model(args):
    pretrained = args.model_path
    device = 'cuda'
    device_map = 'auto'

    if args.tcap:
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            pretrained, 
            dtype=torch.bfloat16,
            attn_implementation='eager',
            device_map=device_map
        )
    else:
        model = Qwen3VLForConditionalGeneration.from_pretrained(
            pretrained, 
            dtype=torch.bfloat16,
            attn_implementation='flash_attention_2',
            device_map=device_map
        )

    if args.lora_path:
        print(f"[Lora] Loading lora weights from {args.lora_path}")
        model = PeftModel.from_pretrained(
            model,
            args.lora_path,
            device_map=device_map,
            dtype=torch.bfloat16
        )
        print("[Lora] Merging weights...")
        model = model.merge_and_unload()
        print("[Lora] Done.")

    model.eval()
    model.tie_weights()

    tokenizer = AutoTokenizer.from_pretrained(pretrained)
    processor = AutoProcessor.from_pretrained(pretrained, min_pixels=4*28*28, max_pixels=1024*28*28)

    # load json
    question_file = open(os.path.expanduser(args.question_file), 'r')
    ans_path = os.path.expanduser(args.answer_file)
    parent_dir = os.path.dirname(ans_path)
    if parent_dir:
        os.makedirs(parent_dir, exist_ok=True)
    answer_file = open(ans_path, 'w')

    if args.question_file.endswith('.json'):
        questions = json.load(question_file)
    elif args.question_file.endswith('.jsonl'):
        questions = [json.loads(line) for line in question_file.readlines()]
    else:
        print('Only Support JSON or JSONL question file.')
        exit(0)
    
    for index, line in enumerate(tqdm(questions)):
        idx = line['id']
        question = line['conversations'][0]['value'].replace('<image>', '').strip()
        if not question:
            question = "This is the picture."
        image_file = line['image']

        messages = [
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": "You are a helpful, generic assistant. Your answers should be accurate, concise, and direct. Do not ramble. If you do not know the answer, admit it instead of making up information. Prioritize clarity and structure (use bullet points or headers) in your responses."}
                ]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": os.path.join(args.image_folder, image_file)},
                    {"type": "text", "text": question},
                ]
            }
        ]

        if args.system_prompt_add:
            messages[0]['content'][0]['text'] += " " + args.system_prompt_add

        if index == 0:
            query = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            print(query)

        IM_START_TOKEN='<|im_start|>'
        IM_END_TOKEN='<|im_end|>'

        VISION_START_TOKEN='<|vision_start|>'
        VISION_END_TOKEN='<|vision_end|>'

        model_inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(device)

        input_ids = model_inputs.input_ids.to(device)

        if args.tcap:
            with torch.inference_mode():
                output_ids = model.generate(
                    **model_inputs,
                    max_new_tokens=1,
                    temperature=args.temperature,
                    do_sample=True if args.temperature > 0 else False,
                    return_dict_in_generate=True,
                    output_attentions=True
                )

            sequences_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, output_ids.sequences)
            ]

            outputs = tokenizer.batch_decode(sequences_trimmed, skip_special_tokens=True)[0].strip()
            attentions = output_ids.attentions[0]

            input_token_len = input_ids.shape[1]
            im_start_token_id = tokenizer.convert_tokens_to_ids(IM_START_TOKEN)
            im_end_token_id = tokenizer.convert_tokens_to_ids(IM_END_TOKEN)
            vision_start_token_id = tokenizer.convert_tokens_to_ids(VISION_START_TOKEN)
            vision_end_token_id = tokenizer.convert_tokens_to_ids(VISION_END_TOKEN)

            indices_im_start = [i for i, d in enumerate(input_ids[0]) if d == im_start_token_id]
            indices_im_end = [i for i, d in enumerate(input_ids[0]) if d == im_end_token_id]
            indice_vision_start = input_ids[0].tolist().index(vision_start_token_id)
            indice_vision_end = input_ids[0].tolist().index(vision_end_token_id)

            # tri-component attention extraction during single image inference
            sys_range = torch.tensor(list(range(0, indice_vision_start)) + list(range(indices_im_end[1], input_token_len)) + [indice_vision_start, indice_vision_end], device=device)
            img_range = torch.tensor(list(range(indice_vision_start + 1, indice_vision_end)), device=device)
            usr_range = torch.tensor(list(range(indice_vision_end + 1, indices_im_end[1])), device=device)

            total_indexed_count = sys_range.numel() + img_range.numel() + usr_range.numel()
            actual_seq_len = attentions[0].shape[-1]

            assert total_indexed_count == actual_seq_len, f"Token count mismatch: {total_indexed_count} indexed vs {actual_seq_len} actual, please check the input formattion."

            # attentions: list of (1, H, T, T)
            # -> (L, H, T)
            layer_attns = []
            for layer_attn in attentions:
                slice_attn = layer_attn[0, :, -1, :]
                layer_attns.append(slice_attn)
            
            # -> (L, H, T)
            last_token_attn = torch.stack(layer_attns, dim=0).float()
            
            # -> (L, H, 1)
            total = last_token_attn.sum(dim=-1, keepdim=True)

            # sys / vision / usr_text
            rsys = last_token_attn[:, :, sys_range].sum(dim=-1, keepdim=True) / total
            rimg = last_token_attn[:, :, img_range].sum(dim=-1, keepdim=True) / total
            rusr = last_token_attn[:, :, usr_range].sum(dim=-1, keepdim=True) / total

            # (L, H, 3)
            tcap = torch.cat([rsys, rimg, rusr], dim=-1)
            tcap_map = tcap.tolist()

            answer_file.write(json.dumps({
                'question_id': idx,
                'answer': outputs,
                'num_tokens': (len(sys_range), len(img_range), len(usr_range)),
                'tcap_map': tcap_map
            }) + '\n')
            
            if index % FLUSH_EVERY == 0:
                answer_file.flush()

        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    **model_inputs,
                    max_new_tokens=args.new_tokens,
                    temperature=args.temperature,
                    do_sample=True if args.temperature > 0 else False,
                    return_dict_in_generate=True,
                    output_attentions=False
                )

            sequences_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, output_ids.sequences)
            ]

            outputs = tokenizer.batch_decode(sequences_trimmed, skip_special_tokens=True)[0].strip()
            
            answer_file.write(json.dumps({
                "question_id": idx,
                "answer": outputs
            }) + "\n")

            if index % FLUSH_EVERY == 0:
                answer_file.flush()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-VL-8B-Instruct")
    parser.add_argument("--lora-path", type=str, default="")
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="")
    parser.add_argument("--answer-file", type=str, default="")
    parser.add_argument("--system-prompt-add", type=str, default="")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--new-tokens", type=int, default=1)
    parser.add_argument("--tcap", action="store_true")
    args = parser.parse_args()

    eval_model(args)
