import argparse
import os
import json
import copy
import numpy as np
import pandas as pd
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import transformers

from llava.model.builder import load_retriever_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path, expand2square
from llava.constants import *
from llava.data.process import *


from PIL import Image
import math

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

class ViquaeQueryDataset(Dataset):

    def __init__(self, query_path: str, tokenizer: transformers.PreTrainedTokenizer,
                 data_args, image_processor, use_image=True):
        super(ViquaeQueryDataset, self).__init__()
        # Load query dataset
        query_df = pd.read_csv(query_path)
        self.total_dataset_size = len(query_df)
        self.query_df = get_chunk(query_df, data_args.num_chunks, data_args.chunk_idx)  # For using multiple GPUs

        self.tokenizer = tokenizer
        self.data_args = data_args
        # Image processor
        self.query_image_processor = image_processor
        # For other baselines that do not use images in the document.
        self.use_image = use_image

    def __len__(self):
        return len(self.query_df)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        query = self.query_df.iloc[i]

        # Generate image + text query pair
        if self.use_image and 'dataset_image_ids' in query:
            query_image_id = query['dataset_image_ids']
            query_image = Image.open(os.path.join(self.data_args.image_folder, query_image_id)).convert('RGBA')
            if self.data_args.image_aspect_ratio == 'pad':
                query_image = expand2square(query_image, tuple(int(x * 255) for x in self.query_image_processor.image_mean))
            query_image = self.query_image_processor.preprocess(query_image, return_tensors='pt')['pixel_values'][0]
            query_has_image = True
        else:
            query_has_image = False
        sources = preprocess_query(
            copy.deepcopy(query['question']) if query_has_image else query['question_original'], is_multimodal=query_has_image
        )
        query_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=query_has_image,
            input_type='query',
        )

        data_dict = dict(query_input_ids=query_dict["input_ids"][0],
                         query_evidence_section_id=query['evidence_section_id'],
                         data_idx=i,
                         )

        # image exist in data
        # For image, the label is positive (first batch query: 0, second batch query: 1)
        # Number would be multiplied in the collator class.
        if query_has_image:
            data_dict['query_image'] = query_image
        elif self.data_args.is_multimodal:
            data_dict['query_image'] = None

        return data_dict


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        query_input_ids = [instance["query_input_ids"] for instance in instances]
        query_input_ids = torch.nn.utils.rnn.pad_sequence(
            query_input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)

        data_idx = torch.tensor([instance['data_idx'] for instance in instances], dtype=torch.int64)
        query_evidence_section_labels = torch.tensor([instance['query_evidence_section_id'] for instance in instances], dtype=torch.int64)

        batch = dict(
            query_input_ids=query_input_ids,
            query_attention_mask=query_input_ids.ne(self.tokenizer.pad_token_id),
            query_evidence_section_labels=query_evidence_section_labels,
            data_idx=data_idx,
        )

        if any(instance['query_image'] is not None for instance in instances):
            query_images = [instance['query_image'] for instance in instances if instance['query_image'] is not None]  # For the mixed modality query
            batch['query_images'] = torch.stack(query_images)  # At most a single image per query.

        return batch

# Dataloader
def create_data_loader(tokenizer, image_processor, data_args, num_workers=8):
    dataset = ViquaeQueryDataset(query_path=data_args.query_path,
                                tokenizer=tokenizer,
                                image_processor=image_processor,
                                data_args=data_args,
                                use_image=data_args.query_use_image)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    data_loader = DataLoader(dataset, batch_size=data_args.batch_size, num_workers=num_workers,
                             shuffle=False, drop_last=False, collate_fn=data_collator)

    return data_loader


def extract_query_embeds(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_retriever_model(model_path, args.model_base, model_name, args.doc_model_init)
    # Turn off the contrastive loss calculation of the model
    if hasattr(model.config, "inter_contrastive"):
        model.config.inter_contrastive = False
    if hasattr(model.config, "intra_contrastive"):
        model.config.intra_contrastive = False


    conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv]

    # Test query dataset
    data_loader = create_data_loader(tokenizer=tokenizer, image_processor=image_processor, data_args=args)

    # Initialize embeddings
    chunk_size = math.ceil(data_loader.dataset.total_dataset_size / args.num_chunks)
    embed_size = model.config.hidden_size
    emb_array = np.zeros((chunk_size, embed_size),  dtype='float16')

    for batch in tqdm(data_loader, total=len(data_loader)):

        input_batch = {k: v.to(device='cuda', non_blocking=True) if k!="query_images" else v.to(device='cuda', dtype=torch.float16, non_blocking=True)
                       for k, v in batch.items() if k != 'data_idx'}
        with torch.inference_mode():
            outputs = model(
                **input_batch,
                use_cache=False,
                output_hidden_states=True,
            )
            emb_array[batch['data_idx']] = outputs['query_feature'].detach().cpu().numpy()

    emb_array = np.nan_to_num(emb_array)

    embed_name = f'query_embed_{args.chunk_idx}.npy'

    os.makedirs(args.save_path, exist_ok=True)
    np.save(os.path.join(args.save_path, embed_name), emb_array)

    print("Done!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--save_path", type=str, default=None)
    parser.add_argument("--conv", type=str, default="qwen_1_5")
    parser.add_argument("--query_path", type=str, default=None)
    parser.add_argument("--query_use_image", action='store_true', default=False)
    parser.add_argument("--image_folder", type=str, default=None)
    parser.add_argument("--is_multimodal", action='store_true', default=False)
    parser.add_argument("--image_aspect_ratio", type=str, default="pad")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--doc_model_init", action='store_true', default=False)


    args = parser.parse_args()

    query_embeds = extract_query_embeds(args)

