import torch
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from icecream import ic as pprint
import json


instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

def format_document(document, tokenizer, max_tokens=None):
    if max_tokens is not None:     
        return tokenizer.decode(
                tokenizer(
                document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
                add_special_tokens=False,
            )['input_ids'][:max_tokens]
        )
    
    return tokenizer.decode(
            tokenizer(
            document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
            add_special_tokens=False,
        )['input_ids']
    )

def trunc_text(text, tokenizer, max_tokens=None):
    return tokenizer.decode(
            tokenizer(
            text,
            add_special_tokens=False,
        )['input_ids'][:max_tokens]
    )

class InferDataset(Dataset):
    def __init__(
        self,
        filepath,
        model,
        tokenizer,
        max_doc_tokens,
        instruction_name,
        max_num_documents=None,
        min_num_documents=None,
        random_num_documents=False,
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.model = model
        self.max_num_documents = max_num_documents
        self.min_num_documents = min_num_documents
        self.random_num_documents = random_num_documents
        self.tokenizer = tokenizer

        self.instruction_text = instructions_map[instruction_name]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        example = self.dataset[index]
        question = example['question']
        documents = [
            format_document(document, self.tokenizer)
            for document in example['ctxs'] 
        ]
        prompt = self.instruction_text + "\n\n" + '\n'.join(documents) \
            + "\n\nQuestion: " + question + "\nAnswer:\n"
        answers = example['answers']

        return {
            'question': question,
            'prompt': prompt,
            'answers': answers
        }
    
    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}
        
        bacth_question = [instance['question'] for instance in batch]
        batch_prompts = [instance['prompt'] for instance in batch]
        batch_answers = [instance['answers'] for instance in batch]
        
        return {
            'question': bacth_question,
            'prompt': batch_prompts,
            'answers': batch_answers
        }