from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import Dataset, ConcatDataset
import torch
import transformers
import os
from typing import List, Dict, Optional, Sequence, Tuple
import json
import pandas as pd 
import copy
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pickle
import random
from datasets import load_dataset
from PIL import Image
from llava.conversation import conv_templates
from llava.constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX, 
                                   DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, 
                                   DEFAULT_IM_END_TOKEN, ENT_TOKEN)
from llava.mm_utils import tokenizer_image_token, process_images
from dataclasses import dataclass
import spacy
from sklearn.preprocessing import LabelEncoder
import sys 
sys.path.append("/ROOT_DIR/VQA/myModel/common_utils")
from instructions_rank import instruction_dict, generate_prompt_and_response

"""
TODO: 
1. construct labels for answering 
2. how to deal with two hop: two urls 

add more instruction tuning 

- what is the name of this entity?
- recover the content of this entity

update: 

add part for answer with yes or no 


"""

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result
    
def append_space_if_not_ends_with_space(s):
    if not s.endswith(' '):
        s += ' '
    return s


def select_random_item(input_string):
    items = input_string.split('|')
    return random.choice(items)

def generate_landmark_path(root, image_id):
    # Break down the image ID into its components (first character, second character, etc.)
    part1 = image_id[0]
    part2 = image_id[1]
    part3 = image_id[2]
    
    # Construct the path
    path =os.path.join(root, f"{part1}/{part2}/{part3}/{image_id}.jpg")
    
    return path


def truncate_text(text, chunk_size, chunk_overlap): 
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", " ", ""]
    )
    chunks = text_splitter.split_text(text)
    return chunks



def replace_filepath_to_bowen(img_path, oven_index_dict):
    if "oven_images" in img_path: 
        fname = os.path.basename(img_path)
        number = fname.split("_")[-1].split(".")[0] 
        fd = number[:2]
        if number.startswith("00") or number.startswith("05"):
            fd = oven_index_dict[fname]
        path = os.path.join("/ROOT_DIR/VQA/oven/oven_images", fd, fname) 
    elif "wikipedia_images_full" in img_path: 
        # fname = os.path.basename(img_path)
        fname = img_path.split("wikipedia_images_full/")[1]
        path = os.path.join("/ROOT_DIR/VQA/oven/wikipedia_images_full",fname )
    else: 
        path = img_path
    return path        



class EvqaQuerywithKBDataset(Dataset): 
    def __init__(self,args,data_file,url_to_idx_path ,tokenizer, image_processor, model_config, split = "train", max_token = 512, ): 
        super().__init__()

        self.args = args
        self.tokenizer = tokenizer 
        self.image_processor = image_processor
        self.model_config = model_config
        self.task = args.task

        self.url_to_idx_path = url_to_idx_path
        self.data_file = data_file 

        with open(self.data_file, "r") as f: 
            self.data_list = json.load(f)
        
        
        with open("/ROOT_DIR/EncyclopedicVQA/full_img_list.pkl", "rb") as f:
            img_index_list, img_url_list = pickle.load(f)
        self.img_url_to_idx = dict(zip(img_url_list, img_index_list))
        self.img_dataset = load_dataset("TREC-AToMiC/AToMiC-Images-v0.2", split="train")
        
        file_path = "/ROOT_DIR/VQA/oven/oven_images/index.pkl"
        with open(file_path, "rb") as f: 
            self.oven_index_dict = pickle.load(f)




    def __len__(self): 
        return len(self.data_list) 
        # return len(self.question_data_list) + len(self.content_data_list)
    

    def process_input(self, prompt, image_path): 
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
       
        if "https" in image_path: 
            img_idx = self.img_url_to_idx[image_path]
            image = self.img_dataset[img_idx]["image"]
        else: 
            image = Image.open(image_path).convert('RGB')
        return input_ids, image
    
    def process_autoreg(self, prompt, response, image_path): 
        full_prompt = prompt+response
        input_ids = tokenizer_image_token(full_prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        human_ids_len = len(tokenizer_image_token(prompt,  self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'))
        target = copy.deepcopy(input_ids)
        target[:human_ids_len] = IGNORE_INDEX

        image = Image.open(image_path).convert('RGB')
        return input_ids, target, image 

    def process_rank(self, prompt, response, image_path): 
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        image = Image.open(image_path).convert('RGB')
        return input_ids, None ,image

    def process_assistant(self, prompt, response):
        """
        in current dataset, "ASSISTANT:" is in the response 
        """
        prompt += "ASSISTANT: "
        response = response.replace("ASSISTANT: ", "")
        return prompt, response
    

    def __getitem__(self, index):
        data = self.data_list[index]
        question = data["question"]
        img_path = data["question_image"]

        img_path = replace_filepath_to_bowen(img_path, self.oven_index_dict)
      
        img_path = img_path.replace("/ROOT_DIR/EncyclopedicVQA/GoogleLandMarksV2/train", "/ROOT_DIR/VQA/GoogleLandMarksV2/train").replace("/ROOT_DIR/EncyclopedicVQA/inat", "/ROOT_DIR/VQA/iNaturalist")
        answer = data["answer"]

     
        input_ids, image = self.process_input(data["query_prompt"], img_path)

        pos_image = data["positive"]["image_url"]
        

        pos_image = replace_filepath_to_bowen(pos_image, self.oven_index_dict)
        

        pos_input_ids, pos_image = self.process_input(data["positive"]["prompt"], pos_image)
        

        if self.task != 'ansrank' and self.task != 'ansrank_noent':
            neg_image = data["negative"]["image_url"]
            neg_image = replace_filepath_to_bowen(neg_image, self.oven_index_dict)
            neg_input_ids, neg_image = self.process_input(data["negative"]["prompt"], neg_image)




        data["positive"]["autoreg"]["prompt"], data["positive"]["autoreg"]["response"] = self.process_assistant(data["positive"]["autoreg"]["prompt"], data["positive"]["autoreg"]["response"])

        if self.task != 'ansrank' and self.task != 'ansrank_noent':
            data["negative"]["autoreg"]["prompt"], data["negative"]["autoreg"]["response"] = self.process_assistant( data["negative"]["autoreg"]["prompt"], data["negative"]["autoreg"]["response"])

        
        #################### prepare autorehressive data  
        # positive

        # process_rank do not include response
        process_func = self.process_autoreg 
        if self.task == "rank_binary" or self.task == "rank_relprob" or self.task == "rank_preEmbed":
        # if "rank" in self.task:
            process_func = self.process_rank 

        
        

        autoreg_pos_input_ids, autoreg_pos_target, autoreg_pos_image = process_func(
            data["positive"]["autoreg"]["prompt"], data["positive"]["autoreg"]["response"], img_path
        )
        pos_correct_pos_idx = data["positive"]["autoreg"]["correct_idx"]

        if self.task != 'ansrank' and self.task != 'ansrank_noent':
            autoreg_neg_input_ids, autoreg_neg_target, autoreg_neg_image = process_func(
                data["negative"]["autoreg"]["prompt"], data["negative"]["autoreg"]["response"], img_path
            )
            neg_correct_pos_idx = data["negative"]["autoreg"]["correct_idx"]

        # if "rank" in self.task:
        if self.task ==  'rank_binary' or self.task == "rank_relprob" or self.task == "rank_preEmbed":
            # if data["positive"]['data_type'] == "positive_relevant":
            #     autoreg_pos_target = 2
            # else: 
            #     autoreg_pos_target = 1 
            # autoreg_neg_target = 0
            if data["positive"]['data_type'] == "positive_relevant":
                autoreg_pos_target = 1
            else: 
                autoreg_pos_target = 0
            autoreg_neg_target = 0

        data_dict = {
            "input_ids": input_ids,
            "image": image,
            "positive_input_ids": pos_input_ids,
            "positive_image": pos_image,
            

            "autoreg_pos_input_ids": autoreg_pos_input_ids,
            "autoreg_pos_target": autoreg_pos_target,
            "autoreg_pos_image": autoreg_pos_image,
            "autoreg_pos_correct_idx": pos_correct_pos_idx,

            

        }

        if self.task != 'ansrank' and self.task != 'ansrank_noent':
            neg_data_dict = {
                "negative_input_ids": neg_input_ids,
                "negative_image": neg_image,

                "autoreg_neg_input_ids": autoreg_neg_input_ids,
                "autoreg_neg_target": autoreg_neg_target,
                "autoreg_neg_image": autoreg_neg_image,
                "autoreg_neg_correct_idx": neg_correct_pos_idx ,


            }

            data_dict.update(
                neg_data_dict
            )

        
        # Return them all in a dictionary
        return data_dict
            

    
 
def compute_evqa_image_tensor(model, processor, image):
    input_pixels = processor(images=image, return_tensors="pt", padding=True).pixel_values.to(model.device)
    
    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model.encode_image(input_pixels)


    return image_features





@dataclass
class DataCollatorForSupervisedDataset(object):
    tokenizer: transformers.PreTrainedTokenizer
    image_processor: transformers.image_processing_utils.BaseImageProcessor
    
    model_config: transformers.PretrainedConfig
    evqa_image_processor: transformers.image_processing_utils.BaseImageProcessor
    task: str  

    def __call__(self, instances) :
        input_ids = [instance["input_ids"] for instance in instances]
        pos_input_ids = [instance["positive_input_ids"] for instance in instances]
        

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        pos_input_ids = torch.nn.utils.rnn.pad_sequence(
            pos_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        

        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        pos_input_ids = pos_input_ids[:, :self.tokenizer.model_max_length]

        

        # Images and sizes
        images = [instance["image"] for instance in instances]
        pos_images = [instance["positive_image"] for instance in instances]
        

        image_sizes = [img.size for img in images]
        pos_image_sizes = [img.size for img in pos_images]

        

        image_tensor = process_images(images, self.image_processor, self.model_config)
        pos_image_tensor = process_images(pos_images, self.image_processor, self.model_config)
        
        input_images = [instance["image"] for instance in instances]
        input_pos_images = [instance["positive_image"] for instance in instances]
        
        batch = dict(
            input_ids=input_ids,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            images=image_tensor,
            image_sizes=image_sizes,

            positive_input_ids=pos_input_ids,
            positive_attention_mask=pos_input_ids.ne(self.tokenizer.pad_token_id),
            positive_images=pos_image_tensor,
            positive_image_sizes=pos_image_sizes,

            input_images = input_images, 
            input_pos_images = input_pos_images, 
            
        )

         # Autoreg positive
        autoreg_pos_input_ids = [instance["autoreg_pos_input_ids"] for instance in instances]
        autoreg_pos_target = [instance["autoreg_pos_target"] for instance in instances]

       

        # Pad sequences
        def pad(seq_list, padding_value = self.tokenizer.pad_token_id):
            return torch.nn.utils.rnn.pad_sequence(
                seq_list, batch_first=True, padding_value=padding_value
            )[:, :self.tokenizer.model_max_length]
        autoreg_pos_input_ids = pad(autoreg_pos_input_ids)
        
        # if "rank" in self.task:
        if self.task == 'rank_binary' or self.task == "rank_relprob" or self.task == "rank_preEmbed":
            autoreg_pos_target = torch.tensor(autoreg_pos_target, dtype=torch.long)
        else: 
            autoreg_pos_target = pad(autoreg_pos_target, padding_value = IGNORE_INDEX)
           
        autoreg_pos_images = [instance["autoreg_pos_image"] for instance in instances]
        autoreg_pos_image_sizes = [img.size for img in autoreg_pos_images]
        autoreg_pos_image_tensor = process_images(autoreg_pos_images, self.image_processor, self.model_config)
        # correct_idx (the entity position index)
        auto_pos_correct_idx = [instance["autoreg_pos_correct_idx"] for instance in instances]
        

        batch.update(
            dict(
                autoreg_pos_input_ids=autoreg_pos_input_ids,
                autoreg_pos_labels=autoreg_pos_target,
                autoreg_pos_attention_mask=autoreg_pos_input_ids.ne(self.tokenizer.pad_token_id),
                autoreg_pos_images=autoreg_pos_image_tensor,
                autoreg_pos_image_sizes=autoreg_pos_image_sizes,
                autoreg_pos_correct_idx = auto_pos_correct_idx,
              
            )
        )


        
        image_input_pixels = self.evqa_image_processor(images=images, return_tensors="pt").pixel_values
        pos_input_pixels = self.evqa_image_processor(images=pos_images, return_tensors="pt").pixel_values
        autoreg_pos_image_pixels = self.evqa_image_processor(images=autoreg_pos_images, return_tensors="pt").pixel_values
        

        batch.update(
            dict(
                image_input_pixels = image_input_pixels, 
                pos_input_pixels = pos_input_pixels,  # Added
                autoreg_pos_image_pixels = autoreg_pos_image_pixels, # Added
            )
        )


        

        if self.task != 'ansrank' and self.task != 'ansrank_noent':
            neg_input_ids = [instance["negative_input_ids"] for instance in instances]
            neg_input_ids = torch.nn.utils.rnn.pad_sequence(
                neg_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
            )
            neg_input_ids = neg_input_ids[:, :self.tokenizer.model_max_length]
            neg_images = [instance["negative_image"] for instance in instances]
            neg_image_sizes = [img.size for img in neg_images]
            neg_image_tensor = process_images(neg_images, self.image_processor, self.model_config)
            input_neg_images = [instance["negative_image"] for instance in instances]
            # Autoreg negative
            autoreg_neg_input_ids = [instance["autoreg_neg_input_ids"] for instance in instances]
            autoreg_neg_target = [instance["autoreg_neg_target"] for instance in instances]
            autoreg_neg_input_ids = pad(autoreg_neg_input_ids)

            



            if "rank" in self.task:
                autoreg_neg_target = torch.tensor(autoreg_neg_target, dtype=torch.long)
            else: 
                autoreg_neg_target = pad(autoreg_neg_target, padding_value = IGNORE_INDEX)
            autoreg_neg_images = [instance["autoreg_neg_image"] for instance in instances]

            autoreg_neg_image_sizes = [img.size for img in autoreg_neg_images]

            autoreg_neg_image_tensor = process_images(autoreg_neg_images, self.image_processor, self.model_config)
            auto_neg_correct_idx = [instance["autoreg_neg_correct_idx"] for instance in instances]
            
            neg_input_pixels = self.evqa_image_processor(images=neg_images, return_tensors="pt").pixel_values
            autoreg_neg_image_pixels = self.evqa_image_processor(images=autoreg_neg_images, return_tensors="pt").pixel_values
            
            batch.update(
                dict(
                    negative_input_ids=neg_input_ids,
                    negative_attention_mask=neg_input_ids.ne(self.tokenizer.pad_token_id),
                    negative_images=neg_image_tensor,
                    negative_image_sizes=neg_image_sizes,
                    input_neg_images = input_neg_images,

                    autoreg_neg_input_ids=autoreg_neg_input_ids,
                    autoreg_neg_labels=autoreg_neg_target,
                    autoreg_neg_attention_mask=autoreg_neg_input_ids.ne(self.tokenizer.pad_token_id),
                    autoreg_neg_images=autoreg_neg_image_tensor,
                    autoreg_neg_image_sizes=autoreg_neg_image_sizes,
                    autoreg_neg_correct_idx = auto_neg_correct_idx,
                    
                    neg_input_pixels = neg_input_pixels,  # Added
                    autoreg_neg_image_pixels = autoreg_neg_image_pixels # Added
                )
            )

        return batch


def build_evqa_data_modules(
        tokenizer, model_config, 
        data_args, 
        is_training = True
): 
    
    image_processor = data_args.image_processor


    train_dataset = EvqaQuerywithKBDataset(
        data_args, 
        data_args.data_file, 
        data_args.url_to_idx_path, 
        tokenizer, image_processor, model_config, 
    )


    # load evqa image processor:
    from transformers import CLIPImageProcessor
    evqa_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") 


    data_collator = DataCollatorForSupervisedDataset(tokenizer, image_processor, model_config, evqa_processor, data_args.task)
    

    val_dataset = EvqaQuerywithKBDataset(
         data_args, 
        data_args.val_data_file, 
        data_args.val_url_to_idx_path, 
        tokenizer, image_processor, model_config, 
    )
    print("validation dataset:", len(val_dataset))

    return dict(
        train_dataset=train_dataset,
        data_collator=data_collator,
        eval_dataset= val_dataset,
    )






    