import os
import json
import copy
import random
import pandas as pd
from typing import List, Dict, Optional, Sequence, Tuple

import torch
import transformers
import numpy as np

from torch.utils.data import Dataset, ConcatDataset
from PIL import Image
from dataclasses import dataclass

from llava.conversation import conv_templates
from llava.mm_utils import tokenizer_image_token, expand2square
from llava.data.process import *
from llava.constants import *


class OpenWikiTableDocumentDataset(Dataset):

    def __init__(self, query_path: str, document_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor, is_training=True, concat_sec=False):
        super(OpenWikiTableDocumentDataset, self).__init__()
        # Load query dataset
        query_df = pd.read_csv(query_path)
        # For debugging
        if data_args.debugging:
            query_df = query_df[:100]
        # For subset experiment
        if data_args.train_subset and is_training:
            random_seed = 42  # We will fix the random seed to compare models trained with the same subset dataset.
            subset_size = int(len(query_df) * data_args.train_subset_ratio)
            query_df = query_df.sample(n=subset_size, random_state=random_seed)  # randomly sample about N% of dataset.
        self.query_df = query_df
        # Load id to path mapping dictionary

        # Load document KB
        self.document_kb = json.load(open(document_path, 'r'))

        # Subset number of sections per document.
        # Since the number of sections for each document varies significantly, for stable GPU memory usage,
        # we limit the number of the number of sections selected.
        self.subset_sec_num = data_args.subset_sec_num

        self.tokenizer = tokenizer
        self.data_args = data_args
        self.is_training = is_training
        # Image processor
        self.query_image_processor = image_processor
        # Resize the doc image processing size, to have 2x2 input image.
        self.doc_image_processor = copy.deepcopy(image_processor)
        self.doc_image_processor.size = (self.doc_image_processor.size[0] // 2, self.doc_image_processor.size[1] // 2)

        self.concat_sec = concat_sec

        # Model variants
        self.query_use_image = data_args.query_use_image

        self.doc_use_image = data_args.doc_use_image
        self.doc_use_table = data_args.doc_use_table

        if self.doc_use_image:
            # Load image_url to image_id mapping dictionary
            self.doc_image_url_to_id = json.load(open(data_args.image_url_to_id_path, 'r'))
        else:
            self.doc_image_url_to_id = None

    def __len__(self):
        return len(self.query_df)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        query = self.query_df.iloc[i]
        document = self.document_kb[query['wikipedia_url']]

        query_has_image = False
        sources = preprocess_query(
            copy.deepcopy(query['question']), is_multimodal=query_has_image
        )
        query_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=query_has_image,
            input_type='query' if not self.concat_sec else 'query_sec',
        )

        # Generate interleaved document
        # In this dataset, we use all the sections of the document, but give in a separate sample.
        num_secs = len(document['section_titles'])
        num_sub_secs = min(num_secs, self.subset_sec_num)

        random_sec_indices = np.random.choice([i for i in range(num_secs) if i != query['evidence_section_id']], num_sub_secs - 1, replace=False)
        random_sec_indices = np.append(random_sec_indices, query['evidence_section_id']) # essentially add the evidence section of the query
        random_sec_indices = np.sort(random_sec_indices)
        subdoc_evid_section_id = int(np.where(random_sec_indices == query['evidence_section_id'])[0]) # position of evidence section within the subdocument

        # Mapping function: document indices to sub-document indices
        doc_to_subdoc_indices = {}
        for sub_sec_idx, sec_idx in enumerate(random_sec_indices):
            doc_to_subdoc_indices[sec_idx] = sub_sec_idx

        sub_document = {}
        sub_document['url'] = document['url']
        sub_document['title'] = document['title']
        sub_document['section_titles'] = [content for sec_idx, content in enumerate(document['section_titles']) if sec_idx in random_sec_indices]
        sub_document['section_texts'] = [content for sec_idx, content in enumerate(document['section_texts']) if sec_idx in random_sec_indices]
       
        # In this implementation, we merge the images from the same section into a single image (2x2).
        # Hence, we modify the original meta-data to be grouped by section id.
        temp_document = {
            'image_section_indices': [doc_to_subdoc_indices[image_sec_idx] 
                                      for image_sec_idx in document['image_section_indices'] 
                                      if image_sec_idx in random_sec_indices],
            'image_urls': [content for idx, content in enumerate(document['image_urls']) 
                           if document['image_section_indices'][idx] in random_sec_indices],
            'image_reference_descriptions': [content for idx, content in 
                                             enumerate(document['image_reference_descriptions']) 
                                             if document['image_section_indices'][idx] in random_sec_indices],
        }
        sub_document['image_section_indices'] = []
        sub_document['image_reference_descriptions'] = []
        if self.doc_use_image and len(temp_document['image_urls']) != 0:
            doc_merged_image_list = []

            doc_image_list = []
            doc_image_descript_list = []
            for idx, (doc_image_descript, doc_image_url) in enumerate(
                    zip(temp_document['image_reference_descriptions'],
                        temp_document['image_urls'])):

                doc_image_file = self.doc_image_url_to_id[doc_image_url]
                doc_image = Image.open(os.path.join(os.path.dirname(self.data_args.document_path), 'images', doc_image_file)).convert('RGBA')
                if self.data_args.image_aspect_ratio == 'pad':
                    doc_image = expand2square(doc_image, tuple(int(x * 255) for x in self.doc_image_processor.image_mean))
                doc_image = self.doc_image_processor.preprocess(doc_image, return_tensors='pt')['pixel_values'][0]
                doc_image_list.append(doc_image)
                doc_image_descript_list.append(doc_image_descript)

                # Check if the section idx will be changed or the end has arrived.
                section_will_change = ((idx == len(temp_document['image_section_indices']) - 1)
                                       or (temp_document['image_section_indices'][idx] != temp_document['image_section_indices'][idx+1]))

                # Merge the four sub-images (W//2 x H//2) into a single image (W x H)
                # This happens when
                # 1. there are four images in the temporal list
                # 2. the section has changed to the next section
                # 3. the for loop ends.
                if section_will_change or len(doc_image_list) == 4:
                    doc_merged_image = self.merge_images(doc_image_list)
                    doc_merged_image_list.append(doc_merged_image)

                    current_section_id = temp_document['image_section_indices'][idx]
                    sub_document['image_section_indices'].append(current_section_id)
                    sub_document['image_reference_descriptions'].append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []
            doc_has_image = True
        else:
            doc_has_image = False

        # Add tabular modality
        sub_document['table_section_indices'] = [doc_to_subdoc_indices[table_sec_idx]
                                      for table_sec_idx in document['table_section_indices']
                                      if table_sec_idx in random_sec_indices]
        sub_document['tables'] = [content for idx, content in enumerate(document['tables'])
                       if document['table_section_indices'][idx] in random_sec_indices]
        if self.doc_use_table and len(sub_document['tables']) != 0:
            doc_has_table = True
        else:
            doc_has_table = False

        sources = preprocess_interleaved_section(
            sub_document, is_multimodal=doc_has_image, is_tabular=doc_has_table)

        doc_input_ids = []
        for source in sources:
            section_dict = preprocess(
                source,
                self.tokenizer,
                has_image=doc_has_image,
                input_type='document' if not self.concat_sec else 'document_sec',
            )
            doc_input_ids.append(section_dict["input_ids"][0])

        assert subdoc_evid_section_id < num_sub_secs, "The section evidence index is longer than the number of sections"

        data_dict = dict(query_input_ids=query_dict["input_ids"][0],
                         query_evidence_section_id=subdoc_evid_section_id,
                         doc_input_ids=doc_input_ids,
                         doc_num_section=num_sub_secs,
                         )

        # image exist in data
        # For image, the label is positive (first batch query: 0, second batch query: 1)
        # Number would be multiplied in the collator class.
        if self.data_args.is_multimodal:
            data_dict['query_image'] = None

        if doc_has_image:
            data_dict['doc_image'] = torch.stack(doc_merged_image_list, dim=0)
        elif self.data_args.is_multimodal:
            data_dict['doc_image'] = None

        return data_dict

    def merge_images(self, image_list: List):
        # Note that the Conv2d of visual encoder would scan the image starting
        # from the top-left corner -> top-right -> bottom-left -> bottom-right.
        N, M = self.doc_image_processor.size[0], self.doc_image_processor.size[1]

        merged_img = torch.zeros((3, 2*N, 2*M), dtype=image_list[0].dtype)

        for i in range(len(image_list)):
            if i == 0:
                merged_img[:,:N,:M] = image_list[i]
            elif i == 1:
                merged_img[:,:N,M:] = image_list[i]
            elif i == 2:
                merged_img[:,N:,:M] = image_list[i]
            else:
                merged_img[:,N:,M:] = image_list[i]

        return merged_img

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        # Process query text to make batch
        # Since the query is way shorter than the document, we will process them separately rather than adding a bunch of zero pads to the query.
        if 'query_input_ids' in instances[0]:
            query_input_ids = [instance["query_input_ids"] for instance in instances]
            query_input_ids = torch.nn.utils.rnn.pad_sequence(
                query_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            query_evidence_section_labels = torch.tensor([instance['query_evidence_section_id'] for instance in instances], dtype=torch.int64)
        else:
            query_input_ids, query_evidence_section_labels = None, None

        # Process list of section text to make batch
        # We remove the doc_input_ids = doc_input_ids[:, :self.tokenizer.model_max_length]
        # since we will split the sections into passages
        if 'doc_input_ids' in instances[0]:
            doc_input_ids = [item for instance in instances for item in instance["doc_input_ids"]]
            doc_input_ids = torch.nn.utils.rnn.pad_sequence(
                doc_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            doc_num_sections = torch.tensor([instance['doc_num_section'] for instance in instances], dtype=torch.int64)
        else:
            doc_input_ids, doc_num_sections = None, None

        batch = dict(
            query_input_ids=query_input_ids,
            doc_input_ids=doc_input_ids,
            query_attention_mask=query_input_ids.ne(self.tokenizer.pad_token_id),
            doc_attention_mask=doc_input_ids.ne(self.tokenizer.pad_token_id),
            query_evidence_section_labels=query_evidence_section_labels,
            doc_num_sections=doc_num_sections,
        )

        if any(instance['query_image'] is not None for instance in instances):
            query_images = [instance['query_image'] for instance in instances if instance['query_image'] is not None]  # For the mixed modality query
            batch['query_images'] = torch.stack(query_images)  # At most a single image per query.

        if any(instance['doc_image'] is not None for instance in instances):
            doc_images = [instance['doc_image'] for instance in instances if instance['doc_image'] is not None]  # There could be no-images section
            batch['doc_images'] = torch.cat(doc_images)  # Different amounts of image per document.

        return batch


def build_open_wikitable_data_modules(
        tokenizer: transformers.PreTrainedTokenizer,
        data_args,
        image_processor,
        is_training = True,
):
    train_dataset = OpenWikiTableDocumentDataset(query_path=data_args.train_query_path, document_path=data_args.document_path,
                                             tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
                                             is_training=is_training, concat_sec=data_args.concat_sec)
    eval_dataset = OpenWikiTableDocumentDataset(query_path=data_args.eval_query_path, document_path=data_args.document_path,
                                            tokenizer=tokenizer, data_args=data_args, image_processor=image_processor,
                                            is_training=False, concat_sec=data_args.concat_sec)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator)
