import os
import torch
from PIL import Image
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria


class COCOEvalData_v2(torch.utils.data.Dataset):
    def __init__(self, questions, image_dir, image_processor, aligned_vision_image_processor, tokenizer, aligned_vision_tokenizer, conv_mode, mm_use_im_start_end):
        self.image_dir = image_dir
        self.image_processor = image_processor
        self.questions = questions
        self.conv_mode = conv_mode
        self.mm_use_im_start_end = mm_use_im_start_end
        self.tokenizer = tokenizer
        self.aligned_vision_tokenizer = aligned_vision_tokenizer
        self.aligned_vision_image_processor = aligned_vision_image_processor


    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        '''
        return image_tensor, input_ids, cur_prompt, neg_ids, question_id, img_id
        image_tensor: (1, 3, 224, 224)
        input_ids: (1, seq_len)
        cur_prompt: str
        neg_ids: (1, seq_len)
        question_id: str
        img_id: str

        '''
        data = self.questions[idx]

        qs = data['conversations'][0]['value'].replace('<image>', '').strip()
        cur_prompt = qs

        qs_neg = data["conversations"][1]['value']

        img_id = data["image"]
        question_id = data["id"]

        image = Image.open(os.path.join(self.image_dir, img_id))

        image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
        image_tensor = image_tensor.unsqueeze(0).half().cuda()
        # # BUG: Unsupported number of image dim: 2
        if np.asarray(image).ndim == 2:
            image = np.stack([image] * 3, axis=-1)
            image = Image.fromarray(image)
        # aligned_vision_image_tensor = self.aligned_vision_image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() # (1,3,224,224)
        # aligned_vision_image_tensor = image to tensor directly
        # image from JpegImageFile to numpy.ndarray
        image = np.asarray(image)
        aligned_vision_image_tensor = torch.tensor(image).half().cuda()

        if self.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
        cur_prompt = '<image>' + '\n' + cur_prompt

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
        aligned_vision_input_ids = tokenizer_image_token(prompt, self.aligned_vision_tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        # keywords = [stop_str]
        # stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)        
        return image_tensor, input_ids, cur_prompt, aligned_vision_image_tensor, aligned_vision_input_ids, question_id, img_id


def custom_collate_fn(batch):
    # Unzip the batch
    image_tensors, input_ids, cur_prompts,  aligned_vision_image_tensors, aligned_vision_input_idss, question_ids, img_ids = zip(*batch)

    image_tensor_batch = torch.stack(image_tensors, dim=0)  # Shape: (batch_size, 3, 224, 224)
    # torch.Size([8, 1, 3, 224, 224]) to torch.Size([8, 3, 224, 224])
    image_tensor_batch = torch.squeeze(image_tensor_batch, dim=1)

    # Pad sequences for input_ids and neg_idss
    input_ids = [torch.squeeze(input_id, dim=0) for input_id in input_ids]
    input_id_batch = pad_sequence(input_ids, batch_first=True, padding_value=0)  # Shape: (batch_size, max_seq_len)

    # aligned_vision_image_tensor_batch = torch.stack(aligned_vision_image_tensors, dim=0)
    # aligned_vision_image_tensor_batch = torch.squeeze(aligned_vision_image_tensor_batch, dim=1)
    # use list
    aligned_vision_image_tensor_batch = list(aligned_vision_image_tensors)
    aligned_vision_input_idss = [torch.squeeze(aligned_vision_input_ids, dim=0) for aligned_vision_input_ids in aligned_vision_input_idss]
    aligned_vision_input_id_batch = pad_sequence(aligned_vision_input_idss, batch_first=True, padding_value=0)  # Shape: (batch_size, max_seq_len)
    
    # cur_prompts, question_ids, img_ids are already in list format
    cur_prompt_batch = list(cur_prompts)
    question_id_batch = list(question_ids)
    img_id_batch = list(img_ids)

    return image_tensor_batch, input_id_batch, cur_prompt_batch, aligned_vision_image_tensor_batch, aligned_vision_input_id_batch, question_id_batch, img_id_batch


if __name__ == "__main__":
    import argparse
    import json
    from transformers import AutoTokenizer
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="./checkpoints/llava-llama-2-7b-chat-lightning-preview")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="../POPE/data/minival2014/minival2014")
    parser.add_argument("--question_path", type=str, default="../POPE/llava_qa/question")
    parser.add_argument("--question_file", type=str, default="I1_sub240_control.json")
    parser.add_argument("--answer_path", type=str, default="../POPE/llava_qa/answer")
    parser.add_argument("--answers_file", type=str, default=None)

    parser.add_argument("--conv-mode", type=str, default="llava_llama_2")
    parser.add_argument("--cfg_mode", type=str, default="text")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.8)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--answer-prompter", action="store_true")
    args = parser.parse_args()

    from transformers import set_seed
    set_seed(args.seed)


    questions_file = "/data/linxi/workspace/POPE/llava_qa/question/I4_test_c2.json"
    questions = json.load(open(questions_file, "r"))
    # image_processor is CLIPImageProcessor
    from transformers import CLIPImageProcessor
    # image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
    image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")


    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    dataset = COCOEvalData(questions, args.image_folder, image_processor, tokenizer, args.conv_mode, False)
    eval_dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

    for images, input_ids, cur_prompt, neg_ids, question_id, img_id in eval_dataloader:
        print(f"test1 {images.shape} {input_ids.shape} {cur_prompt} {neg_ids.shape} {question_id} {img_id}")
        break
