import os
import json
import copy
import random
import pandas as pd
from typing import List, Dict, Optional, Sequence, Tuple

import torch
import transformers
import numpy as np

from torch.utils.data import Dataset, ConcatDataset
from PIL import Image
from dataclasses import dataclass

from llava.conversation import conv_templates
from llava.mm_utils import tokenizer_image_token, expand2square
from llava.data.process import *
from llava.constants import *


class EncyclopedicDocumentHardNegDataset(Dataset):
    """
    Each line of the train / eval / test csv file contains
    {
        - question: The question Q to be used for the VQA triplets.
        - answer: The answer to the question. This field may contain multiple answers:
          if the question was answered by multiple annotators, answers are separated by '|'.
          In case of the multi_answer questions, individual answers are separated by '&&'.
        - dataset_image_ids: A list of up to 5 identifier for the image associated with the question.
          The IDs correspond to the images from the image dataset.
        - dataset_name: The name of the image dataset.
        - dataset_category_id: An identifier for the category of the subject in the image which corresponds
          to the original IDs of the respective datasets.
        - question_type: The type of the question, which can be one of ['templated', 'automatic', 'multi_answer', '2_hop'].
        - wikipedia_url: The URL of the Wikipedia article that corresponds to the knowledge base for the question.
          This URL acts as a key to our provided knowledge base. For two_hop questions this field contains the two
          consecutive URLs separated by '|'.
        - wikipedia_title: The title of the Wikipedia article that corresponds to the knowledge base for the question.
          Warning: not a stable identifier, use the wikipedia_url instead.
        - evidence: The evidence supporting the answer, in the form of a string. Only for templated questions.
        - evidence_section_id: An integer identifier indicating the section of the knowledge base where the evidence
          can be found. For two_hop questions there are two IDs separated by '|'.
        - evidence_section_title: The title of the section of the knowledge base where the evidence can be found.
          Corresponds to evidence_section_id. For two_hop questions the two IDs are separated by '|'.
        - encyclopedic_vqa_split: This defines the split in our Encyclopedic-VQA dataset: train, val, or test.
        - question_original: The original text of the question before any rephrasing.
        - wikipedia_url_used_in_train: Boolean denoting whether the wikipedia_url of this questions occurs also
          in the training set. When this is 'False', the subject of the question (C in our paper)
          with its corresponding wikipedia page is unseen during training.
    }

    For the KB dictionary (encyclopedic_kb_wiki.json), the key is the wikipedia_url and the value contains the following properties:
    {
        - title: Title of the Wikipedia article
        - section_titles: List with titles of each section. Its first element is identical to title.
        - section_texts: List with contents for each section.
        - image_urls: List with urls to images within the Wikipedia article.
        - image_reference_descriptions: List with reference descriptions (i.e. captions) of the images.
        - image_section_indices: List of integers denoting the sections where each image belongs to (i.e. index in
          section_titles and section_texts).
        - url: The wikipedia_url (again)
    }
    """

    def __init__(self, query_path: str, document_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor, is_training=True, concat_sec=False):
        super(EncyclopedicDocumentHardNegDataset, self).__init__()
        # Load query dataset
        query_df = pd.read_csv(query_path)
        # For debugging
        if data_args.debugging:
            query_df = query_df[:100]
        # For subset experiment
        if data_args.train_subset and is_training:
            random_seed = 42  # We will fix the random seed to compare models trained with the same subset dataset.
            subset_size = int(len(query_df) * data_args.train_subset_ratio)
            query_df = query_df.sample(n=subset_size, random_state=random_seed)  # randomly sample about N% of dataset.
        self.query_df = query_df
        # Load id to path mapping dictionary
        self.query_id_to_path = json.load(open(data_args.dataset_id_to_path, 'r'))

        # Load document KB
        self.document_kb = json.load(open(document_path, 'r'))

        # Subset number of sections per document.
        # Since the number of sections for each document varies significantly, for stable GPU memory usage,
        # we limit the number of the number of sections selected.
        self.subset_sec_num = data_args.subset_sec_num

        self.tokenizer = tokenizer
        self.data_args = data_args
        self.is_training = is_training
        # Image processor
        self.query_image_processor = image_processor

        self.concat_sec = concat_sec

        # Model variants
        self.query_use_image = data_args.query_use_image
        self.mixed_query_modality = data_args.mixed_query_modality and is_training

        self.doc_use_image = data_args.doc_use_image
        self.doc_use_table = data_args.doc_use_table

        if self.doc_use_image:
            # Load image_url to image_id mapping dictionary
            self.doc_image_url_to_id = json.load(open(data_args.image_url_to_id_path, 'r'))
        else:
            self.doc_image_url_to_id = None

        # The model can take only a single image
        self.single_img = data_args.single_img

        if self.single_img:
            self.doc_image_processor = copy.deepcopy(image_processor)
        else:
            # Resize the doc image processing size, to have 2x2 input image.
            self.doc_image_processor = copy.deepcopy(image_processor)
            self.doc_image_processor.size = (self.doc_image_processor.size[0] // 2, self.doc_image_processor.size[1] // 2)

        # The model refers to only the summary section of a document.
        self.only_summary = data_args.only_summary
        self.only_entity = data_args.only_entity

        # section retrieval result from the Passage retrieval. to provide hard negatives.
        retrieval_result = json.load(open(data_args.passage_result_path, 'r'))
        # Collect hard negatives, excluding the ground-truth query-section pair.
        self.hard_neg_set = {}
        train_val_negs = set(self.document_kb.keys())
        for q_idx in range(len(self.query_df)):
            query_ret_results = retrieval_result[str(q_idx)]
            doc_sorted = sorted(query_ret_results, key=query_ret_results.get, reverse=True)
            wiki_url = self.query_df.iloc[q_idx]['wikipedia_url']
            pos_idx = self.query_df.iloc[q_idx]['evidence_section_id']
            pos_sec_id = wiki_url + f'_section_{pos_idx:02d}'
            doc_sorted = [[sec_id.rsplit('_',2)[0], int(sec_id.rsplit('_',2)[2])] for sec_id in doc_sorted if sec_id != pos_sec_id]
            doc_sorted = [[neg_wiki_url, sec_id] for neg_wiki_url, sec_id in doc_sorted if neg_wiki_url in train_val_negs][:self.subset_sec_num-1]
            self.hard_neg_set[q_idx] = doc_sorted

    def __len__(self):
        return len(self.query_df)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        query = self.query_df.iloc[i]
        document = self.document_kb[query['wikipedia_url']]

        # Mixed modality query version. Randomly give multimodal or uni-modal query
        if self.mixed_query_modality:
            assert self.query_use_image, "Use of query image should be activated for the mixed format"
            multimodal_choice = random.choice([True, False])
        else:
            multimodal_choice = True

        # Generate image + text query pair
        if multimodal_choice and self.query_use_image and 'dataset_image_ids' in query:
            query_image_id = query['dataset_image_ids']
            query_image_file = self.query_id_to_path[str(query_image_id)]
            query_image = Image.open(os.path.join(self.data_args.image_folder, query_image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                query_image = expand2square(query_image, tuple(int(x * 255) for x in self.query_image_processor.image_mean))
            query_image = self.query_image_processor.preprocess(query_image, return_tensors='pt')['pixel_values'][0] 
            query_has_image = True
        else:
            query_has_image = False
        sources = preprocess_query(
            copy.deepcopy(query['question'] if query_has_image else query['question_original']), is_multimodal=query_has_image
        )
        query_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=query_has_image,
            input_type='query' if not self.concat_sec else 'query_sec',
        )
        # Generate interleaved document
        # In this dataset, we use all the sections of the document, but give in a separate sample.
        num_secs = len(document['section_titles'])
        num_sub_secs = min(num_secs, self.subset_sec_num)

        section_samples = [[query['wikipedia_url'], query['evidence_section_id']]]
        if num_sub_secs > 1:
            section_samples.extend(self.hard_neg_set[i][:num_sub_secs-1])

        sub_document = {}
        sub_document['url'] = document['url']
        sub_document['title'] = document['title']
        sub_document['section_titles'] = []
        sub_document['section_texts'] = []
        doc_merged_image_list = []
        sub_document['image_section_indices'] = []
        sub_document['image_reference_descriptions'] = []
        sub_document['table_section_indices'] = []
        sub_document['tables'] = []

        # Negative sections
        for sample_idx, (sec_wiki_url, sec_id) in enumerate(section_samples):
            section_title, section_text, merged_image_list, image_reference_descriptions, tables = self.extract_single_section(sec_wiki_url, sec_id)

            # Positive section
            sub_document['section_titles'].append(section_title)
            sub_document['section_texts'].append(section_text)
            doc_merged_image_list.extend(merged_image_list)
            sub_document['image_section_indices'].extend([sample_idx]*len(merged_image_list))
            sub_document['image_reference_descriptions'].extend(image_reference_descriptions)
            sub_document['table_section_indices'].extend([sample_idx]*len(tables))
            sub_document['tables'].extend(tables)

        if self.doc_use_image and len(doc_merged_image_list) != 0:
            doc_has_image = True
        else:
            doc_has_image = False

        if self.doc_use_table and len(sub_document['tables']) != 0:
            doc_has_table = True
        else:
            doc_has_table = False

        sources = preprocess_interleaved_section(
            sub_document, is_multimodal=doc_has_image, is_tabular=doc_has_table, is_title_only=self.only_entity)

        doc_input_ids = []
        for source in sources:
            section_dict = preprocess(
                source,
                self.tokenizer,
                has_image=doc_has_image,
                input_type='document' if not self.concat_sec else 'document_sec',
            )
            doc_input_ids.append(section_dict["input_ids"][0])

        data_dict = dict(query_input_ids=query_dict["input_ids"][0],
                         query_evidence_section_id=0,
                         doc_input_ids=doc_input_ids,
                         doc_num_section=num_sub_secs,
                         )

        # image exist in data
        # For image, the label is positive (first batch query: 0, second batch query: 1)
        # Number would be multiplied in the collator class.
        if query_has_image:
            data_dict['query_image'] = query_image
        elif self.data_args.is_multimodal:
            data_dict['query_image'] = None

        if doc_has_image:
            data_dict['doc_image'] = torch.stack(doc_merged_image_list, dim=0)
        elif self.data_args.is_multimodal:
            data_dict['doc_image'] = None

        return data_dict
    
    def extract_single_section(self, wiki_url, sec_id):
        
        document = self.document_kb[wiki_url]

        section_title = document['section_titles'][sec_id]
        section_text = document['section_texts'][sec_id]

        # Mapping function: document indices to sub-document indices
        doc_to_subdoc_indices = {}
        doc_to_subdoc_indices[sec_id] = 0

        temp_document = {
            'image_section_indices': [doc_to_subdoc_indices[image_sec_idx] 
                                      for image_sec_idx in document['image_section_indices'] 
                                      if image_sec_idx == sec_id],
            'image_urls': [content for idx, content in enumerate(document['image_urls']) 
                           if document['image_section_indices'][idx] == sec_id],
            'image_reference_descriptions': [content for idx, content in 
                                             enumerate(document['image_reference_descriptions']) 
                                             if document['image_section_indices'][idx] == sec_id],
        }

        doc_merged_image_list = []
        image_reference_descriptions = []
        if self.doc_use_image and len(temp_document['image_urls']) != 0:


            doc_image_list = []
            doc_image_descript_list = []
            for idx, (doc_image_descript, doc_image_url) in enumerate(
                    zip(temp_document['image_reference_descriptions'],
                        temp_document['image_urls'])):

                doc_image_file = self.doc_image_url_to_id[doc_image_url]
                doc_image = Image.open(os.path.join(os.path.dirname(self.data_args.document_path),'AToMiC-Images-v0.2/data', doc_image_file)).convert('RGB')
                if self.data_args.image_aspect_ratio == 'pad':
                    doc_image = expand2square(doc_image, tuple(int(x * 255) for x in self.doc_image_processor.image_mean))
                doc_image = self.doc_image_processor.preprocess(doc_image, return_tensors='pt')['pixel_values'][0]
                doc_image_list.append(doc_image)
                doc_image_descript_list.append(doc_image_descript)

                # Check if the section idx will be changed or the end has arrived.
                section_will_change = ((idx == len(temp_document['image_section_indices']) - 1)
                                       or (temp_document['image_section_indices'][idx] != temp_document['image_section_indices'][idx+1]))
                # Merge the four sub-images (W//2 x H//2) into a single image (W x H)
                # This happens when
                # 1. there are four images in the temporal list
                # 2. the section has changed to the next section
                # 3. the for loop ends.
                if section_will_change or len(doc_image_list) == 4:
                    doc_merged_image = self.merge_images(doc_image_list)
                    doc_merged_image_list.append(doc_merged_image)

                    image_reference_descriptions.append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []

        tables = [content for idx, content in enumerate(document['tables'])
                       if document['table_section_indices'][idx] == sec_id]

        return section_title, section_text, doc_merged_image_list, image_reference_descriptions, tables


    def merge_images(self, image_list: List):
        # Note that the Conv2d of visual encoder would scan the image starting
        # from the top-left corner -> top-right -> bottom-left -> bottom-right.
        N, M = self.doc_image_processor.size[0], self.doc_image_processor.size[1]

        merged_img = torch.zeros((3, 2*N, 2*M), dtype=image_list[0].dtype)

        for i in range(len(image_list)):
            if i == 0:
                merged_img[:,:N,:M] = image_list[i]
            elif i == 1:
                merged_img[:,:N,M:] = image_list[i]
            elif i == 2:
                merged_img[:,N:,:M] = image_list[i]
            else:
                merged_img[:,N:,M:] = image_list[i]

        return merged_img

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        # Process query text to make batch
        # Since the query is way shorter than the document, we will process them separately rather than adding a bunch of zero pads to the query.
        if 'query_input_ids' in instances[0]:
            query_input_ids = [instance["query_input_ids"] for instance in instances]
            query_input_ids = torch.nn.utils.rnn.pad_sequence(
                query_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            query_evidence_section_labels = torch.tensor([instance['query_evidence_section_id'] for instance in instances], dtype=torch.int64)
        else:
            query_input_ids, query_evidence_section_labels = None, None

        # Process list of section text to make batch
        # We remove the doc_input_ids = doc_input_ids[:, :self.tokenizer.model_max_length]
        # since we will split the sections into passages
        if 'doc_input_ids' in instances[0]:
            doc_input_ids = [item for instance in instances for item in instance["doc_input_ids"]]
            doc_input_ids = torch.nn.utils.rnn.pad_sequence(
                doc_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            doc_num_sections = torch.tensor([instance['doc_num_section'] for instance in instances], dtype=torch.int64)
        else:
            doc_input_ids, doc_num_sections = None, None

        batch = dict(
            query_input_ids=query_input_ids,
            doc_input_ids=doc_input_ids,
            query_attention_mask=query_input_ids.ne(self.tokenizer.pad_token_id),
            doc_attention_mask=doc_input_ids.ne(self.tokenizer.pad_token_id),
            query_evidence_section_labels=query_evidence_section_labels,
            doc_num_sections=doc_num_sections,
        )

        if any(instance['query_image'] is not None for instance in instances):
            query_images = [instance['query_image'] for instance in instances if instance['query_image'] is not None]  # For the mixed modality query
            batch['query_images'] = torch.stack(query_images)  # At most a single image per query.

        if any(instance['doc_image'] is not None for instance in instances):
            doc_images = [instance['doc_image'] for instance in instances if instance['doc_image'] is not None]  # There could be no-images section
            batch['doc_images'] = torch.cat(doc_images)  # Different amounts of image per document.

        return batch


def build_encyclopedic_vqa_hard_neg_data_modules(
        tokenizer: transformers.PreTrainedTokenizer,
        data_args,
        image_processor,
        is_training = True,
):
    train_dataset = EncyclopedicDocumentHardNegDataset(query_path=data_args.train_query_path, document_path=data_args.document_path,
                                             tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
                                             is_training=is_training, concat_sec=data_args.concat_sec)
    # eval_dataset = EncyclopedicDocumentHardNegDataset(query_path=data_args.eval_query_path, document_path=data_args.document_path,
    #                                         tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
    #                                         is_training=False, concat_sec=data_args.concat_sec)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)
