import math
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from modelscope import AutoModel, AutoTokenizer
from internvl.conversation import Conversation, SeparatorStyle
import argparse
import os
import json
from tqdm import tqdm

DEFAULT_IMAGE_TOKEN = '<image>'
FLUSH_EVERY = 200

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)
    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def eval_model(args):
    pretrained = args.model_path
    device = 'cuda'
    device_map = 'auto'

    if args.tcap:
        model = AutoModel.from_pretrained(
            pretrained,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=False,
            trust_remote_code=True,
            device_map=device_map)
    else:
        model = AutoModel.from_pretrained(
            pretrained,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            use_flash_attn=True,
            trust_remote_code=True,
            device_map=device_map)

    model.eval()
    model.tie_weights()
    tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True, use_fast=False)

    # load json
    question_file = open(os.path.expanduser(args.question_file), 'r')
    ans_path = os.path.expanduser(args.answer_file)
    parent_dir = os.path.dirname(ans_path)
    if parent_dir:
        os.makedirs(parent_dir, exist_ok=True)
    answer_file = open(ans_path, 'w')

    if args.question_file.endswith('.json'):
        questions = json.load(question_file)
    elif args.question_file.endswith('.jsonl'):
        questions = [json.loads(line) for line in question_file.readlines()]
    else:
        print('Only Support JSON or JSONL question file.')
        exit(0)

    for index, line in enumerate(tqdm(questions)):
        idx = line['id']
        question = line['conversations'][0]['value'].replace('<image>', '').strip()
        if not question:
            question = "This is the picture."
        image_file = line['image']
        

        pixel_values = load_image(os.path.join(args.image_folder, image_file), max_num=3).to(torch.bfloat16).to(device)

        user_prompt = DEFAULT_IMAGE_TOKEN + '\n' + question

        conv = Conversation(
                name='internvl2_5',
                system_template='<|im_start|>system\n{system_message}',
                system_message="You are InternVL, a multimodal large language model jointly developed by the Shanghai Artificial Intelligence Laboratory, Tsinghua University, and multiple partner institutions.",
                roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
                sep_style=SeparatorStyle.MPT,
                sep='<|im_end|>\n',
                messages=[]
            )
        eos_token_id = tokenizer.convert_tokens_to_ids(conv.sep.strip())

        if args.system_prompt_add:
            conv.system_message += " " + args.system_prompt_add

        conv.append_message(conv.roles[0], user_prompt)
        conv.append_message(conv.roles[1], None)
        query = conv.get_prompt()

        if index == 0:
            print('[Prompt]', query)

        IM_START_TOKEN='<|im_start|>'
        IM_END_TOKEN='<|im_end|>'

        IMG_START_TOKEN='<img>'
        IMG_END_TOKEN='</img>'
        IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'


        num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []

        img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        model.img_context_token_id = img_context_token_id


        image_bs = pixel_values.shape[0]

        for num_patches in num_patches_list:
            image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * model.num_image_token * num_patches + IMG_END_TOKEN
            query = query.replace('<image>', image_tokens, 1)

        model_inputs = tokenizer(query, return_tensors='pt')
        input_ids = model_inputs['input_ids'].to(device)
        attention_mask = model_inputs['attention_mask'].to(device)

        if args.tcap:
            with torch.inference_mode():
                output_ids = model.generate(
                    max_new_tokens=1,
                    temperature=args.temperature,
                    do_sample=True if args.temperature > 0 else False,
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pad_token_id=eos_token_id,
                    eos_token_id=eos_token_id,
                    return_dict_in_generate=True,
                    output_attentions=True
                )

            outputs = tokenizer.batch_decode(output_ids.sequences, skip_special_tokens=True)[0].strip()
            attentions = output_ids['attentions'][0]

            input_token_len = input_ids.shape[1]
            im_start_token_id = tokenizer.convert_tokens_to_ids(IM_START_TOKEN)
            im_end_token_id = tokenizer.convert_tokens_to_ids(IM_END_TOKEN)
            img_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
            img_end_token_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
            index_im_starts = [i for i, d in enumerate(input_ids[0]) if d == im_start_token_id]
            index_im_ends = [i for i, d in enumerate(input_ids[0]) if d == im_end_token_id]
            index_img_start = input_ids[0].tolist().index(img_start_token_id)
            index_img_end = input_ids[0].tolist().index(img_end_token_id)


            # tri-component attention extraction during single image inference
            sys_range = torch.tensor(list(range(0, index_img_start)) + list(range(index_im_ends[1], input_token_len)), device=device)
            img_range = torch.tensor(list(range(index_img_start, index_img_end + 2)), device=device)
            usr_range = torch.tensor(list(range(index_img_end + 2, index_im_ends[1])), device=device)

            total_indexed_count = sys_range.numel() + img_range.numel() + usr_range.numel()
            actual_seq_len = attentions[0].shape[-1]

            assert total_indexed_count == actual_seq_len, f"Token count mismatch: {total_indexed_count} indexed vs {actual_seq_len} actual, please check the input formattion"

            # attentions: list of (1, H, T, T)
            # -> (L, H, T)
            layer_attns = []
            for layer_out in attentions:
                slice_attn = layer_out[0, :, -1, :] 
                layer_attns.append(slice_attn)

            # (L, H, T)
            last_token_attn = torch.stack(layer_attns, dim=0).float()

            # (L, H, 1)
            total = last_token_attn.sum(dim=-1, keepdim=True)

            # sys_range / img_range / usr_range
            rsys = last_token_attn[:, :, sys_range].sum(dim=-1, keepdim=True) / total
            rimg = last_token_attn[:, :, img_range].sum(dim=-1, keepdim=True) / total
            rusr = last_token_attn[:, :, usr_range].sum(dim=-1, keepdim=True) / total

            # (L, H, 3)
            tcap = torch.cat([rsys, rimg, rusr], dim=-1)

            tcap_map = tcap.tolist()

            answer_file.write(json.dumps({
                'question_id': idx,
                'answer': outputs,
                'num_tokens': (len(sys_range), len(img_range), len(usr_range)),
                'tcap_map': tcap_map
            }) + '\n')
            
            if index % FLUSH_EVERY == 0:
                answer_file.flush()
        
        else:
            with torch.inference_mode():
                output_ids = model.generate(
                    max_new_tokens=args.new_tokens,
                    temperature=args.temperature,
                    do_sample=True if args.temperature > 0 else False,
                    pixel_values=pixel_values,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    pad_token_id=eos_token_id,
                    eos_token_id=eos_token_id,
                    return_dict_in_generate=True,
                    output_attentions=False
                )
            outputs = tokenizer.batch_decode(output_ids.sequences, skip_special_tokens=True)[0].strip()
            # ans write
            answer_file.write(json.dumps({
                "question_id": idx,
                "answer": outputs
            }) + "\n")
            if index % FLUSH_EVERY == 0:
                answer_file.flush()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="OpenGVLab/InternVL2_5-8B")
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="")
    parser.add_argument("--answer-file", type=str, default="")
    parser.add_argument("--system-prompt-add", type=str, default="")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--new-tokens", type=int, default=1)
    parser.add_argument("--tcap", action="store_true")
    args = parser.parse_args()

    eval_model(args)
