import torch
import os
import json
from tqdm import tqdm
import shortuuid

from torch.utils.data import DataLoader
from llava.eval.dataset_coco_v2 import COCOEvalData_v2, custom_collate_fn

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model, load_pretrained_model_v2, load_pretrained_model_v3
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper
from llava.backbone.detr_branch.utils_cfg_v2 import CFGLogits_v2
from llava.training_module.utils import get_chunk, get_answers_file_name
from PIL import Image

from PIL import Image

from llava.model.builder import load_pretrained_vision_module
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
import transformers
from llava.training_module.load_args import ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments, VisionModuleArguments_with_vm_prefix
from llava.training_module.utils import print_trainable_layers, print_model_size, print_param_device


def eval_model():
    global local_rank

    ## load args
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, OtherArguments, VisionModuleArguments_with_vm_prefix))
    model_args, data_args, training_args, args, vm_args = parser.parse_args_into_dataclasses()
    # rename all agrs in vm_args that delete vm_ prefix in the name
    vm_args = VisionModuleArguments(**{k.replace('vm_', ''): v for k, v in vm_args.__dict__.items()})
    args.seed = args.cfg_seed
    
    from transformers import set_seed
    set_seed(args.seed)

    from llava.eval.utils_cfg import set_logger
    logger = set_logger(cfg=args.cfg, save_dir=args.answer_path)
    logger.info(args)
    
    device = torch.device("cuda", training_args.local_rank)

    ## load model
    model_base = None if args.model_base is None else args.model_base
    model_path = model_args.model_name_or_path if model_base is None else args.model_path
    model_name = get_model_name_from_path(model_path)

    # load vanilla llava model
    # tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 
    tokenizer, model, image_processor, context_len = load_pretrained_model_v3(model_args, data_args, training_args, model_name="detr", device='cuda')
    data_args.image_processor = image_processor

    ## load aligned vision module
    aligned_vision_tokenizer, aligned_vision_model, aligned_vision_image_processor, aligned_vision_context_len = load_pretrained_vision_module(model_args, vm_args, data_args, training_args)
    # tokenizer = aligned_vision_tokenizer
    # image_processor = aligned_vision_image_processor
    
    ## data
    questions = json.load(open(os.path.expanduser(os.path.join(args.question_path, args.question_file)), "r"))
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    
    if args.answers_file is None: 
        args.answers_file = get_answers_file_name(args, model_name, model_args.pretrain_mm_mlp_adapter, vm_args.pretrain_mm_mlp_adapter)
     
    answers_file = os.path.expanduser(os.path.join(args.answer_path, args.answers_file))
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    # for i, line in enumerate(tqdm(questions)):
    #     idx = line["id"]
    #     question = line['conversations'][0]
    #     gt_ans = line["conversations"][1]['value']
    #     qs = question['value'].replace('<image>', '').strip()
    #     cur_prompt = qs
    #     neg_prompt = tokenizer(gt_ans, return_tensors='pt')['input_ids'].cuda()

    #     if 'image' in line:
    #         image_file = line["image"]
    #         image = Image.open(os.path.join(args.eval_image_folder, image_file))

    #         image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)
            
    #         # BUG: Unsupported number of image dim: 2
    #         import numpy as np
    #         if np.asarray(image).ndim == 2:
    #             image = np.stack([image] * 3, axis=-1)
    #             image = Image.fromarray(image)
    #         aligned_vision_image_tensor = aligned_vision_image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)

    #         if getattr(model.config, 'mm_use_im_start_end', False):
    #             qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    #         else:
    #             qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
    #         cur_prompt = '<image>' + '\n' + cur_prompt
    #     else:
    #         raise NotImplementedError

    #     conv = conv_templates[args.conv_mode].copy()
    #     conv.append_message(conv.roles[0], qs)
    #     conv.append_message(conv.roles[1], None)
    #     prompt = conv.get_prompt()
    #     # You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
    #     # <</SYS>>

    #     # <image>
    #     # Provide a brief description of the given image. [/INST]

        # input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        # aligned_vision_input_ids = tokenizer_image_token(prompt, aligned_vision_tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        # keywords = [stop_str] # ['<s>']
        # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    dataset = COCOEvalData_v2(questions, args.eval_image_folder, image_processor, aligned_vision_image_processor, tokenizer, aligned_vision_tokenizer, 
                           args.conv_mode, getattr(model.config, 'mm_use_im_start_end', False))
    eval_dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=False, collate_fn=custom_collate_fn)

    for images, input_ids, cur_prompts, aligned_vision_images, aligned_vision_input_ids, question_ids, img_ids in eval_dataloader:
        # print(f"test1 {images.shape} {input_ids.shape} {cur_prompts}  {aligned_vision_input_ids.shape} {question_ids} {img_ids}")
        aligned_vision_images = aligned_vision_image_processor.preprocess(aligned_vision_images, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)
        print(aligned_vision_images.shape)
        
        conv = conv_templates[args.conv_mode].copy()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(
            keywords, tokenizer, input_ids)
        # print(f"test1 {images.shape} {input_ids.shape} {cur_prompt} {neg_prompt.shape} {question_id} {img_id} {stopping_criteria}")

        with torch.inference_mode():
            if args.cfg == 0:
                outputs = model.generate(
                    input_ids,
                    images=images,
                    do_sample=True,
                    temperature=0.6,
                    top_p=0.9,
                    max_new_tokens=64,
                    use_cache=True,
                    return_dict_in_generate=True,
                    output_scores=True,
                    # stopping_criteria=[stopping_criteria],
                )
            else:
                outputs = model.generate(
                    input_ids,
                    images=images,
                    do_sample=True,
                    # temperature=0.2,
                    max_new_tokens=64,
                    use_cache=True,
                    # stopping_criteria=[stopping_criteria],
                    logits_processor=LogitsProcessorList([
                        CFGLogits_v2(args.cfg, control_input_ids=aligned_vision_input_ids, control_images=aligned_vision_images, model=model, aligned_vision_model=aligned_vision_model, tokenizer=tokenizer, verbose=True),
                        TemperatureLogitsWarper(0.6),
                        TopPLogitsWarper(0.9),
                    ]),
                    return_dict_in_generate=True,
                    output_scores=True,
                )
                # print(f"return_dict_in_generate: {outputs}")
        output_ids = outputs.sequences

        # Assuming input_ids and output_ids are batched tensors with shape [batch_size, seq_len]
        input_token_len = input_ids.shape[1]

        # Calculate differences between input and output for each item in the batch
        n_diff_input_output = (
            input_ids != output_ids[:, :input_token_len]).sum(dim=1)

        # Batch decode the outputs
        decoded_outputs = tokenizer.batch_decode(
            output_ids[:, input_token_len:], skip_special_tokens=True)

        for i, output in enumerate(decoded_outputs):
            # Check if there is a difference for the current item in the batch
            if n_diff_input_output[i].item() > 0:
                print(
                    f'[Warning] Output ID {i} is not the same as the corresponding input ID')

            # Process each output
            output = output.strip()
            if output.endswith(stop_str):
                output = output[:-len(stop_str)]
            output = output.strip()
            print(f"{question_ids[i]}: {output}")

            # Generate answer ID and write to file
            ans_id = shortuuid.uuid()
            ans_file.write(json.dumps({"question_id": question_ids[i],  # Assuming question_ids is a list of question IDs for the batch
                                       "prompt": cur_prompts[i],
                                       "text": output,
                                       "answer_id": ans_id,
                                       "model_id": model_name,  # Assuming model_name is defined
                                       "metadata": {}}) + "\n")

        ans_file.flush()
    ans_file.close()
    print(f"Done! Saved answers to {answers_file}")

if __name__ == "__main__":
    eval_model()
