from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import Conversation, SeparatorStyle
from llava.utils import disable_torch_init
from PIL import Image
import requests
import copy
import torch
import argparse
import numpy as np
from tqdm import tqdm
import os
import json


FLUSH_EVERY = 200

def eval_model(args):
    pretrained = args.model_path
    model_name = "llava_llama3"
    device = "cuda"
    device_map = "auto"

    if args.tcap:
        tokenizer, model, image_processor, max_length = load_pretrained_model(
            pretrained, 
            None, 
            model_name, 
            device_map=device_map,
            attn_implementation="eager"
        )
    else:
        tokenizer, model, image_processor, max_length = load_pretrained_model(
            pretrained, 
            None, 
            model_name, 
            device_map=device_map,
            attn_implementation="flash_attention_2"
        )

    model.eval()
    model.tie_weights()

    # load json
    question_file = open(os.path.expanduser(args.question_file), 'r')
    ans_path = os.path.expanduser(args.answer_file)
    parent_dir = os.path.dirname(ans_path)
    os.makedirs(parent_dir, exist_ok=True)
    answer_file = open(ans_path, 'w')

    if args.question_file.endswith('.json'):
        questions = json.load(question_file)
    elif args.question_file.endswith('.jsonl'):
        questions = [json.loads(line) for line in question_file.readlines()]
    else:
        print('Only Support JSON or JSONL question file.')
        exit(0)

    for index, line in enumerate(tqdm(questions)):
        idx = line['id']
        question = line['conversations'][0]['value'].replace('<image>', '').strip()
        if not question:
            question = "This is the picture."
        image_file = line['image']

        image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
        image_tensor = process_images([image], image_processor, model.config)
        image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]
        image_sizes = [image.size]

        user_prompt = DEFAULT_IMAGE_TOKEN + '\n' + question

        conv = Conversation(
            system = "You are a helpful language and vision assistant. " + 
                    "You are able to understand the visual content that the user provides, " +
                    "and assist the user with a variety of tasks using natural language." +
                    f" {args.system_prompt_add}" if args.system_prompt_add else "",
            roles=("user", "assistant"),
            version="llama_v3",
            messages=[],
            offset=0,
            sep="<|eot_id|>",
            sep_style=SeparatorStyle.LLAMA_3,
            tokenizer_id=None,
            tokenizer=tokenizer,
            stop_token_ids=[128009],
        )

        conv.append_message(conv.roles[0], user_prompt)
        conv.append_message(conv.roles[1], None)
        prompt_question = conv.get_prompt()

        if index == 0:
            print('[Prompt]', prompt_question)

        input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
        pos = input_ids[0].tolist().index(IMAGE_TOKEN_INDEX)

        # tcap
        if args.tcap:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    image_sizes=image_sizes,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    max_new_tokens=1,
                    use_cache=True,
                    output_attentions=True,
                    return_dict_in_generate=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            outputs = tokenizer.batch_decode(output_ids.sequences, skip_special_tokens=True)[0].strip()
            attentions = output_ids['attentions'][0]

            num_image_tokens = attentions[0].shape[-1] - (input_ids.shape[-1] - 1)

            START_HEADER_ID = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
            END_HEADER_ID = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
            EOT_ID = tokenizer.convert_tokens_to_ids("<|eot_id|>")
            BOF_ID = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
            IMAGE_ID = IMAGE_TOKEN_INDEX

            current_attn_idx = 0
            raw_ids = input_ids[0].cpu().numpy()
            raw_ids = np.array([y for x in raw_ids for y in ([-201] * num_image_tokens if x == -200 else [x])])

            index_start_header_ids = [i for i, d in enumerate(raw_ids) if d == START_HEADER_ID]
            index_end_header_ids = [i for i, d in enumerate(raw_ids) if d == END_HEADER_ID]
            index_eot_ids = [i for i, d in enumerate(raw_ids) if d == EOT_ID]

            sys1_start = 0
            sys1_end = index_end_header_ids[1] + 2
            img_start = sys1_end
            img_end = img_start + num_image_tokens
            usr_start = img_end
            usr_end = index_eot_ids[1]
            sys2_start = usr_end
            sys2_end = len(raw_ids)

            sys_range = list(range(sys1_start, sys1_end)) + list(range(sys2_start, sys2_end))
            img_range = list(range(img_start, img_end))
            usr_range = list(range(usr_start, usr_end))

            assert attentions[0].shape[-1] == len(list(set(sys_range + img_range + usr_range)))

            # attentions: list of (1, H, T, T)
            # -> (L, H, T, T)
            attn = torch.stack(
                [a.squeeze(0) for a in attentions], dim=0
            ).float()

            # (L, H, T)
            last_token_attn = attn[:, :, -1, :]

            # (L, H, 1)
            total = last_token_attn.sum(dim=-1, keepdim=True)

            # sys_range / img_range / usr_range
            rsys = last_token_attn[:, :, sys_range].sum(dim=-1, keepdim=True) / total
            rimg = last_token_attn[:, :, img_range].sum(dim=-1, keepdim=True) / total
            rusr = last_token_attn[:, :, usr_range].sum(dim=-1, keepdim=True) / total

            # (L, H, 3)
            tcap = torch.cat([rsys, rimg, rusr], dim=-1)

            tcap_map = tcap.tolist()

            answer_file.write(json.dumps({
                'question_id': idx,
                'answer': outputs,
                'num_tokens': (len(sys_range), len(img_range), len(usr_range)),
                'tcap_map': tcap_map
            }) + '\n')
            
            if index % FLUSH_EVERY == 0:
                answer_file.flush()
    
        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    image_sizes=image_sizes,
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    max_new_tokens=args.new_tokens,
                    use_cache=True,
                    output_attentions=False,
                    return_dict_in_generate=True,
                    pad_token_id=tokenizer.eos_token_id
                )
            outputs = tokenizer.batch_decode(output_ids.sequences, skip_special_tokens=True)[0].strip()

            # ans write
            answer_file.write(json.dumps({
                "question_id": idx,
                "answer": outputs
            }) + "\n")
            answer_file.flush()
            torch.cuda.empty_cache()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="lmms-lab/llama3-llava-next-8b")
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="")
    parser.add_argument("--answer-file", type=str, default="tcap_result.jsonl")
    parser.add_argument("--system-prompt-add", type=str, default="")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--new-tokens", type=int, default=1)
    parser.add_argument("--tcap", action="store_true")
    args = parser.parse_args()

    eval_model(args)
