import torch
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from icecream import ic as pprint
import json

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'closed_book': 'Answer the following question to the best of your ability. Only give me the answer and do not output any other words.',
    'original': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.'
}

def format_document(document, tokenizer, max_tokens=None):
    if max_tokens is not None:     
        return tokenizer.decode(
                tokenizer(
                document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
                add_special_tokens=False,
            )['input_ids'][:max_tokens]
        )
    
    return tokenizer.decode(
            tokenizer(
            document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
            add_special_tokens=False,
        )['input_ids']
    )

def trunc_text(text, tokenizer, max_tokens=None):
    return tokenizer.decode(
            tokenizer(
            text,
            add_special_tokens=False,
        )['input_ids'][:max_tokens]
    )

class InferDataset(Dataset):
    def __init__(
        self,
        filepath,
        model,
        tokenizer,
        max_doc_tokens,
        instruction_name='base',
        max_num_documents=None,
        min_num_documents=None,
        random_num_documents=False,
        baseline_type='original',
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.model = model
        self.max_num_documents = max_num_documents
        self.min_num_documents = min_num_documents
        self.random_num_documents = random_num_documents
        self.tokenizer = tokenizer
        self.baseline_type = baseline_type

        self.instruction_text = instructions_map[instruction_name]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        example = self.dataset[index]
        question = example['question']
        
        if self.baseline_type == 'closed_book':
            # 对于closed book，直接拼接指令和问题
            prompt = self.instruction_text + "\n\nQuestion: " + question + "\nAnswer:"
        else:  # original prompt
            # 使用所有文档
            documents = [
                format_document(document, self.tokenizer)
                for document in example['ctxs'] 
            ]
            prompt = self.instruction_text + "\n\n" + '\n'.join(documents) + "\nQuestion: " + question + "\nAnswer:"
        
        answers = example['answers']

        return {
            'question': question,
            'prompt': prompt,
            'answers': answers
        }
    
    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}
        
        bacth_question = [instance['question'] for instance in batch]
        batch_prompts = [instance['prompt'] for instance in batch]
        batch_answers = [instance['answers'] for instance in batch]
        
        return {
            'question': bacth_question,
            'prompt': batch_prompts,
            'answers': batch_answers
        } 