import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, crop_two_square
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import math

from llava.eval.eval_utils.perplexity_utils import get_full_prompt, DatasetForPerplexity_SEED
from llava.eval.eval_utils.dataset_utils import Dataset_SugarCrepe, shuffle_choices, Dataset_common_diff
def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, questions, image_folder, tokenizer, image_processor, model_config, two_crops = False):
        self.questions = questions
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.model_config = model_config
        self.dummy_image = None
        self.two_crops = two_crops

    def __getitem__(self, index):
        line = self.questions[index]
        image_file = line["image"]
        qs = line["text"]

        # if self.model_config.mm_use_im_start_end:
        #     qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        # else:
        #     qs = DEFAULT_IMAGE_TOKEN + '\n' + qs  # DEFAULT_IMAGE_TOKEM = '<image>'
        # conv = conv_templates[args.conv_mode].copy()
        # conv.append_message(conv.roles[0], qs)
        # conv.append_message(conv.roles[1], None)
        # prompt = conv.get_prompt()

        # get_full_prompt():  add system prompt,  and  <image>,  USER:,  ASSISTANT:
        prompt = get_full_prompt(conv_mode=args.conv_mode,qs= qs, mm_use_im_start_end= self.model_config.mm_use_im_start_end) #  here prompt is a tuple
        #  here if prompt is a tuple, the output input_ids is also a tuple ,   input_ids[0]
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX,  return_tensors='pt')  # IMAGE_TOKEN_INDEX = -200,

        # image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        # image_tensor = process_images([image], self.image_processor, self.model_config)[0]  # (3, 336, 336)

        try:
            image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') #
            if self.two_crops:
                two_crops = list(crop_two_square(image))
                image_tensor = process_images(two_crops, self.image_processor, self.model_config)
            else:
                image_tensor = process_images([image], self.image_processor, self.model_config)[0]
        except:
            image_tensor = self.dummy_image

        if self.dummy_image is None:
            import copy
            self.dummy_image = copy.deepcopy(image_tensor) * 0



        return input_ids, image_tensor

    def __len__(self):
        return len(self.questions)


# DataLoader
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4, args = None, annotation_file = None):
    assert batch_size == 1, "batch_size must be 1"
    if args.eval_dataset == 'SEED':
        if args.inference_mode == 'perplexity':
            dataset = DatasetForPerplexity_SEED(questions, image_folder, tokenizer, image_processor, model_config,  annotation_file=args.annotation_file,
                                                conv_mode=args.conv_mode, perplexity_prompt_version=args.perplexity_prompt_version)
        elif args.inference_mode == 'generate':
            dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config, two_crops=args.two_crops)
        else:
            raise NotImplementedError
    elif args.eval_dataset == 'SugarCrepe':
        dataset = Dataset_SugarCrepe(questions,  tokenizer, image_processor, model_config,
                    image_folder = image_folder,
                    eval_dataset=args.eval_dataset,
                    inference_mode=args.inference_mode,
                    perplexity_prompt_version=args.perplexity_prompt_version, conv_mode=args.conv_mode)
    elif args.eval_dataset == 'common_diff':
        dataset = Dataset_common_diff(questions, tokenizer, image_processor, model_config,
                                        image_folder=image_folder,
                                        eval_dataset=args.eval_dataset,
                                        inference_mode=args.inference_mode,
                                        perplexity_prompt_version=args.perplexity_prompt_version, conv_mode=args.conv_mode)
    elif args.eval_dataset in ['Ours_baseline', 'Ours_new']: # Irene's project,   SugarCrepe + ARO,  old and new negatives
        image_folder_dict = {'SugarCrepe': args.sugarcrepe_image_folder,
                             'ARO': args.aro_image_folder}
        dataset = Dataset_SugarCrepe(questions, tokenizer, image_processor, model_config,
                    image_folder_dict=image_folder_dict,
                    eval_dataset=args.eval_dataset,
                    inference_mode=args.inference_mode,
                    perplexity_prompt_version=args.perplexity_prompt_version, conv_mode=args.conv_mode, )
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    return data_loader


def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

# <<<<<<< HEAD
    if args.question_file.endswith("jsonl"): # SEED-Bench,  no need to shuffle
        questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    elif args.question_file.endswith("json"):
        questions = json.load(open(os.path.expanduser(args.question_file), "r"))["questions"]
        if args.eval_dataset in  ['Ours_baseline', 'Ours_new']:
            for question_id, item in questions.items():
                if args.eval_dataset == 'Ours_baseline':
                    item['answers'] = [item['orig_pos'], item['orig_neg']]
                elif args.eval_dataset == 'Ours_new':
                    item['answers'] = [item['orig_pos'], item['new_neg']]
                else:
                    raise ValueError(f'Unknown eval_dataset: {args.eval_dataset}')
        if args.eval_dataset in ['SugarCrepe', 'Ours_baseline', 'Ours_new']:
            questions = shuffle_choices(questions)
            questions = list(questions.values())
    # get the chunk of questions for this process
    # data: question_id, question, answers, gt_answer, question_type_id, image_file,
    # question_typ_dict:  question_type to question_type_id dict
# =======
#     # model = model.bfloat16()

    # questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
# >>>>>>> leonid
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
    answers_file = os.path.expanduser(args.answers_file)
    os.makedirs(os.path.dirname(answers_file), exist_ok=True)
    ans_file = open(answers_file, "w")

    if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
        args.conv_mode = args.conv_mode + '_mmtag'
        print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')

    data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config, args=args) # by default, batch_size=1

    if args.inference_mode == 'perplexity':
# <<<<<<< HEAD
#         if args.eval_dataset == 'SugarCrepe':
#             answer_2options = ['A', 'B']
#             for (image_tensor, input_ids_a, labels_a, input_ids_b, labels_b), line in tqdm(zip(data_loader, questions),  total=len(questions)):
#                 idx = line["question_id"]
#                 gt_answer_idx = line['gt_answer_idx']  # the index of the true caption after shuffling
#                 image_tensor, input_ids_a, labels_a, input_ids_b, labels_b = (image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
#                                                                               input_ids_a.to(device='cuda', non_blocking=True), labels_a.to(device='cuda', non_blocking=True),
#                                                                               input_ids_b.to(device='cuda', non_blocking=True), labels_b.to(device='cuda', non_blocking=True))
#                 all_input_ids = [input_ids_a, input_ids_b]
#                 all_labels = [labels_a, labels_b]
#                 scores = []
#                 with torch.inference_mode():
#                     for input_ids, labels in zip(all_input_ids, all_labels):
#                         scores.append(  - model( input_ids=input_ids,images=image_tensor,labels=labels, return_dict=True)["loss"].to('cpu').numpy() )
#                 pred = scores.index(max(scores))
#                 pred_answer = answer_2options[pred]
#                 ans_file.write(json.dumps({"question_id": idx,
#                                            "text": pred_answer,
#                                            "gt_answer": answer_2options[gt_answer_idx], # # the letter of the true caption after shuffling
#                                            "answer_id": shortuuid.uuid(),
#                                            "model_id": model_name,
#                                            "metadata": {}}) + "\n")
#         else:
#             answer_4options = ['A', 'B', 'C', 'D']
#             for (image_tensor, input_ids_a, labels_a, input_ids_b,  labels_b, input_ids_c, labels_c, input_ids_d, labels_d), line in tqdm(zip(data_loader, questions), total=len(questions)):
#                 if image_tensor is None:
#                     continue
#                 idx = line["question_id"]
#                 question = data_loader.dataset.annotated_data_dict[str(idx)]['question']
#                 choices = [data_loader.dataset.annotated_data_dict[str(idx)]['choice_a'], data_loader.dataset.annotated_data_dict[str(idx)]['choice_b'],
#                            data_loader.dataset.annotated_data_dict[str(idx)]['choice_c'], data_loader.dataset.annotated_data_dict[str(idx)]['choice_d']]
#                 gt_answer = data_loader.dataset.annotated_data_dict[str(idx)]['answer']
#                 # gt_answer_idx = seed_answer_conversion( data_loader.dataset.annotated_data_dict[str(idx)]['answer'])
#                 image_tensor, input_ids_a, labels_a, input_ids_b, labels_b, input_ids_c, labels_c, input_ids_d, labels_d \
#                     = (image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
#                        input_ids_a.to(device='cuda', non_blocking=True), labels_a.to(device='cuda', non_blocking=True),
#                        input_ids_b.to(device='cuda', non_blocking=True), labels_b.to(device='cuda', non_blocking=True),
#                        input_ids_c.to(device='cuda', non_blocking=True), labels_c.to(device='cuda', non_blocking=True),
#                        input_ids_d.to(device='cuda', non_blocking=True), labels_d.to(device='cuda', non_blocking=True))
#                 all_input_ids = [input_ids_a, input_ids_b, input_ids_c, input_ids_d]
#                 all_labels = [labels_a, labels_b, labels_c, labels_d]
#                 scores = []
#                 with torch.inference_mode():
#                     for input_ids, labels in zip(all_input_ids, all_labels):
#                         scores.append(  - model( input_ids=input_ids,images=image_tensor,labels=labels, return_dict=True)["loss"].to('cpu').numpy() )
#                 pred = scores.index(max(scores))
#                 # hit = int(pred == gt_answer_idx)
#                 pred_answer = answer_4options[pred]
#                 ans_file.write(json.dumps({"question_id": idx,
#                                            # "prompt": question,
#                                            "question": question,
#                                             "choices": choices,
#                                             "gt_answer": gt_answer,
#                                            "text": pred_answer,
#                                            "answer_id": shortuuid.uuid(),
#                                            "model_id": model_name,
#                                            "metadata": {}}) + "\n")
# =======
        if 'seed' in args.question_file.lower() and not 'seedbench2' in args.question_file.lower(): # SEED-Bench 1
            seed_answer_options = ['A', 'B', 'C', 'D']
            for (image_tensor, input_ids_a, labels_a, input_ids_b,
                 labels_b, input_ids_c, labels_c, input_ids_d, labels_d), line in tqdm(zip(data_loader, questions), total=len(questions)):
                if image_tensor is None:
                    continue
                idx = line["question_id"]
                question = data_loader.dataset.annotated_data_dict[str(idx)]['question']
                choices = [data_loader.dataset.annotated_data_dict[str(idx)]['choice_a'], data_loader.dataset.annotated_data_dict[str(idx)]['choice_b'],
                           data_loader.dataset.annotated_data_dict[str(idx)]['choice_c'], data_loader.dataset.annotated_data_dict[str(idx)]['choice_d']]
                gt_answer = data_loader.dataset.annotated_data_dict[str(idx)]['answer']
                # gt_answer_idx = seed_answer_conversion( data_loader.dataset.annotated_data_dict[str(idx)]['answer'])
                image_tensor, input_ids_a, labels_a, input_ids_b, labels_b, input_ids_c, labels_c, input_ids_d, labels_d \
                    = (image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                       input_ids_a.to(device='cuda', non_blocking=True), labels_a.to(device='cuda', non_blocking=True),
                       input_ids_b.to(device='cuda', non_blocking=True), labels_b.to(device='cuda', non_blocking=True),
                       input_ids_c.to(device='cuda', non_blocking=True), labels_c.to(device='cuda', non_blocking=True),
                       input_ids_d.to(device='cuda', non_blocking=True), labels_d.to(device='cuda', non_blocking=True))
                all_input_ids = [input_ids_a, input_ids_b, input_ids_c, input_ids_d]
                all_labels = [labels_a, labels_b, labels_c, labels_d]
                scores = []
                with torch.inference_mode():
                    for input_ids, labels in zip(all_input_ids, all_labels):
                        scores.append(  - model( input_ids=input_ids.long(),images=image_tensor,labels=labels, return_dict=True)["loss"].to('cpu').numpy() )
                pred = scores.index(max(scores))
                # hit = int(pred == gt_answer_idx)
                pred_answer = seed_answer_options[pred]
                ans_file.write(json.dumps({"question_id": idx,
                                           # "prompt": question,
                                           "question": question,
                                            "choices": choices,
                                            "gt_answer": gt_answer,
                                           "text": pred_answer,
                                           "answer_id": shortuuid.uuid(),
                                           "model_id": model_name,
                                           "metadata": {}}) + "\n")
        elif 'seedbench2' in args.question_file.lower(): # SEED-Bench 2
            seed_answer_options = ['A', 'B', 'C', 'D']
            for (image_tensor, input_ids_a, labels_a, input_ids_b,
                 labels_b, input_ids_c, labels_c, input_ids_d, labels_d), line in tqdm(zip(data_loader, questions), total=len(questions)):

                if image_tensor.dim() == 5:
                    image_tensor = image_tensor.squeeze(0)
                image_tensor, input_ids_a, labels_a, input_ids_b, labels_b, input_ids_c, labels_c, input_ids_d, labels_d \
                    = (image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                       input_ids_a.to(device='cuda', non_blocking=True), labels_a.to(device='cuda', non_blocking=True),
                       input_ids_b.to(device='cuda', non_blocking=True), labels_b.to(device='cuda', non_blocking=True),
                       input_ids_c.to(device='cuda', non_blocking=True), labels_c.to(device='cuda', non_blocking=True),
                       input_ids_d.to(device='cuda', non_blocking=True), labels_d.to(device='cuda', non_blocking=True))
                all_input_ids = [input_ids_a, input_ids_b, input_ids_c, input_ids_d]
                all_labels = [labels_a, labels_b, labels_c, labels_d]
                scores = []
                with torch.inference_mode():
                    for input_ids, labels in zip(all_input_ids, all_labels):
                        scores.append(  - model( input_ids=input_ids.long(),images=image_tensor,labels=labels, return_dict=True)["loss"].to('cpu').numpy() )
                pred = scores.index(max(scores))
                pred_answer = seed_answer_options[pred]
                if args.eval_dataset == 'common_diff':
                    ans_file.write(json.dumps({"question_id": line["question_id"],
                                               "image": line["image"],
                                               "prompt": line['text'],
                                               "text": pred_answer,
                                               "model_id": model_name,
                                               "metadata": {}
                                               }) + "\n")
        elif 'winoground' in args.question_file.lower():
            total_count = 0
            img_correct = 0
            text_correct = 0
            group_correct = 0
            for (image_tensor, input_ids_a, labels_a, input_ids_b,
                 labels_b, input_ids_c, labels_c, input_ids_d, labels_d), line in tqdm(zip(data_loader, questions), total=len(questions)):
                cap_order = line['cap_order']
                if image_tensor.dim() == 5:
                    image_tensor = image_tensor.squeeze(0)
                image_tensor, input_ids_a, labels_a, input_ids_b, labels_b, input_ids_c, labels_c, input_ids_d, labels_d \
                    = (image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
                       input_ids_a.to(device='cuda', non_blocking=True), labels_a.to(device='cuda', non_blocking=True),
                       input_ids_b.to(device='cuda', non_blocking=True), labels_b.to(device='cuda', non_blocking=True),
                       input_ids_c.to(device='cuda', non_blocking=True), labels_c.to(device='cuda', non_blocking=True),
                       input_ids_d.to(device='cuda', non_blocking=True), labels_d.to(device='cuda', non_blocking=True))
                all_input_ids = [input_ids_a, input_ids_b, input_ids_c, input_ids_d]
                all_labels = [labels_a, labels_b, labels_c, labels_d]
                scores = []
                with torch.inference_mode():
                    for input_ids, labels in zip(all_input_ids, all_labels):
                        scores.append(-model(input_ids=input_ids.long(), images=image_tensor, labels=labels, return_dict=True)["loss"].to('cpu').numpy())
                if cap_order == 'A':
                    i0c0, i1c0, i0c1, i1c1 = scores
                elif cap_order == 'B': #
                    i0c1, i1c1, i0c0, i1c0 = scores
                total_count += 1
                if i0c0 > i0c1 and i1c1 > i1c0:   # i0 ->  c0 and c1 ;  i1 ->   c1 and c0
                    text_correct += 1
                if i0c0 > i1c0 and i1c1 > i0c1:  #  c0 -> i0 and i1 ;  c1 -> i1 and i0
                    img_correct += 1
                if i0c0 > i0c1 and i1c1 > i1c0 and i0c0 > i1c0 and i1c1 > i0c1:
                    group_correct += 1
            text_acc = text_correct / float(total_count)
            img_acc = img_correct / float(total_count)
            group_acc = group_correct / float(total_count)
            print(f'Text acc: {text_acc}, Image acc: {img_acc}, Group acc: {group_acc}, total count: {total_count}')
            ans_file.write(json.dumps({"text_acc": text_acc, "img_acc": img_acc, "group_acc": group_acc,
                                       "total_count": total_count,
                                       "text_correct": text_correct, "img_correct": img_correct, "group_correct":group_correct}) + "\n")

# >>>>>>> leonid
    elif args.inference_mode == 'generate':
        if args.eval_dataset in ['SugarCrepe', 'Ours_baseline', 'Ours_new']:
            answer_2options = ['A', 'B']
        for data_tuple, line in tqdm(zip(data_loader, questions), total=len(questions)):
            input_ids, image_tensor = data_tuple
            if image_tensor.dim() == 5:
                image_tensor = image_tensor.squeeze(0)
            if image_tensor is None:
                continue

            idx = line["question_id"]
            if args.eval_dataset in ['SugarCrepe', 'Ours_baseline', 'Ours_new']:
                gt_answer_idx = line['gt_answer_idx']# the index of the true caption after shuffling

            else: # SEED
                cur_prompt = line["text"] #  question with options

            if isinstance(input_ids, list):
                input_ids = input_ids[0]
            input_ids = input_ids.to(device='cuda', non_blocking=True)

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids, # (1, 83)  containing the -200
                    images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), # (1, 3, 336, 336)
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    num_beams=args.num_beams,
                    max_new_tokens=args.max_new_tokens,
                    use_cache=True)  #  in generate mode, there is no labels provided

            input_token_len = input_ids.shape[1]
            n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
            if n_diff_input_output > 0:
                print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
            outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
            outputs = outputs.strip()

            # hotfix :-)
            _outputs = outputs.split('</s>')
            outputs = _outputs[0]

            ans_id = shortuuid.uuid()
            if args.eval_dataset in ['SugarCrepe', 'Ours_baseline', 'Ours_new']:
                ans_file.write(json.dumps({"question_id": idx,
                                           "text": outputs, # outputs is the answer option  A or B
                                             "gt_answer": answer_2options[gt_answer_idx],
                                           "answer_id": ans_id,
                                           "model_id": model_name,
                                           "metadata": {}
                                           }) + "\n")
            elif args.eval_dataset == 'common_diff':
                ans_file.write(json.dumps({"question_id": idx,
                                           "image": line["image"],
                                           "prompt": cur_prompt,
                                           "text": outputs,
                                           "answer_id": ans_id,
                                           "model_id": model_name,
                                           "metadata": {}
                                           }) + "\n")
            else: # SEED
                ans_file.write(json.dumps({"question_id": idx,
                                           "prompt": cur_prompt,
                                           "text": outputs,
                                           "answer_id": ans_id,
                                           "model_id": model_name,
                                           "metadata": {}}) + "\n")
            # ans_file.flush()
    ans_file.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default=None)
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="/system/user/publicdata/coco2017/coco2017/images/train2017/")
    parser.add_argument("--sugarcrepe-image-folder", type=str, default="")
    parser.add_argument("--aro-image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default='/system/user/publicdata/llava_datasets/sgd_summarize_commdiff.jsonl')
    # parser.add_argument("--question-file", type=str, default='/system/user/publicdata/llava_datasets/sgd_summarize_commdiff_dummy.jsonl')
    parser.add_argument("--answers-file", type=str, default="/system/user/publicdata/llava_datasets/sgd_summarize_commdiff_answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0) #0.2
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=256)
    parser.add_argument("--inference_mode", type=str, default='generate')
    parser.add_argument("--perplexity_prompt_version", type=str, default='v1')
    parser.add_argument("--eval_dataset", type=str, default='common_diff')
    parser.add_argument("--two_crops", action='store_true', help='use two crops for the image')
    parser.add_argument("--annotation_file", type=str, default='/system/user/publicdata/LMM_benchmarks/SEED-Bench/SEED-Bench.json')

    args = parser.parse_args()

    # from cvar_pyutils.debugging_tools import set_remote_debugger
    # set_remote_debugger(None, 12345)

    eval_model(args)
