import argparse
import os
import json
import copy
import time
import numpy as np
import pandas as pd
import pytrec_eval
from llava.eval.beir.faiss_index import FaissIndex
from llava.eval.beir.custom_metrics import mrr, recall_cap, hole, top_k_accuracy
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import transformers

from llava.model.builder import load_reranker_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path, expand2square

from typing import List, Dict, Optional, Sequence, Tuple
from collections import OrderedDict
from llava.constants import *
from llava.data.process import *

from PIL import Image
import math


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


class EncyclopedicRerankDataset(Dataset):

    def __init__(self, query_doc_list: List, query_path: str, document_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor):
        super(EncyclopedicRerankDataset, self).__init__()
        # Load query dataset
        self.query_df = pd.read_csv(query_path)
        # Load id to path mapping dictionary
        self.query_id_to_path = json.load(open(data_args.dataset_id_to_path, 'r'))

        # Load document KB
        self.document_kb = json.load(open(document_path, 'r'))

        self.query_doc_list = query_doc_list
        self.total_dataset_size = len(query_doc_list)

        self.tokenizer = tokenizer
        self.data_args = data_args
        # Image processor
        self.query_image_processor = image_processor
        # Resize the doc image processing size, to have 2x2 input image.
        self.doc_image_processor = copy.deepcopy(image_processor)
        self.doc_image_processor.size = (self.doc_image_processor.size[0] // 2, self.doc_image_processor.size[1] // 2)

        # query and doc modality
        self.query_use_image = data_args.query_use_image
        self.doc_use_image = data_args.doc_use_image
        self.doc_use_table = data_args.doc_use_table

        if self.doc_use_image:
            # Load image_url to image_id mapping dictionary
            self.doc_image_url_to_id = json.load(open(data_args.image_url_to_id_path, 'r'))
        else:
            self.doc_image_url_to_id = None

    def __len__(self):
        return len(self.query_doc_list)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        query_id, doc_id = self.query_doc_list[i]

        query = self.query_df.iloc[query_id]
        document = self.document_kb[doc_id]

        num_sub_secs = len(document['section_titles'])

        # Generate image + text query pair
        if self.query_use_image and 'dataset_image_ids' in query:
            query_image_id = query['dataset_image_ids']
            query_image_file = self.query_id_to_path[str(query_image_id)]
            query_image = Image.open(os.path.join(self.data_args.image_folder, query_image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                query_image = expand2square(query_image, tuple(int(x * 255) for x in self.query_image_processor.image_mean))
            query_image = self.query_image_processor.preprocess(query_image, return_tensors='pt')['pixel_values'][0] 
            query_has_image = True
        else:
            query_has_image = False
        sources = preprocess_query(
            copy.deepcopy(query['question'] if query_has_image else query['question_original']), is_multimodal=query_has_image
        )
        query_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=query_has_image,
            input_type='query_sec',
        )

        # Generate interleaved document
        merge_document = {k: [] if k.startswith('image') else v for k,v in document.items()}

        # Generate interleaved document
        if self.doc_use_image and len(document['image_urls']) != 0:
            doc_merged_image_list = []

            doc_image_list = []
            doc_image_descript_list = []
            for idx, (doc_image_descript, doc_image_url) in enumerate(
                    zip(document['image_reference_descriptions'],
                        document['image_urls'])):

                doc_image_file = self.doc_image_url_to_id[doc_image_url]
                doc_image = Image.open(os.path.join(os.path.dirname(self.data_args.document_path),'AToMiC-Images-v0.2/data', doc_image_file)).convert('RGB')
                if self.data_args.image_aspect_ratio == 'pad':
                    doc_image = expand2square(doc_image, tuple(int(x * 255) for x in self.doc_image_processor.image_mean))
                doc_image = self.doc_image_processor.preprocess(doc_image, return_tensors='pt')['pixel_values'][0]

                doc_image_list.append(doc_image)
                doc_image_descript_list.append(doc_image_descript)

                # Check if the section idx will be changed or the end has arrived.
                section_will_change = ((idx == len(document['image_section_indices']) - 1)
                                       or (document['image_section_indices'][idx] != document['image_section_indices'][idx+1]))

                # Merge the four sub-images (W//2 x H//2) into a single image (W x H)
                # This happens when
                # 1. there are four images in the temporal list
                # 2. the section has changed to the next section
                # 3. the for loop ends.
                if section_will_change or len(doc_image_list) == 4:
                    doc_merged_image = self.merge_images(doc_image_list)
                    doc_merged_image_list.append(doc_merged_image)

                    merge_document['image_section_indices'].append(document['image_section_indices'][idx])
                    merge_document['image_reference_descriptions'].append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []

            doc_has_image = True
        else:
            doc_has_image = False

        if self.doc_use_table and len(document['tables']) != 0:
            doc_has_table = True
        else:
            doc_has_table = False

        sources = preprocess_interleaved_section(
            merge_document, is_multimodal=doc_has_image, is_tabular=doc_has_table)

        doc_input_ids = []
        for source in sources:
            section_dict = preprocess(
                source,
                self.tokenizer,
                has_image=doc_has_image,
                input_type='document',
            )
            doc_input_ids.append(section_dict["input_ids"][0])

        data_dict = dict(query_input_ids=query_dict["input_ids"][0],
                         doc_input_ids=doc_input_ids,
                         data_idx=i,
                         doc_num_section=num_sub_secs,
                         )

        # image exist in data
        # For image, the label is positive (first batch query: 0, second batch query: 1)
        # Number would be multiplied in the collator class.
        if query_has_image:
            data_dict['query_image'] = query_image
        elif self.data_args.is_multimodal:
            data_dict['query_image'] = None

        if doc_has_image:
            data_dict['doc_image'] = torch.stack(doc_merged_image_list, dim=0)
        elif self.data_args.is_multimodal:
            data_dict['doc_image'] = None

        return data_dict

    def merge_images(self, image_list: List):
        # Note that the Conv2d of visual encoder would scan the image starting
        # from the top-left corner -> top-right -> bottom-left -> bottom-right.
        N, M = self.doc_image_processor.size[0], self.doc_image_processor.size[1]

        merged_img = torch.zeros((3, 2*N, 2*M), dtype=image_list[0].dtype)

        for i in range(len(image_list)):
            if i == 0:
                merged_img[:,:N,:M] = image_list[i]
            elif i == 1:
                merged_img[:,:N,M:] = image_list[i]
            elif i == 2:
                merged_img[:,N:,:M] = image_list[i]
            else:
                merged_img[:,N:,M:] = image_list[i]

        return merged_img

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        # Process query text to make batch
        # Since the query is way shorter than the document, we will process them separately rather than adding a bunch of zero pads to the query.
        if 'query_input_ids' in instances[0]:
            query_input_ids = [instance["query_input_ids"] for instance in instances]
            query_input_ids = torch.nn.utils.rnn.pad_sequence(
                query_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
        else:
            query_input_ids = None

        # Process list of section text to make batch
        # We remove the doc_input_ids = doc_input_ids[:, :self.tokenizer.model_max_length]
        # since we will split the sections into passages
        if 'doc_input_ids' in instances[0]:
            doc_input_ids = [item for instance in instances for item in instance["doc_input_ids"]]
            doc_input_ids = torch.nn.utils.rnn.pad_sequence(
                doc_input_ids,
                batch_first=True,
                padding_value=self.tokenizer.pad_token_id)
            doc_num_sections = torch.tensor([instance['doc_num_section'] for instance in instances], dtype=torch.int64)
        else:
            doc_input_ids, doc_num_sections = None, None

        data_idx = torch.tensor([instance['data_idx'] for instance in instances], dtype=torch.int64)
        batch = dict(
            query_input_ids=query_input_ids,
            doc_input_ids=doc_input_ids,
            query_attention_mask=query_input_ids.ne(self.tokenizer.pad_token_id),
            doc_attention_mask=doc_input_ids.ne(self.tokenizer.pad_token_id),
            doc_num_sections=doc_num_sections,
            data_idx=data_idx,
        )

        if any(instance['query_image'] is not None for instance in instances):
            query_images = [instance['query_image'] for instance in instances if instance['query_image'] is not None]  # For the mixed modality query
            batch['query_images'] = torch.stack(query_images)  # At most a single image per query.

        if any(instance['doc_image'] is not None for instance in instances):
            doc_images = [instance['doc_image'] for instance in instances if instance['doc_image'] is not None]  # There could be no-images section
            batch['doc_images'] = torch.cat(doc_images)  # Different amounts of image per document.

        return batch


# Dataloader
def create_data_loader(query_doc_list, tokenizer, image_processor, data_args, num_workers=8):
    dataset = EncyclopedicRerankDataset(query_doc_list=query_doc_list,
                                        query_path=data_args.query_path,
                                        document_path=data_args.document_path,
                                        tokenizer=tokenizer,
                                        image_processor=image_processor,
                                        data_args=data_args,)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    data_loader = DataLoader(dataset, batch_size=data_args.batch_size, num_workers=num_workers,
                             shuffle=False, drop_last=False, collate_fn=data_collator)

    return data_loader


def rerank_document(args):

    test_queries = pd.read_csv(args.query_path)

    # Perform reranking
    # Here, we collect top-k document retrieval information and conduct the reranking.
    retrieval_result = json.load(open(args.ret_result_path, 'r'))
    # Collect query-document pair to be evaluated.
    eval_q_d_list = []
    for q_idx in range(len(test_queries)):
        query_ret_results = retrieval_result[str(q_idx)]
        doc_sorted = sorted(query_ret_results, key=query_ret_results.get, reverse=True)[:args.document_top_k]
        eval_q_d_list.extend([[q_idx, doc_id] for doc_id in doc_sorted])

    eval_q_d_list = get_chunk(eval_q_d_list, args.num_chunks, args.chunk_idx)

    # Load the trained re-ranker
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_reranker_model(model_path, args.model_base, model_name)

    conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv]

    # Construct test dataset
    data_loader = create_data_loader(query_doc_list=eval_q_d_list, tokenizer=tokenizer, 
                                     image_processor=image_processor, data_args=args)

    sec_rerank_results = {}
    # Relevancy scoring
    for batch in tqdm(data_loader, total=len(data_loader)):

        input_batch = {k: v.to(device='cuda', non_blocking=True) if k!="doc_images" else v.to(device='cuda', dtype=torch.float16, non_blocking=True)
                       for k, v in batch.items() if k != 'data_idx'}
        with torch.inference_mode():
            outputs = model(
                **input_batch,
                use_cache=False,
                output_hidden_states=True,
            )

            batch_score = outputs['predictions']

            num_account_sample = 0
            for batch_idx, data_idx in enumerate(batch['data_idx']):
                query_idx, doc_idx = eval_q_d_list[data_idx]

                if str(query_idx) not in sec_rerank_results:
                    sec_rerank_results[str(query_idx)] = {}
                
                for i in range(batch['doc_num_sections'][batch_idx]):
                    sec_score = batch_score[num_account_sample + i]
                    sec_id = doc_idx + f'_section_{i:02d}'
                    sec_rerank_results[str(query_idx)][sec_id] = sec_score.item()
                
                num_account_sample += batch['doc_num_sections'][batch_idx].item()
    
    os.makedirs(os.path.dirname(args.rerank_result_save_path), exist_ok=True)
    with open(args.rerank_result_save_path + f'_{args.chunk_idx}.json', 'w') as f:
        json.dump(sec_rerank_results, f)

# If you are unfamiliar with the retrieval metrics and pytrec_eval module, refer to the link below:
# https://weaviate.io/blog/retrieval-evaluation-metrics

# We refer to the beir module to save the query embedding and iteratively compute the document embedding
# Since our input is much more complicated than the beir default input,
# we modify the structure, but follows the same flow of retrieval
# 1. Unlike beir, we pre-compute the query embeddings and document embeddings. If we have multiple GPUs, we can extract faster.
# 2. Then, using the embeddings and faiss module, we retrieve the top-k similar pairs.
# 3. For the intra-document retrieval, we load the test query csv file to extract ground-truth section id.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--conv", type=str, default="qwen_1_5")
    parser.add_argument("--document_path", type=str, default=None)
    parser.add_argument("--image_url_to_id_path", type=str, default=None)
    parser.add_argument("--doc_use_image", action='store_true')
    parser.add_argument("--doc_use_table", action='store_true')
    parser.add_argument("--query_path", type=str, default=None)
    parser.add_argument("--dataset_id_to_path", type=str, default=None)
    parser.add_argument("--query_use_image", action='store_true', default=False)
    parser.add_argument("--image_folder", type=str, default=None)
    parser.add_argument("--is_multimodal", action='store_true')
    parser.add_argument("--image_aspect_ratio", type=str, default="pad")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--document_top_k", type=int, default=25)
    parser.add_argument("--ret_result_path", type=str, default=None)
    parser.add_argument("--rerank_result_save_path", type=str, default=None)
    args = parser.parse_args()

    rerank_document(args)
