import torch
import os

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_v2, load_pretrained_model_v3
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper
from llava.backbone.detr_branch.utils_cfg_v2 import CFGLogits_v2

from PIL import Image

import requests
from PIL import Image
from io import BytesIO

from llava.model.builder import load_pretrained_vision_module
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
import transformers
from llava.training_module.load_args import ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments, VisionModuleArguments_with_vm_prefix
from llava.training_module.utils import print_trainable_layers, print_model_size, print_param_device


def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def eval_model():
    global local_rank

    ## load args
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments_with_vm_prefix))
    model_args, data_args, training_args, args, vm_args = parser.parse_args_into_dataclasses()
    # rename all agrs in vm_args that delete vm_ prefix in the name
    vm_args = VisionModuleArguments(**{k.replace('vm_', ''): v for k, v in vm_args.__dict__.items()})

    device = torch.device("cuda", training_args.local_rank)

    ## load model
    model_base = None if args.model_base is None else args.model_base
    model_path = model_args.model_name_or_path if model_base is None else args.model_path
    model_name = get_model_name_from_path(model_path)

    # load vanilla llava model
    # tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 

    tokenizer, model, image_processor, context_len = load_pretrained_model_v2(model_args, data_args, training_args, model_path, model_base, model_name)
    # tokenizer, model3, image_processor, context_len = load_pretrained_model_v3(True, model_args, data_args, training_args, model_name="detr-v2", device='cuda')
    data_args.image_processor = image_processor

    ## load aligned vision module
    aligned_vision_tokenizer, aligned_vision_model, aligned_vision_image_processor, aligned_vision_context_len = load_pretrained_vision_module(model_args, vm_args, data_args, training_args, "clip", device=device)

    from llava.backbone.detr_branch.test_compare_ckpt import compare_models_weight
    # compare_models_weight(model3, aligned_vision_model) # no prob! same weight in detr-v2
    # compare_models_weight(model, model3) # no prob! same weight in LLM
    # exit()

    ## data
    qs = "Provide a brief description of this image"
    if model.config.mm_use_im_start_end: # False
        # '<im_start>' # '<im_end>'
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs # '<image>'

    conv_mode = "llava_llama_2"

    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    # Conversation(system='You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.', 
    # roles=('USER', 'ASSISTANT'), messages=[['USER', '<image>\nProvide a brief description of this image'], ['ASSISTANT', None]], offset=0, sep_style=<SeparatorStyle.LLAMA_2: 5>, sep='<s>', sep2='</s>', version='llama_v2', skip_next=False)
    prompt = conv.get_prompt() # '[INST] <<SYS>>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<</SYS>>\n\n<image>\nProvide a brief description of this image [/INST]'

    # tokenize and insert IMAGE_TOKEN_INDEX
    # input_ids is the same
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    aligned_vision_input_ids = tokenizer_image_token(prompt, aligned_vision_tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str] # ['<s>']
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

    # print device
    # print("model device")
    # print_param_device(model)
    # print("aligned_vision_model device")
    # print_param_device(aligned_vision_model)
    
    # print trainable layers
    print("model")
    print_trainable_layers(model)

    print("aligned_vision_model")
    print_trainable_layers(aligned_vision_model)

    ## load image
    results_file = "detr-v2-clip"
    img_name_ls = ["COCO_val2014_000000054870", "COCO_val2014_000000144305"]
    for img_name in img_name_ls:
        image_file = f"../POPE/data/minival2014/minival2014/{img_name}.jpg"
        image = load_image(image_file)

        # torch.Size([1, 3, 800, 1066]) vs. torch.Size([1, 3, 224, 224])
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)
        aligned_vision_image_tensor = aligned_vision_image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)

        cfg_ls = [1]#, 0.3, 0.5, 0.8, 1.0]
        results_dict = {"img_name": img_name, "cfg_ls": cfg_ls, "outputs_cfg_ls": {}}
        outputs_cfg_ls = {}
        for cfg in cfg_ls:
            print("cfg", cfg)
            with torch.inference_mode():
                outputs = model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=0.2,
                    max_new_tokens=1024,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria],
                    logits_processor=LogitsProcessorList([
                        CFGLogits_v2(cfg, control_input_ids=aligned_vision_input_ids, control_images=aligned_vision_image_tensor, model=model, aligned_vision_model=aligned_vision_model),
                        TemperatureLogitsWarper(0.8),
                        TopPLogitsWarper(0.95),
                    ]),
                    return_dict_in_generate=True,
                    output_scores=True,
            )
            # print(f"return_dict_in_generate: {outputs_text}")
            output_ids = outputs.sequences

            input_token_len = input_ids.shape[1] # 68
            n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
            if n_diff_input_output > 0: # False, output_ids of shape (1, 161)
                print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
            outputs_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
            outputs_text = outputs_text.strip()
            if outputs_text.endswith(stop_str): # False
                outputs_text = outputs_text[:-len(stop_str)]
            outputs_text = outputs_text.strip()
            print(outputs_text)
            outputs_cfg_ls[cfg] = outputs_text

        results_dict["outputs_cfg_ls"] = outputs_cfg_ls
            
        # # save to file
        # import json
        # save_dir = f"./llava/backbone/results/{results_file}"
        # with open(os.path.join(save_dir,f"{img_name}_outputs.json"), 'w') as f:
        #     json.dump(results_dict, f)
        # print(f"saved to {os.path.join(save_dir,f'{img_name}_outputs.json')}")

if __name__ == "__main__":
    # parser = argparse.ArgumentParser()
    # # parser.add_argument("--model_path", type=str, default="./checkpoints/llava-llama-2-13b-chat-lightning-preview")
    # # parser.add_argument("--model-base", type=str, default=None)
    # # # parser.add_argument("--image_file", type=str, required=True)
    # # # parser.add_argument("--query", type=str, required=True)
    # # parser.add_argument("--conv-mode", type=str, default=None)
    # # --image_file ../POPE/data/minival2014/minival2014/COCO_val2014_000000144305.jpg --query "Provide a brief description of this image"
    # args = parser.parse_args()
    # args.image_file = "../POPE/data/minival2014/minival2014/COCO_val2014_000000144305.jpg"
    # args.query = "Provide a brief description of this image"

    eval_model()
