import torch
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from icecream import ic as pprint
import json

# ins, target, no question, position of instruction, info
instructions_map = {
    'base': 'Write a high-quality answer for the given question using only the provided search results(some of which might be irrelevant).\n',
    'short': 'Answer the Question based on the given Text. Only give me the answer and do not output any other words.\nText:',
    'summary': 'Summarize the passage according to the question provided.\n\n',
    'restatement':'Restate the aforementioned Text.\n\n',  # 复述   ,
    'icae':'Write a summary for the above text. Your summary should not exceed 100 words but should include as much information of the original text as possible.\n\n',
    'ft_prefix': 'Text:',
    'rs_prefix': 'Text:',
    'pwc': 'Answer the Question based on the given Text. \nText:'
}

def format_document(
    document, 
    tokenizer, 
    max_tokens=None
):
    if max_tokens is None:
        return tokenizer.decode(
                tokenizer(
                document,
                add_special_tokens=False,
            )['input_ids']
        )
    return tokenizer.decode(
            tokenizer(
            document,
            add_special_tokens=False,
        )['input_ids'][:max_tokens]
    )

class TrainDataset(Dataset):
    def __init__(
        self,
        filepath,
        model,
        max_doc_tokens,
        instruction_name,
        lm_ratio=0,
        leave_len=0,
        pad_token_id=0,
        # max_num_documents=None,
        # min_num_documents=None,
        # random_num_documents=False,
        # num_gold_documents=1,
        # use_answer_as_target=False,
        # gold_first_for_kd=False,
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.model = model
        self.cmp_tokenizer = kwargs['cmp_tokenizer']
        self.llm_tokenizer = kwargs['llm_tokenizer']
        self.lm_ratio = lm_ratio
        self.leave_len = leave_len
        self.pad_token_id = self.cmp_tokenizer.pad_token_id
        self.prefix_type = kwargs['prefix_type']

        self.llm_tokenizer.padding_side = 'left'
        # pprint(self.llm_tokenizer.pad_token)
        # pprint(self.llm_tokenizer.pad_token_id)
        # exit(0)
        if self.llm_tokenizer.pad_token is None:
            self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.unk_token_id

        # self.cmp_tokenizer.padding_side = 'left'
        # self.llm_tokenizer.padding_side = 'left'
            
        # if self.cmp_tokenizer.pad_token is None:
        #     self.cmp_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        #     self.cmp_tokenizer.pad_token = "[PAD]"
        #     self.cmp_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")
            
        # if self.llm_tokenizer.pad_token is None:
        #     self.llm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        #     self.llm_tokenizer.pad_token = "[PAD]"
        #     self.llm_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")

        self.instruction_text = instructions_map[instruction_name]
        self.prefix_text = instructions_map[self.prefix_type]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        example = self.dataset[index]
        
        # question = example['prompt']
        question = example['prompt']
        document = format_document(example['input'], self.cmp_tokenizer, self.max_doc_tokens)
        # answer = example['answer']  
        answer = example['answer']

        return {
            'question': question,
            'document': document,
            'answer': answer,
        }

    def random_choice(self, l):
        return random.choice(l)

    def text_extraction(
        self,
        input_ids, # dim : [seq_len]
    ):
        input_len = len(input_ids)
        # ae
        if random.random() >= self.lm_ratio: 
            return input_ids, input_ids
        
        # lm    
        r = random.randint(0, input_len - self.leave_len - 1)
        return input_ids[ : r + 1], input_ids[r + 1 : ]

    def dynamic_padding(
        self, 
        sequences, 
        fill_value=-100
    ):
        padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=fill_value)
        padded_attention_mask = (padded_sequences != fill_value).long()
        return padded_sequences, padded_attention_mask

    def process_batch(self, batch):
        batch_input_ids = [self.cmp_tokenizer(instance['document'], return_tensors='pt').input_ids[0] for instance in batch]
        
        enc_doc_ids_list, target_doc_ids_list = zip(*[self.text_extraction(input_ids) for input_ids in batch_input_ids])
        
        enc_doc_ids, enc_doc_mask = self.dynamic_padding(enc_doc_ids_list, self.pad_token_id)
        target_doc_ids, target_doc_mask = self.dynamic_padding(target_doc_ids_list, self.pad_token_id)
        
        return enc_doc_ids, enc_doc_mask, target_doc_ids, target_doc_mask

    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}
        if self.lm_ratio != 0:
            enc_doc_ids, enc_doc_mask, target_doc_ids, target_doc_mask = self.process_batch(batch)
        else:
            enc_batch_documents = [instance['document'] for instance in batch]
            eos_token_id = self.llm_tokenizer.eos_token_id
            eos_token = self.llm_tokenizer.decode([eos_token_id])    
            target_batch_documents = [instance['document'] + eos_token for instance in batch]
        
            tokenize_docs = self.cmp_tokenizer(enc_batch_documents, \
                return_tensors='pt', padding=True, add_special_tokens=False)
            
            enc_doc_ids = tokenize_docs.input_ids
            enc_doc_mask = tokenize_docs.attention_mask

            target_tokenize_docs = self.llm_tokenizer(target_batch_documents, \
                return_tensors='pt', padding=True, add_special_tokens=False)
            
            target_doc_ids = target_tokenize_docs.input_ids
            target_doc_mask = target_tokenize_docs.attention_mask

        enc_questions = [instance['question'] for instance in batch]
        llm_questions = ['Question:' + instance['question'] + '\nAnswer:' for instance in batch]

        llm_answer = [self.random_choice(instance['answer']) + "</s>" for instance in batch]
        llm_instructions = [self.instruction_text for _ in batch]
        # answers = [instance['answers'] for instance in batch]

        llm_prefix_tokens = [self.prefix_text for _ in enc_questions]
        repeat_tokens = ['Restate the aforementioned Text.' for _ in enc_questions]
        continue_tokens = ['Continue writing the aforementioned Text.' for _ in enc_questions]
        enc_continue_outputs = self.llm_tokenizer(continue_tokens, return_tensors='pt', padding=True, add_special_tokens=False)
        
        enc_questions_outputs = self.llm_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_prefix_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True)
        enc_repeat_outputs = self.llm_tokenizer(repeat_tokens, return_tensors='pt', padding=True, add_special_tokens=False)
        # raw_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        
        # if self.lm_ratio != 0:
        #     enc_doc_ids, enc_doc_mask, target_doc_ids, target_doc_mask = self.batch_text_extraction(
        #         raw_doc_outputs.input_ids,
        #         raw_doc_outputs.attention_mask
        #     )
        # else:
        #     enc_doc_ids = raw_doc_outputs.input_ids
        #     enc_doc_mask = raw_doc_outputs.attention_mask
            
        #     target_doc_ids = raw_doc_outputs.input_ids
        #     target_doc_mask = raw_doc_outputs.attention_mask
        
        # total_mem_size = self.model.mem_size * self.model.compute_num_segments(len(enc_doc_outputs.input_ids[0]))
        llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True)
        # llm_doc_outputs = self.llm_tokenizer(llm_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        llm_answer_outputs = self.llm_tokenizer(llm_answer, return_tensors='pt', padding=True, add_special_tokens=False)
        # memorys = torch.zeros((llm_ins_outputs["input_ids"].shape[0], total_mem_size))
        # # exit(0)

        # def right_padding(value, padding_value):
        #     padded_value = pad_sequence(
        #         [torch.tensor(v) for v in value],
        #         batch_first=True,
        #         padding_value=padding_value,
        #     )
        #     return padded_value

        # llm_tgt_outputs = [self.llm_tokenizer(ans, add_special_tokens=False).input_ids for ans in llm_answer]
        # llm_tgt_tokens = right_padding(llm_tgt_outputs, self.llm_tokenizer.pad_token_id)
        # llm_tgt_mask = right_padding([[1] * len(elem) for elem in llm_tgt_outputs], 0)
        # # if self.que_mask_ratio is not None and self.que_mask_ratio > 0:
        # #     llm_que_tokens = llm_que_outputs.input_ids
        # #     random_indices = torch.rand_like(llm_que_outputs.input_ids[:, :-2].float()).sort().indices
        # #     mask_indices = random_indices[:, :int(self.que_mask_ratio * llm_que_tokens.size(1))]
        # #     llm_que_outputs.input_ids = llm_que_tokens.scatter(1, mask_indices, self.llm_tokenizer.pad_token_id)

        # # process 'labels'
        # llm_ins_mask = llm_ins_outputs.attention_mask
        
        # memorys_mask = torch.ones((memorys.shape[0], memorys.shape[1]))
        
        # # llm_attention_mask = torch.cat((llm_ins_mask, memorys_mask, llm_que_mask, llm_tgt_mask), dim=1)
        # # llm_input_ids = torch.cat((llm_ins_outputs["input_ids"], memorys, \
        # #                            llm_que_outputs["input_ids"], llm_tgt_tokens), dim=1)
        # llm_attention_mask = torch.cat((memorys_mask, llm_ins_mask, llm_tgt_mask), dim=1)
        # llm_input_ids = torch.cat((memorys, llm_ins_outputs["input_ids"], llm_tgt_tokens), dim=1)
        # # pprint(llm_attention_mask.shape)
        # # pprint(llm_input_ids.shape)
        # # exit(0)
        # llm_labels = torch.full_like(llm_attention_mask, -100)
        # llm_labels[:, -llm_tgt_tokens.size(1):] = llm_tgt_tokens.masked_fill(
        #     ~llm_tgt_mask.bool(), -100,
        # )
        # # pprint(llm_labels)
        # # exit(0)

        return {
            'enc_que_ids': enc_questions_outputs.input_ids,
            'enc_que_mask': enc_questions_outputs.attention_mask,
            'enc_doc_ids': enc_doc_ids,
            'enc_doc_mask': enc_doc_mask,
            'target_doc_ids': target_doc_ids,
            'target_doc_mask': target_doc_mask,
            'llm_ins_ids': llm_ins_outputs.input_ids,
            'llm_ins_mask': llm_ins_outputs.attention_mask,
            'enc_prefix_ids': enc_prefix_outputs.input_ids,
            'enc_prefix_mask': enc_prefix_outputs.attention_mask,
            'enc_repeat_ids': enc_repeat_outputs.input_ids,
            'enc_repeat_mask': enc_repeat_outputs.attention_mask,
            'llm_answer_ids': llm_answer_outputs.input_ids,
            'llm_answer_mask': llm_answer_outputs.attention_mask,
            'llm_que_ids': llm_que_outputs.input_ids,
            'llm_que_mask': llm_que_outputs.attention_mask,
            'enc_continue_ids': enc_continue_outputs.input_ids,
            'enc_continue_mask': enc_continue_outputs.attention_mask
        }
    

class InferDataset(Dataset):
    def __init__(
        self,
        filepath,
        cmp_tokenizer,
        llm_tokenizer,
        max_doc_tokens,
        instruction_name,
        max_num_documents=None,
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.cmp_tokenizer = cmp_tokenizer
        self.llm_tokenizer = llm_tokenizer
        # self.max_num_documents = max_num_documents

        self.llm_tokenizer.padding_side = 'left'
        if self.llm_tokenizer.pad_token is None:
            self.llm_tokenizer.pad_token = llm_tokenizer.unk_token
            self.llm_tokenizer.pad_token_id = llm_tokenizer.unk_token_id

        self.instruction_text = instructions_map[instruction_name]


    def __len__(self):
        return len(self.dataset)
    

    def __getitem__(self, index):
        example = self.dataset[index]
        
        # question = example['prompt']
        question = ''
        document = format_document(example['input'], self.cmp_tokenizer, self.max_doc_tokens)
        # answer = example['answer']
        answer = example['input']

        return {
            'question': question,
            'document': document,
            'answer': answer,
        }


    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}

        enc_documents = []
        # llm_prefix_tokens = []
        for instance in batch:
            instance_enc_documents = instance['document']
            enc_documents.append(instance_enc_documents)

        enc_questions = [instance['question'] for instance in batch]
        # llm_prefix_tokens = [f'\nDocument:' for instance in batch for _ in instance['documents']]
        llm_questions = ['\nQuestion:' + instance['question'] + '\nAnswer:' for instance in batch]
        # llm_questions = ['\nQuestion:' + instance['question']  for instance in batch]

        llm_instructions = [self.instruction_text for _ in batch]
        answers = [instance['answer'] for instance in batch]

        llm_prefix_tokens = ['Text:' for _ in enc_questions]
        repeat_tokns = ['Restate the aforementioned Text.' for _ in enc_questions]
        continue_tokens = ['Continue writing the aforementioned Text.' for _ in enc_questions]
        enc_prefix_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_repeat_outputs = self.llm_tokenizer(repeat_tokns, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_continue_outputs = self.llm_tokenizer(continue_tokens, return_tensors='pt', padding=True, add_special_tokens=False)

        enc_que_outputs = self.cmp_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True)
        llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False)

        return {
            'enc_doc_tokens': enc_doc_outputs.input_ids,
            'enc_que_tokens': enc_que_outputs.input_ids,
            'enc_doc_mask': enc_doc_outputs.attention_mask,
            'enc_que_mask': enc_que_outputs.attention_mask,
            'llm_ins_tokens': llm_ins_outputs.input_ids,
            'llm_que_tokens': llm_que_outputs.input_ids,
            'llm_ins_mask': llm_ins_outputs.attention_mask,
            'llm_que_mask': llm_que_outputs.attention_mask,
            'answers': answers,
            'enc_prefix_ids': enc_prefix_outputs.input_ids,
            'enc_prefix_mask': enc_prefix_outputs.attention_mask,
            'enc_repeat_ids': enc_repeat_outputs.input_ids,
            'enc_repeat_mask': enc_repeat_outputs.attention_mask,
            'enc_continue_ids': enc_continue_outputs.input_ids,
            'enc_continue_mask': enc_continue_outputs.attention_mask
        }