import torch
import random
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from icecream import ic as pprint
import json

instructions_map = {
    'base': 'Write a high-quality answer for the given question using only the provided search results(some of which might be irrelevant).\n\n',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.\nText:',
    'repeat': 'Restate the aforementioned Text.',
    'ft_prefix': 'Search results:',
    'rs_prefix': 'Text:',
    'pwc': 'Answer the Question based on the given Text. \nText:'
}

# # code settings
# instructions_map = {
#     'base': 'Write a high-quality answer for the given question using only the provided search results(some of which might be irrelevant).\n',
#     # 'short': 'Answer the Question based on the given Text. Only give me the answer and do not output any other words.\nText:',
#     'short': 'Please complete the code given below. \n',
#     'summary': 'Summarize the passage according to the question provided.\n\n',
#     'restatement':'Restate the aforementioned Text.\n\n',  # 复述   ,
#     'icae':'Write a summary for the above text. Your summary should not exceed 100 words but should include as much information of the original text as possible.\n\n',
#     # 'ft_prefix': 'Text:',
#     # 'rs_prefix': 'Text:',
#     'ft_prefix': 'Please complete the code given below. \n',
#     'rs_prefix': 'Please complete the code given below. \n',
#     'pwc': 'Answer the Question based on the given Text. \nText:'
# }

def format_document(document, tokenizer, max_tokens=None):
    if max_tokens is not None:     
        return tokenizer.decode(
                tokenizer(
                document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
                add_special_tokens=False,
            )['input_ids'][:max_tokens]
        )
    
    return tokenizer.decode(
            tokenizer(
            document['title'] + ' ' + document['text'] if 'title' in document else document['text'],
            add_special_tokens=False,
        )['input_ids']
    )

def trunc_text(text, tokenizer, max_tokens=None):
    return tokenizer.decode(
            tokenizer(
            text,
            add_special_tokens=False,
        )['input_ids'][:max_tokens]
    )

class TrainDataset(Dataset):
    def __init__(
        self,
        filepath,
        model,
        max_doc_tokens,
        instruction_name,
        # que_mask_ratio=None,
        max_num_documents=None,
        min_num_documents=None,
        random_num_documents=False,
        # num_gold_documents=1,
        # use_answer_as_target=False,
        # gold_first_for_kd=False,
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.model = model
        self.cmp_tokenizer = model.tokenizer
        self.llm_tokenizer = model.llm_tokenizer
        # self.que_mask_ratio = que_mask_ratio
        self.max_num_documents = max_num_documents
        self.min_num_documents = min_num_documents
        self.random_num_documents = random_num_documents
        self.prefix_type = kwargs['prefix_type']
        # self.num_gold_documents = num_gold_documents
        # self.use_answer_as_target = use_answer_as_target
        # self.gold_first_for_kd = gold_first_for_kd

        # self.llm_tokenizer.padding_side = 'left'
        # # pprint(self.llm_tokenizer.pad_token)
        # # pprint(self.llm_tokenizer.pad_token_id)
        # # exit(0)
        # if self.llm_tokenizer.pad_token is None:
        #     self.llm_tokenizer.pad_token = model.llm_tokenizer.unk_token
        #     self.llm_tokenizer.pad_token_id = model.llm_tokenizer.unk_token_id
        self.cmp_tokenizer.padding_side = 'left'
        self.llm_tokenizer.padding_side = 'left'
            
        if self.cmp_tokenizer.pad_token is None:
            self.cmp_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            self.cmp_tokenizer.pad_token = "[PAD]"
            self.cmp_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")
            
        if self.llm_tokenizer.pad_token is None:
            self.llm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            self.llm_tokenizer.pad_token = "[PAD]"
            self.llm_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")

        self.instruction_text = instructions_map[instruction_name]
        self.prefix_text = instructions_map[self.prefix_type]

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        example = self.dataset[index]
        question = example['question']
        documents = [
            format_document(document, self.cmp_tokenizer)
            for document in example['ctxs'] 
        ]
        # document = "\n".join(example['ctxs'])
        
        # neg_documents = [
        #     format_document(document, self.cmp_tokenizer, self.max_doc_tokens)
        #     for document in example['ctxs'] if document['isgold'] is False
        # ]

        # if len(neg_documents) > self.max_num_documents:
        #     neg_documents = random.sample(neg_documents, k = self.max_num_documents)
        # else:
        #     random.shuffle(neg_documents)
            
        # pos_documents = [
        #     format_document(document, self.cmp_tokenizer, self.max_doc_tokens)
        #     for document in example['ctxs'] if document['isgold'] is True
        # ]

        # if len(pos_documents) > self.num_gold_documents:
        #     num_gold_documents = self.num_gold_documents
        #     if len(neg_documents) < self.max_num_documents:
        #         num_gold_documents = self.max_num_documents - len(neg_documents)
        #     pos_documents = random.sample(pos_documents, k = num_gold_documents)
        
        # else:
        #     random.shuffle(pos_documents)

        # if self.use_answer_as_target:
        #     appeared_answer_list = []
        #     for answer in example['answers']:
        #         if answer in '\n\n'.join(pos_documents):
        #             appeared_answer_list.append(answer)

        #     target = random.choice(
        #         appeared_answer_list if appeared_answer_list != [] else example['answers']
        #     )
        # else:
        #     target = example['target']

        # answers = example['answers']
        # answer = document

        return {
            'question': question,
            # 'neg_documents': neg_documents,
            # 'pos_documents': pos_documents,
            # 'target': target,
            'documents': documents,
            # 'answers': ""
        }

    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}

        # enc_documents = []
        # llm_prefix_tokens = []
        # llm_documents = []
        
        num_documents = (
            random.randint(self.min_num_documents, self.max_num_documents)
            if self.random_num_documents else self.max_num_documents
        )
        
        # pprint(num_documents)
        # cnt = 0
        # for instance in batch:
            # cnt += 1
            # instance_enc_candidate_documents = [document for document in instance['pos_documents'] + instance['neg_documents']][:num_documents]
            # random.shuffle(instance_enc_candidate_documents)
            # instance_enc_documents = ''.join(['\nDocument:' + document for document in instance_enc_candidate_documents])
            # enc_documents += instance_enc_documents
            # enc_documents.append(instance["documents"])

            # llm_candiate_documents = instance['pos_documents'] + instance['neg_documents']
            # llm_candiate_documents = llm_candiate_documents[:num_documents]
            # if not self.gold_first_for_kd:
            #     random.shuffle(llm_candiate_documents)

            # llm_documents += [''.join(['\nDocument:' + document for document in llm_candiate_documents])]
        # with open("enc_documents.json", "a") as f:
        #     json.dump(enc_documents, f, indent=4)
        # pprint(len(enc_documents))
        # pprint(cnt)
        
        # print(self.max_doc_tokens)
        # exit(0)
        enc_documents = [trunc_text("\n".join(instance["documents"][:num_documents]), self.cmp_tokenizer, self.max_doc_tokens)\
            for instance in batch]
        enc_questions = [instance['question'] for instance in batch]
        llm_questions = ['\nQuestion:' + instance['question'] + '\nAnswer:' for instance in batch]
        llm_targets = [trunc_text("\n".join(instance["documents"][:num_documents]), self.cmp_tokenizer, self.max_doc_tokens)\
            for instance in batch]
        llm_instructions = [self.instruction_text for _ in batch]
        llm_prefix_tokens = [self.prefix_text for _ in enc_questions]
        # answers = [instance['answers'] for instance in batch]

        repeat_tokens = ['Restate the aforementioned Text.' for _ in enc_questions]
        enc_prefix_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True)
        enc_repeat_outputs = self.llm_tokenizer(repeat_tokens, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_que_outputs = self.cmp_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        # pprint(enc_que_outputs.attention_mask)
        # pprint(enc_doc_outputs.attention_mask)
        total_mem_size = self.model.mem_size * self.model.compute_num_segments(len(enc_doc_outputs.input_ids[0]))
        llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True)
        # llm_doc_outputs = self.llm_tokenizer(llm_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        memorys = torch.zeros((llm_ins_outputs["input_ids"].shape[0], total_mem_size))
        # exit(0)

        def right_padding(value, padding_value):
            padded_value = pad_sequence(
                [torch.tensor(v) for v in value],
                batch_first=True,
                padding_value=padding_value,
            )
            return padded_value

        llm_tgt_outputs = [self.llm_tokenizer(ans, add_special_tokens=False).input_ids for ans in llm_targets]
        llm_tgt_tokens = right_padding(llm_tgt_outputs, self.llm_tokenizer.pad_token_id)
        llm_tgt_mask = right_padding([[1] * len(elem) for elem in llm_tgt_outputs], 0)
        # if self.que_mask_ratio is not None and self.que_mask_ratio > 0:
        #     llm_que_tokens = llm_que_outputs.input_ids
        #     random_indices = torch.rand_like(llm_que_outputs.input_ids[:, :-2].float()).sort().indices
        #     mask_indices = random_indices[:, :int(self.que_mask_ratio * llm_que_tokens.size(1))]
        #     llm_que_outputs.input_ids = llm_que_tokens.scatter(1, mask_indices, self.llm_tokenizer.pad_token_id)

        # process 'labels'
        llm_ins_mask = llm_ins_outputs.attention_mask
        llm_que_mask = llm_que_outputs.attention_mask
        # pprint(llm_ins_mask)
        # pprint(llm_que_mask)
        # exit(0)
        # pprint(llm_ins_mask.shape)
        # pprint(cmp_llm_doc_mask.shape)
        # pprint(llm_que_mask.shape)
        # pprint(llm_tgt_mask.shape)
        # exit(0)
        memorys_mask = torch.ones((memorys.shape[0], memorys.shape[1]))
        llm_attention_mask = torch.cat((llm_ins_mask, memorys_mask, llm_que_mask, llm_tgt_mask), dim=1)
        llm_input_ids = torch.cat((llm_ins_outputs["input_ids"], memorys, \
                                   llm_que_outputs["input_ids"], llm_tgt_tokens), dim=1)
        # pprint(llm_attention_mask.shape)
        # pprint(llm_input_ids.shape)
        # exit(0)
        llm_labels = torch.full_like(llm_attention_mask, -100)
        llm_labels[:, -llm_tgt_tokens.size(1):] = llm_tgt_tokens.masked_fill(
            ~llm_tgt_mask.bool(), -100,
        )
        # pprint(llm_labels)
        # exit(0)

        return {
            'enc_doc_ids': enc_doc_outputs.input_ids,
            'enc_doc_mask': enc_doc_outputs.attention_mask,
            'enc_que_ids': enc_que_outputs.input_ids,
            'enc_que_mask': enc_que_outputs.attention_mask,
            'llm_input_ids': llm_input_ids,
            'llm_input_mask': llm_attention_mask,
            'llm_ins_ids': llm_ins_outputs.input_ids,
            'llm_ins_mask': llm_ins_mask,
            'labels': llm_labels,
            'enc_prefix_ids': enc_prefix_outputs.input_ids,
            'enc_prefix_mask': enc_prefix_outputs.attention_mask,
            'enc_repeat_ids': enc_repeat_outputs.input_ids,
            'enc_repeat_mask': enc_repeat_outputs.attention_mask
        }
    

class InferDataset(Dataset):
    def __init__(
        self,
        filepath,
        cmp_tokenizer,
        llm_tokenizer,
        max_doc_tokens,
        instruction_name,
        max_num_documents=None,
        **kwargs,
    ):
        self.dataset = load_dataset('json', data_files=filepath, split='train')
        self.max_doc_tokens = max_doc_tokens
        self.cmp_tokenizer = cmp_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_num_documents = max_num_documents
        self.prefix_type = kwargs['prefix_type']

        # self.llm_tokenizer.padding_side = 'left'
        # if self.llm_tokenizer.pad_token is None:
        #     self.llm_tokenizer.pad_token = llm_tokenizer.unk_token
        #     self.llm_tokenizer.pad_token_id = llm_tokenizer.unk_token_id
        self.cmp_tokenizer.padding_side = 'left'
        self.llm_tokenizer.padding_side = 'left'
            
        if self.cmp_tokenizer.pad_token is None:
            self.cmp_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            self.cmp_tokenizer.pad_token = "[PAD]"
            self.cmp_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")
            
        if self.llm_tokenizer.pad_token is None:
            self.llm_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
            self.llm_tokenizer.pad_token = "[PAD]"
            self.llm_tokenizer.pad_token_id = self.cmp_tokenizer.convert_tokens_to_ids("[PAD]")

        self.instruction_text = instructions_map[instruction_name]
        self.prefix_text = instructions_map[self.prefix_type]

    def __len__(self):
        return len(self.dataset)
    

    def __getitem__(self, index):
        example = self.dataset[index]
        question = example['question']
        
        # pprint(len(example['ctxs']))
        # exit(0)
        
        # self.max_num_documents = 1
        documents = [
            format_document(document, self.cmp_tokenizer, None)
            for document in example['ctxs']
        ]
        # pprint(len(documents))
        # exit(0)
        
        # if len(documents) < self.max_num_documents:
        #     documents += ['\n' for _ in range(self.max_num_documents - len(documents))]

        answers = example['answers']

        return {
            'question': question,
            'documents': documents,
            'answers': answers,
        }


    def collate_fn(self, batch):
        if len(batch) == 0:
            return {}

        enc_documents = []
        # llm_prefix_tokens = []
        for instance in batch:
            instance_enc_candidate_documents = instance['documents']
            # instance_enc_documents = ''.join(['\nDocument:' + document for document in instance_enc_candidate_documents])
            instance_enc_documents = '\n'.join(instance_enc_candidate_documents)
            enc_documents.append(instance_enc_documents)

        enc_questions = [instance['question'] for instance in batch]
        # llm_prefix_tokens = [f'\nDocument:' for instance in batch for _ in instance['documents']]
        llm_questions = ['\nQuestion:' + instance['question'] + '?\nAnswer:' for instance in batch]
        # llm_questions = [instance['question'] for instance in batch]
        # llm_questions = ['Restate the aforementioned Text.' for _ in enc_questions]
        llm_instructions = [self.instruction_text for _ in batch]
        answers = [instance['answers'] for instance in batch]

        llm_prefix_tokens = [self.prefix_text for _ in enc_questions]
        repeat_tokns = ['Restate the aforementioned Text.' for _ in enc_questions]
        enc_prefix_outputs = self.llm_tokenizer(llm_prefix_tokens, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_repeat_outputs = self.llm_tokenizer(repeat_tokns, return_tensors='pt', padding=True, add_special_tokens=False)

        enc_que_outputs = self.cmp_tokenizer(enc_questions, return_tensors='pt', padding=True, add_special_tokens=False)
        enc_doc_outputs = self.cmp_tokenizer(enc_documents, return_tensors='pt', padding=True, add_special_tokens=False)
        llm_ins_outputs = self.llm_tokenizer(llm_instructions, return_tensors='pt', padding=True)
        llm_que_outputs = self.llm_tokenizer(llm_questions, return_tensors='pt', padding=True, add_special_tokens=False)

        return {
            'enc_doc_ids': enc_doc_outputs.input_ids,
            'enc_que_ids': enc_que_outputs.input_ids,
            'enc_doc_mask': enc_doc_outputs.attention_mask,
            'enc_que_mask': enc_que_outputs.attention_mask,
            'llm_ins_ids': llm_ins_outputs.input_ids,
            'llm_que_ids': llm_que_outputs.input_ids,
            'llm_ins_mask': llm_ins_outputs.attention_mask,
            'llm_que_mask': llm_que_outputs.attention_mask,
            # 'llm_pfx_tokens': llm_pfx_outputs.input_ids,
            # 'llm_pfx_mask': llm_pfx_outputs.attention_mask,
            'answers': answers,
            'enc_prefix_ids': enc_prefix_outputs.input_ids,
            'enc_prefix_mask': enc_prefix_outputs.attention_mask,
            'enc_repeat_ids': enc_repeat_outputs.input_ids,
            'enc_repeat_mask': enc_repeat_outputs.attention_mask
        }