import os
import json
import copy
import random
import pandas as pd
from typing import List, Dict, Optional, Sequence, Tuple

import torch
import transformers
import numpy as np

from torch.utils.data import Dataset, ConcatDataset
from PIL import Image
from dataclasses import dataclass

from llava.conversation import conv_templates
from llava.mm_utils import tokenizer_image_token, expand2square
from llava.data.process import *
from llava.constants import *


class EncyclopedicDocumentDataset(Dataset):
    """
    Each line of the train / eval / test csv file contains
    {
        - question: The question Q to be used for the VQA triplets.
        - answer: The answer to the question. This field may contain multiple answers:
          if the question was answered by multiple annotators, answers are separated by '|'.
          In case of the multi_answer questions, individual answers are separated by '&&'.
        - dataset_image_ids: A list of up to 5 identifier for the image associated with the question.
          The IDs correspond to the images from the image dataset.
        - dataset_name: The name of the image dataset.
        - dataset_category_id: An identifier for the category of the subject in the image which corresponds
          to the original IDs of the respective datasets.
        - question_type: The type of the question, which can be one of ['templated', 'automatic', 'multi_answer', '2_hop'].
        - wikipedia_url: The URL of the Wikipedia article that corresponds to the knowledge base for the question.
          This URL acts as a key to our provided knowledge base. For two_hop questions this field contains the two
          consecutive URLs separated by '|'.
        - wikipedia_title: The title of the Wikipedia article that corresponds to the knowledge base for the question.
          Warning: not a stable identifier, use the wikipedia_url instead.
        - evidence: The evidence supporting the answer, in the form of a string. Only for templated questions.
        - evidence_section_id: An integer identifier indicating the section of the knowledge base where the evidence
          can be found. For two_hop questions there are two IDs separated by '|'.
        - evidence_section_title: The title of the section of the knowledge base where the evidence can be found.
          Corresponds to evidence_section_id. For two_hop questions the two IDs are separated by '|'.
        - encyclopedic_vqa_split: This defines the split in our Encyclopedic-VQA dataset: train, val, or test.
        - question_original: The original text of the question before any rephrasing.
        - wikipedia_url_used_in_train: Boolean denoting whether the wikipedia_url of this questions occurs also
          in the training set. When this is 'False', the subject of the question (C in our paper)
          with its corresponding wikipedia page is unseen during training.
    }

    For the KB dictionary (encyclopedic_kb_wiki.json), the key is the wikipedia_url and the value contains the following properties:
    {
        - title: Title of the Wikipedia article
        - section_titles: List with titles of each section. Its first element is identical to title.
        - section_texts: List with contents for each section.
        - image_urls: List with urls to images within the Wikipedia article.
        - image_reference_descriptions: List with reference descriptions (i.e. captions) of the images.
        - image_section_indices: List of integers denoting the sections where each image belongs to (i.e. index in
          section_titles and section_texts).
        - url: The wikipedia_url (again)
    }
    """

    def __init__(self, query_path: str, document_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor, is_training=True, concat_sec=False):
        super(EncyclopedicDocumentDataset, self).__init__()
        # Load query dataset
        query_df = pd.read_csv(query_path)
        # For debugging
        if data_args.debugging:
            query_df = query_df[:100]
        # For subset experiment
        if data_args.train_subset and is_training:
            random_seed = 42  # We will fix the random seed to compare models trained with the same subset dataset.
            subset_size = int(len(query_df) * data_args.train_subset_ratio)
            query_df = query_df.sample(n=subset_size, random_state=random_seed)  # randomly sample about N% of dataset.
        self.query_df = query_df
        # Load id to path mapping dictionary
        self.query_id_to_path = json.load(open(data_args.dataset_id_to_path, 'r'))

        # Load document KB
        self.document_kb = json.load(open(document_path, 'r'))

        # Subset number of sections per document.
        # Since the number of sections for each document varies significantly, for stable GPU memory usage,
        # we limit the number of the number of sections selected.
        self.subset_sec_num = data_args.subset_sec_num

        self.tokenizer = tokenizer
        self.data_args = data_args
        self.is_training = is_training
        # Image processor
        self.query_image_processor = image_processor

        self.concat_sec = concat_sec

        # Model variants
        self.query_use_image = data_args.query_use_image
        self.mixed_query_modality = data_args.mixed_query_modality and is_training

        self.doc_use_image = data_args.doc_use_image
        self.doc_use_table = data_args.doc_use_table

        if self.doc_use_image:
            # Load image_url to image_id mapping dictionary
            self.doc_image_url_to_id = json.load(open(data_args.image_url_to_id_path, 'r'))
        else:
            self.doc_image_url_to_id = None

        # The model can take only a single image
        self.single_img = data_args.single_img

        if self.single_img:
            self.doc_image_processor = copy.deepcopy(image_processor)
        else:
            # Resize the doc image processing size, to have 2x2 input image.
            self.doc_image_processor = copy.deepcopy(image_processor)
            self.doc_image_processor.size = (self.doc_image_processor.size[0] // 2, self.doc_image_processor.size[1] // 2)

        # The model refers to only the summary section of a document.
        self.only_summary = data_args.only_summary
        self.only_entity = data_args.only_entity

    def __len__(self):
        return len(self.query_df)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        query = self.query_df.iloc[i]
        document = self.document_kb[query['wikipedia_url']]

        # Mixed modality query version. Randomly give multimodal or uni-modal query
        if self.mixed_query_modality:
            assert self.query_use_image, "Use of query image should be activated for the mixed format"
            multimodal_choice = random.choice([True, False])
        else:
            multimodal_choice = True

        # Generate image + text query pair
        if multimodal_choice and self.query_use_image and 'dataset_image_ids' in query:
            query_image_id = query['dataset_image_ids']
            query_image_file = self.query_id_to_path[str(query_image_id)]
            query_image = Image.open(os.path.join(self.data_args.image_folder, query_image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                query_image = expand2square(query_image, tuple(int(x * 255) for x in self.query_image_processor.image_mean))
            query_image = self.query_image_processor.preprocess(query_image, return_tensors='pt')['pixel_values'][0] 
            query_has_image = True
        else:
            query_has_image = False
        sources = preprocess_query(
            copy.deepcopy(query['question'] if query_has_image else query['question_original']), is_multimodal=query_has_image
        )
        query_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=query_has_image,
            input_type='query' if not self.concat_sec else 'query_sec',
        )

        if self.only_entity:
            num_sub_secs = 1
            random_sec_indices = np.array([])
            subdoc_evid_section_id = -1

        elif self.only_summary:
            # The model can refer to only the summary section (the first section of a document).
            num_sub_secs = 1
            random_sec_indices = np.array([0])
            subdoc_evid_section_id = 0
        else:
            # Generate interleaved document
            # In this dataset, we use all the sections of the document, but give in a separate sample.
            num_secs = len(document['section_titles'])
            num_sub_secs = min(num_secs, self.subset_sec_num)
            random_sec_indices = np.random.choice([i for i in range(num_secs) if i != query['evidence_section_id']], num_sub_secs - 1, replace=False)
            random_sec_indices = np.append(random_sec_indices, query['evidence_section_id']) # essentially add the evidence section of the query            
            random_sec_indices = np.sort(random_sec_indices)
            subdoc_evid_section_id = int(np.where(random_sec_indices == query['evidence_section_id'])[0]) # position of evidence section within the subdocument
        

        # Mapping function: document indices to sub-document indices
        doc_to_subdoc_indices = {}
        for sub_sec_idx, sec_idx in enumerate(random_sec_indices):
            doc_to_subdoc_indices[sec_idx] = sub_sec_idx

        sub_document = {}
        sub_document['url'] = document['url']
        sub_document['title'] = document['title']
        sub_document['section_titles'] = [content for sec_idx, content in enumerate(document['section_titles']) if sec_idx in random_sec_indices]
        sub_document['section_texts'] = [content for sec_idx, content in enumerate(document['section_texts']) if sec_idx in random_sec_indices]
       
        # In this implementation, we merge the images from the same section into a single image (2x2).
        # Hence, we modify the original meta-data to be grouped by section id.
        temp_document = {
            'image_section_indices': [doc_to_subdoc_indices[image_sec_idx] 
                                      for image_sec_idx in document['image_section_indices'] 
                                      if image_sec_idx in random_sec_indices],
            'image_urls': [content for idx, content in enumerate(document['image_urls']) 
                           if document['image_section_indices'][idx] in random_sec_indices],
            'image_reference_descriptions': [content for idx, content in 
                                             enumerate(document['image_reference_descriptions']) 
                                             if document['image_section_indices'][idx] in random_sec_indices],
        }
        sub_document['image_section_indices'] = []
        sub_document['image_reference_descriptions'] = []
        if self.doc_use_image and len(temp_document['image_urls']) != 0:
            doc_merged_image_list = []

            doc_image_list = []
            doc_image_descript_list = []
            for idx, (doc_image_descript, doc_image_url) in enumerate(
                    zip(temp_document['image_reference_descriptions'],
                        temp_document['image_urls'])):

                doc_image_file = self.doc_image_url_to_id[doc_image_url]
                doc_image = Image.open(os.path.join(os.path.dirname(self.data_args.document_path),'AToMiC-Images-v0.2/data', doc_image_file)).convert('RGB')
                if self.data_args.image_aspect_ratio == 'pad':
                    doc_image = expand2square(doc_image, tuple(int(x * 255) for x in self.doc_image_processor.image_mean))
                doc_image = self.doc_image_processor.preprocess(doc_image, return_tensors='pt')['pixel_values'][0]
                doc_image_list.append(doc_image)
                doc_image_descript_list.append(doc_image_descript)

                # If the model takes only a single image
                if self.single_img:
                    doc_merged_image_list = doc_image_list
                    current_section_id = temp_document['image_section_indices'][idx]
                    sub_document['image_section_indices'].append(current_section_id)
                    sub_document['image_reference_descriptions'].append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []
                    break

                # Check if the section idx will be changed or the end has arrived.
                section_will_change = ((idx == len(temp_document['image_section_indices']) - 1)
                                       or (temp_document['image_section_indices'][idx] != temp_document['image_section_indices'][idx+1]))
                # Merge the four sub-images (W//2 x H//2) into a single image (W x H)
                # This happens when
                # 1. there are four images in the temporal list
                # 2. the section has changed to the next section
                # 3. the for loop ends.
                if section_will_change or len(doc_image_list) == 4:
                    doc_merged_image = self.merge_images(doc_image_list)
                    doc_merged_image_list.append(doc_merged_image)

                    current_section_id = temp_document['image_section_indices'][idx]
                    sub_document['image_section_indices'].append(current_section_id)
                    sub_document['image_reference_descriptions'].append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []
            doc_has_image = True
        else:
            doc_has_image = False

        # Add tabular modality
        sub_document['table_section_indices'] = [doc_to_subdoc_indices[table_sec_idx]
                                      for table_sec_idx in document['table_section_indices']
                                      if table_sec_idx in random_sec_indices]
        sub_document['tables'] = [content for idx, content in enumerate(document['tables'])
                       if document['table_section_indices'][idx] in random_sec_indices]
        if self.doc_use_table and len(sub_document['tables']) != 0:
            doc_has_table = True
        else:
            doc_has_table = False

        sources = preprocess_interleaved_section(
            sub_document, is_multimodal=doc_has_image, is_tabular=doc_has_table, is_title_only=self.only_entity)

        doc_input_ids = []
        for source in sources:
            section_dict = preprocess(
                source,
                self.tokenizer,
                has_image=doc_has_image,
                input_type='document' if not self.concat_sec else 'document_sec',
            )
            doc_input_ids.append(section_dict["input_ids"][0])

        assert subdoc_evid_section_id < num_sub_secs, "The section evidence index is longer than the number of sections"

        data_dict = dict(query_input_ids=query_dict["input_ids"][0],
                         query_evidence_section_id=subdoc_evid_section_id,
                         doc_input_ids=doc_input_ids,
                         doc_num_section=num_sub_secs,
                         )

        # image exist in data
        # For image, the label is positive (first batch query: 0, second batch query: 1)
        # Number would be multiplied in the collator class.
        if query_has_image:
            data_dict['query_image'] = query_image
        elif self.data_args.is_multimodal:
            data_dict['query_image'] = None

        if doc_has_image:
            data_dict['doc_image'] = torch.stack(doc_merged_image_list, dim=0)
        elif self.data_args.is_multimodal:
            data_dict['doc_image'] = None

        return data_dict

    def merge_images(self, image_list: List):
        # Note that the Conv2d of visual encoder would scan the image starting
        # from the top-left corner -> top-right -> bottom-left -> bottom-right.
        N, M = self.doc_image_processor.size[0], self.doc_image_processor.size[1]

        merged_img = torch.zeros((3, 2*N, 2*M), dtype=image_list[0].dtype)

        for i in range(len(image_list)):
            if i == 0:
                merged_img[:,:N,:M] = image_list[i]
            elif i == 1:
                merged_img[:,:N,M:] = image_list[i]
            elif i == 2:
                merged_img[:,N:,:M] = image_list[i]
            else:
                merged_img[:,N:,M:] = image_list[i]

        return merged_img

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        # Process query text to make batch
        # Since the query is way shorter than the document, we will process them separately rather than adding a bunch of zero pads to the query.
        if 'query_input_ids' in instances[0]:
            query_input_ids = [instance["query_input_ids"] for instance in instances]
            query_input_ids = torch.nn.utils.rnn.pad_sequence(
                query_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            query_evidence_section_labels = torch.tensor([instance['query_evidence_section_id'] for instance in instances], dtype=torch.int64)
        else:
            query_input_ids, query_evidence_section_labels = None, None

        # Process list of section text to make batch
        # We remove the doc_input_ids = doc_input_ids[:, :self.tokenizer.model_max_length]
        # since we will split the sections into passages
        if 'doc_input_ids' in instances[0]:
            doc_input_ids = [item for instance in instances for item in instance["doc_input_ids"]]
            doc_input_ids = torch.nn.utils.rnn.pad_sequence(
                doc_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            doc_num_sections = torch.tensor([instance['doc_num_section'] for instance in instances], dtype=torch.int64)
        else:
            doc_input_ids, doc_num_sections = None, None

        batch = dict(
            query_input_ids=query_input_ids,
            doc_input_ids=doc_input_ids,
            query_attention_mask=query_input_ids.ne(self.tokenizer.pad_token_id),
            doc_attention_mask=doc_input_ids.ne(self.tokenizer.pad_token_id),
            query_evidence_section_labels=query_evidence_section_labels,
            doc_num_sections=doc_num_sections,
        )

        if any(instance['query_image'] is not None for instance in instances):
            query_images = [instance['query_image'] for instance in instances if instance['query_image'] is not None]  # For the mixed modality query
            batch['query_images'] = torch.stack(query_images)  # At most a single image per query.

        if any(instance['doc_image'] is not None for instance in instances):
            doc_images = [instance['doc_image'] for instance in instances if instance['doc_image'] is not None]  # There could be no-images section
            batch['doc_images'] = torch.cat(doc_images)  # Different amounts of image per document.

        return batch


def build_encyclopedic_vqa_data_modules(
        tokenizer: transformers.PreTrainedTokenizer,
        data_args,
        image_processor,
        is_training = True,
):
    train_dataset = EncyclopedicDocumentDataset(query_path=data_args.train_query_path, document_path=data_args.document_path,
                                             tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
                                             is_training=is_training, concat_sec=data_args.concat_sec)
    eval_dataset = EncyclopedicDocumentDataset(query_path=data_args.eval_query_path, document_path=data_args.document_path,
                                            tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
                                            is_training=False, concat_sec=data_args.concat_sec)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator)
