import argparse
import os
import json
import numpy as np
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import transformers

from llava.model.builder import load_retriever_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path, expand2square
from llava.constants import *
from llava.data.process import *


from PIL import Image
import math

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

class OpenWikiTableRetDataset(Dataset):

    def __init__(self, document_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor, use_image=True, use_table=True):
        super(OpenWikiTableRetDataset, self).__init__()
        # Load document KB
        document_kb = json.load(open(document_path, 'r'))
        document_kb_keys = list(document_kb.keys())
        self.total_dataset_size = len(document_kb_keys)
        self.document_kb_keys = get_chunk(document_kb_keys, data_args.num_chunks, data_args.chunk_idx) # For using multiple GPUs
        self.document_kb = dict()
        for kb_key in self.document_kb_keys:
            self.document_kb[kb_key] = document_kb[kb_key]

        self.tokenizer = tokenizer
        self.data_args = data_args
        # Resize the doc image processing size, to have 2x2 input image.
        self.doc_image_processor = image_processor
        self.doc_image_processor.size = (self.doc_image_processor.size[0] // 2, self.doc_image_processor.size[1] // 2)
        # For other baselines that do not use images in the document.
        self.use_image = use_image
        self.use_table = use_table

        if self.use_image:
            # Load image_url to image_id mapping dictionary
            self.doc_image_url_to_id = json.load(open(data_args.image_url_to_id_path, 'r'))
        else:
            self.doc_image_url_to_id = None

    def __len__(self):
        return len(self.document_kb_keys)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        doc_kb_wiki_url = self.document_kb_keys[i]
        document = self.document_kb[doc_kb_wiki_url]

        num_sub_secs = len(document['section_titles'])

        merge_document = {k: [] if k.startswith('image') else v for k,v in document.items()}

        # Generate interleaved document
        if self.use_image and len(document['image_urls']) != 0:
            doc_merged_image_list = []

            doc_image_list = []
            doc_image_descript_list = []
            for idx, (doc_image_descript, doc_image_url) in enumerate(
                    zip(document['image_reference_descriptions'],
                        document['image_urls'])):

                doc_image_file = self.doc_image_url_to_id[doc_image_url]
                doc_image = Image.open(os.path.join(os.path.dirname(self.data_args.document_path), 'images', doc_image_file)).convert('RGBA')
                if self.data_args.image_aspect_ratio == 'pad':
                    doc_image = expand2square(doc_image, tuple(int(x * 255) for x in self.doc_image_processor.image_mean))
                doc_image = self.doc_image_processor.preprocess(doc_image, return_tensors='pt')['pixel_values'][0]

                doc_image_list.append(doc_image)
                doc_image_descript_list.append(doc_image_descript)

                # Check if the section idx will be changed or the end has arrived.
                section_will_change = ((idx == len(document['image_section_indices']) - 1)
                                       or (document['image_section_indices'][idx] != document['image_section_indices'][idx+1]))

                # Merge the four sub-images (W//2 x H//2) into a single image (W x H)
                # This happens when
                # 1. there are four images in the temporal list
                # 2. the section has changed to the next section
                # 3. the for loop ends.
                if section_will_change or len(doc_image_list) == 4:
                    doc_merged_image = self.merge_images(doc_image_list)
                    doc_merged_image_list.append(doc_merged_image)

                    merge_document['image_section_indices'].append(document['image_section_indices'][idx])
                    merge_document['image_reference_descriptions'].append('|'.join(doc_image_descript_list))

                    doc_image_list = []
                    doc_image_descript_list = []

            doc_has_image = True
        else:
            doc_has_image = False

        if self.use_table and len(document['tables']) != 0:
            doc_has_table = True
        else:
            doc_has_table = False

        sources = preprocess_interleaved_section(
            merge_document, is_multimodal=doc_has_image, is_tabular=doc_has_table)

        doc_input_ids = []
        for source in sources:
            section_dict = preprocess(
                source,
                self.tokenizer,
                has_image=doc_has_image,
                input_type='document',
            )
            doc_input_ids.append(section_dict["input_ids"][0])

        data_dict = dict(doc_input_ids=doc_input_ids,
                         data_idx=i,
                         doc_num_section=num_sub_secs,
                         )

        if doc_has_image:
            data_dict['doc_image'] = torch.stack(doc_merged_image_list, dim=0)
        elif self.data_args.is_multimodal:
            data_dict['doc_image'] = None

        return data_dict

    def merge_images(self, image_list: List):
        # Note that the Conv2d of visual encoder would scan the image starting
        # from the top-left corner -> top-right -> bottom-left -> bottom-right.
        N, M = self.doc_image_processor.size[0], self.doc_image_processor.size[1]

        merged_img = torch.zeros((3, 2*N, 2*M), dtype=image_list[0].dtype)

        for i in range(len(image_list)):
            if i == 0:
                merged_img[:,:N,:M] = image_list[i]
            elif i == 1:
                merged_img[:,:N,M:] = image_list[i]
            elif i == 2:
                merged_img[:,N:,:M] = image_list[i]
            else:
                merged_img[:,N:,M:] = image_list[i]

        return merged_img


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        doc_input_ids = [item for instance in instances for item in instance["doc_input_ids"]]
        doc_input_ids = torch.nn.utils.rnn.pad_sequence(
            doc_input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)

        data_idx = torch.tensor([instance['data_idx'] for instance in instances], dtype=torch.int64)
        doc_num_sections = torch.tensor([instance['doc_num_section'] for instance in instances], dtype=torch.int64)

        batch = dict(
            doc_input_ids=doc_input_ids,
            doc_attention_mask=doc_input_ids.ne(self.tokenizer.pad_token_id),
            doc_num_sections=doc_num_sections,
            data_idx=data_idx,
        )

        if any(instance['doc_image'] is not None for instance in instances):
            doc_images = [instance['doc_image'] for instance in instances if instance['doc_image'] is not None]  # There could be no-images section
            batch['doc_images'] = torch.cat(doc_images)  # Different amounts of image per document.

        return batch

# Dataloader
def create_data_loader(tokenizer, image_processor, data_args, num_workers=8):
    dataset = OpenWikiTableRetDataset(document_path=data_args.document_path,
                                tokenizer=tokenizer,
                                image_processor=image_processor,
                                data_args=data_args,
                                use_image=data_args.doc_use_image,
                                use_table=data_args.doc_use_table,)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    data_loader = DataLoader(dataset, batch_size=data_args.batch_size, num_workers=num_workers,
                             shuffle=False, drop_last=False, collate_fn=data_collator)

    return data_loader


def extract_doc_embeds(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_retriever_model(model_path, args.model_base, model_name, args.doc_model_init)
    # Turn off the contrastive loss calculation of the model
    if hasattr(model.config, "inter_contrastive"):
        model.config.inter_contrastive = False
    if hasattr(model.config, "intra_contrastive"):
        model.config.intra_contrastive = False

    conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv]

    # KB dataset
    data_loader = create_data_loader(tokenizer=tokenizer, image_processor=image_processor, data_args=args)

    # Initialize embeddings
    inter_doc_chunk_size = len(data_loader.dataset)
    embed_size = model.config.hidden_size
    inter_doc_emb_array = np.zeros((inter_doc_chunk_size, embed_size),  dtype='float16')

    # inter-document id mapping function (will be used in retrieval)
    inter_doc_id_list = data_loader.dataset.document_kb_keys
    inter_doc_mapping = dict()
    for idx, doc_id in enumerate(inter_doc_id_list):
        inter_doc_mapping[doc_id] = [idx + inter_doc_chunk_size * args.chunk_idx]
        

    for batch in tqdm(data_loader, total=len(data_loader)):

        input_batch = {k: v.to(device='cuda', non_blocking=True) if k!="doc_images" else v.to(device='cuda', dtype=torch.float16, non_blocking=True)
                       for k, v in batch.items() if k != 'data_idx'}
        with torch.inference_mode():
            outputs = model(
                **input_batch,
                use_cache=False,
                output_hidden_states=True,
            )
            inter_doc_emb_array[batch['data_idx']] = outputs['inter_doc_feature'].detach().cpu().numpy()

    inter_doc_emb_array = np.nan_to_num(inter_doc_emb_array)

    os.makedirs(args.save_path, exist_ok=True)

    # Save embeddings
    inter_doc_embed_name = f'inter_doc_embed_{args.chunk_idx}.npy'
    np.save(os.path.join(args.save_path, inter_doc_embed_name), inter_doc_emb_array)

    # Save mapping functions
    inter_doc_mapping_name = f'inter_doc_mapping_{args.chunk_idx}.json'
    with open(os.path.join(args.save_path, inter_doc_mapping_name), 'w') as f:
        json.dump(inter_doc_mapping, f)

    print("Done!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--save_path", type=str, default=None)
    parser.add_argument("--conv", type=str, default="qwen_1_5")
    parser.add_argument("--document_path", type=str, default=None)
    parser.add_argument("--image_url_to_id_path", type=str, default=None)
    parser.add_argument("--doc_use_image", action='store_true')
    parser.add_argument("--doc_use_table", action='store_true')
    parser.add_argument("--is_multimodal", action='store_true')
    parser.add_argument("--image_aspect_ratio", type=str, default="pad")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--doc_model_init", action='store_true', default=False)

    args = parser.parse_args()

    doc_embeds = extract_doc_embeds(args)

