import os
import torch
from PIL import Image
import torch
import os
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria


class COCOEvalData(torch.utils.data.Dataset):
    def __init__(self, questions, image_dir, image_processor, tokenizer, conv_mode, mm_use_im_start_end):
        self.image_dir = image_dir
        self.image_processor = image_processor
        self.questions = questions
        self.conv_mode = conv_mode
        self.mm_use_im_start_end = mm_use_im_start_end
        self.tokenizer = tokenizer


    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        '''
        return image_tensor, input_ids, cur_prompt, neg_ids, question_id, img_id
        image_tensor: (1, 3, 224, 224)
        input_ids: (1, seq_len)
        cur_prompt: str
        neg_ids: (1, seq_len)
        question_id: str
        img_id: str

        '''
        data = self.questions[idx]

        qs = data['conversations'][0]['value'].replace('<image>', '').strip()
        cur_prompt = qs

        qs_neg = data["conversations"][1]['value']

        img_id = data["image"]
        question_id = data["id"]

        image = Image.open(os.path.join(self.image_dir, img_id))

        image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        image_tensor = image_tensor.unsqueeze(0).half().cuda()

        if self.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
        cur_prompt = '<image>' + '\n' + cur_prompt
        
        ## exp4: add image in neg prompt
        if self.mm_use_im_start_end:
            qs_neg = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs_neg
        else:
            qs_neg = DEFAULT_IMAGE_TOKEN + '\n' + qs_neg
        ##

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        conv_neg = conv_templates[self.conv_mode].copy()
        conv_neg.append_message(conv_neg.roles[0], qs_neg)
        conv_neg.append_message(conv_neg.roles[1], None)
        prompt_neg = conv_neg.get_prompt() # use template for 1230, no template for 1225

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        ## enable multimodal tokens
        neg_ids = tokenizer_image_token(prompt_neg, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        ## exp4: do not need image token
        # neg_ids = self.tokenizer(prompt_neg, return_tensors='pt')['input_ids'].cuda()

        # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        # keywords = [stop_str]
        # stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)        
        return image_tensor, input_ids, cur_prompt, neg_ids, question_id, img_id


def custom_collate_fn(batch):
    # Unzip the batch
    image_tensors, input_idss, cur_prompts, neg_idss, question_ids, img_ids = zip(*batch)

    image_tensor_batch = torch.stack(image_tensors, dim=0)  # Shape: (batch_size, 3, 224, 224)
    # torch.Size([8, 1, 3, 224, 224]) to torch.Size([8, 3, 224, 224])
    image_tensor_batch = torch.squeeze(image_tensor_batch, dim=1)

    # Pad sequences for input_ids and neg_idss
    # input_ids = [torch.squeeze(input_id, dim=0) for input_id in input_ids]
    # neg_idss = [torch.squeeze(neg_ids, dim=0) for neg_ids in neg_idss]
    # input_id_batch = pad_sequence(input_ids, batch_first=True, padding_value=0)  # Shape: (batch_size, max_seq_len)
    # neg_ids_batch = pad_sequence(neg_idss, batch_first=True, padding_value=0)  # Shape: (batch_size, max_seq_len)
    
    input_ids = [torch.squeeze(input_id, dim=0) for input_id in input_idss]
    neg_idss = [torch.squeeze(neg_ids, dim=0) for neg_ids in neg_idss]

    input_ids = [input_id.flip(dims=[0]) for input_id in input_ids]
    neg_idss = [neg_ids.flip(dims=[0]) for neg_ids in neg_idss]
    input_id_batch = pad_sequence(input_ids, batch_first=True, padding_value=0).flip(dims=[1])  # Shape: (batch_size, max_seq_len)
    neg_ids_batch = pad_sequence(neg_idss, batch_first=True, padding_value=0).flip(dims=[1])  # Shape: (batch_size, max_seq_len)

    # cur_prompts, question_ids, img_ids are already in list format
    cur_prompt_batch = list(cur_prompts)
    question_id_batch = list(question_ids)
    img_id_batch = list(img_ids)

    return image_tensor_batch, input_id_batch, cur_prompt_batch, neg_ids_batch, question_id_batch, img_id_batch


if __name__ == "__main__":
    import argparse
    import json
    from transformers import AutoTokenizer
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="./checkpoints/llava-llama-2-7b-chat-lightning-preview")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="../POPE/data/minival2014/minival2014")
    parser.add_argument("--question_path", type=str, default="../POPE/llava_qa/question")
    parser.add_argument("--question_file", type=str, default="I1_sub240_control.json")
    parser.add_argument("--answer_path", type=str, default="../POPE/llava_qa/answer")
    parser.add_argument("--answers_file", type=str, default=None)

    parser.add_argument("--conv-mode", type=str, default="llava_llama_2")
    parser.add_argument("--cfg_mode", type=str, default="text")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.8)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--answer-prompter", action="store_true")
    args = parser.parse_args()

    from transformers import set_seed
    set_seed(args.seed)


    questions_file = "/data/linxi/workspace/POPE/llava_qa/question/I4_test_c2.json"
    questions = json.load(open(questions_file, "r"))
    # image_processor is CLIPImageProcessor
    from transformers import CLIPImageProcessor
    # image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")


    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    dataset = COCOEvalData(questions, args.image_folder, image_processor, tokenizer, args.conv_mode, False)
    eval_dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    for images, input_ids, cur_prompt, neg_ids, question_id, img_id in eval_dataloader:
        print(f"test1 {images.shape} {input_ids.shape} {cur_prompt} {neg_ids.shape} {question_id} {img_id}")
        break
