import argparse
import torch
import os
import json
import shortuuid

from torch.utils.data import DataLoader
from transformers import LogitsProcessorList, TemperatureLogitsWarper, TopPLogitsWarper

from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria
from llava.training_module.utils import get_chunk, get_answers_file_name

from llava.eval.dataset_coco import COCOEvalData, custom_collate_fn
from llava.eval.utils_cfg import CFGLogits


def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path, args.model_base, model_name)
    # set padding_side='left' for tokenizer
    tokenizer.padding_side = 'left' # TODO: check later
    print(f"tokenizer.padding_side: {tokenizer.padding_side}")
    model = model.cuda()

    # QA Data
    questions = json.load(open(os.path.expanduser(
        os.path.join(args.question_path, args.question_file)), "r"))
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)

    if args.answers_file is None:
        args.answers_file = get_answers_file_name(args, model_name)

    answers_file = os.path.expanduser(
        os.path.join(args.answer_path, args.answers_file))
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    dataset = COCOEvalData(questions, args.image_folder, image_processor, tokenizer,
                           args.conv_mode, getattr(model.config, 'mm_use_im_start_end', False))
    eval_dataloader = DataLoader(
        dataset, batch_size=args.batch_size, shuffle=False, collate_fn=custom_collate_fn)

    for images, input_ids, cur_prompts, neg_prompt, question_ids, img_ids in eval_dataloader:

        conv = conv_templates[args.conv_mode].copy()
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(
            keywords, tokenizer, input_ids)
        # print(f"test1 {images.shape} {input_ids.shape} {cur_prompt} {neg_prompt.shape} {question_id} {img_id} {stopping_criteria}")

        with torch.inference_mode():
            if args.cfg == 0:
                outputs = model.generate(
                    input_ids,
                    images=images,
                    do_sample=args.sampling,
                    temperature=0.6,
                    top_p=0.9,
                    max_new_tokens=64,
                    use_cache=True,
                    # stopping_criteria=[stopping_criteria],
                    return_dict_in_generate=True,
                    output_scores=True,
                )
            else:
                outputs = model.generate(
                    input_ids,
                    images=images,
                    do_sample=args.sampling,
                    # temperature=0.2,
                    # top_p=0.8,
                    max_new_tokens=64,
                    use_cache=True,
                    # stopping_criteria=[stopping_criteria],# TODO: check
                    logits_processor=LogitsProcessorList([
                        CFGLogits(args.cfg, neg_prompt, images, model,
                                  tokenizer, verbose=False),
                        TemperatureLogitsWarper(0.6),
                        TopPLogitsWarper(0.9),
                    ]),
                    return_dict_in_generate=True,
                    output_scores=True,
                )
                # print(f"return_dict_in_generate: {outputs}")
        from llava.eval.output_probs import output_probs
        output_ids = outputs.sequences
        probs = [torch.softmax(logit, dim=-1) for logit in outputs.scores]
        # Assuming input_ids and output_ids are batched tensors with shape [batch_size, seq_len]
        input_token_len = input_ids.shape[1]

        # Calculate differences between input and output for each item in the batch
        n_diff_input_output = (
            input_ids != output_ids[:, :input_token_len]).sum(dim=1)
        output_token = output_ids[:, input_token_len:]
        if len(probs) != len(output_token[0]):
            import pdb; pdb.set_trace()
        results = output_probs(output_token, probs, tokenizer)
        
        

        # Batch decode the outputs
        decoded_outputs = tokenizer.batch_decode(
            output_ids[:, input_token_len:], skip_special_tokens=True)

        for i, output in enumerate(decoded_outputs):
            # Check if there is a difference for the current item in the batch
            if n_diff_input_output[i].item() > 0:
                print(
                    f'[Warning] Output ID {i} is not the same as the corresponding input ID')

            # Process each output
            output = output.strip()
            if output.endswith(stop_str):
                output = output[:-len(stop_str)]
            output = output.strip()
            print(f"{question_ids[i]}: {output} {results[i]['p_all']}")

            # Generate answer ID and write to file
            ans_id = shortuuid.uuid()
            ans_file.write(json.dumps({"question_id": question_ids[i],  # Assuming question_ids is a list of question IDs for the batch
                                       # Assuming cur_prompts is a list of current prompts for the batch
                                       "image_id": img_ids[i],
                                       "prompt": cur_prompts[i],
                                       "text": output,
                                       "objs": results[i]["objs"],
                                       "plist": results[i]["plist"],
                                       "p_all": results[i]["p_all"],
                                       "answer_id": ans_id,
                                       "model_id": model_name,  # Assuming model_name is defined
                                       "metadata": {}}) + "\n")

        ans_file.flush()
    ans_file.close()
    print(f"Done! Saved answers to {answers_file}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str,
                        default="./checkpoints/llava-llama-2-13b-chat-lightning-preview")
    parser.add_argument("--model_base", type=str, default=None)
    parser.add_argument("--image_folder", type=str,
                        default="../POPE/data/minival2014/minival2014")
    parser.add_argument("--question_path", type=str,
                        default="../POPE/llava_qa/question")
    parser.add_argument("--question_file", type=str,
                        default="I1_sub240_control.json")
    parser.add_argument("--answer_path", type=str,
                        default="../POPE/llava_qa/answer")
    parser.add_argument("--answers_file", type=str, default=None)

    parser.add_argument("--conv-mode", type=str, default="llava_llama_2")
    parser.add_argument("--num_chunks", type=int, default=1)
    parser.add_argument("--chunk_idx", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.8)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=3)
    # add sampling
    parser.add_argument("--sampling", action="store_true")
    args = parser.parse_args()

    from transformers import set_seed
    set_seed(args.seed)
    print(f"test {args.sampling}")
    from llava.eval.utils_cfg import set_logger
    logger = set_logger(cfg=args.cfg, save_dir=args.answer_path)
    logger.info(args)

    eval_model(args)
