import os
import torch
import json
import re
import random
import h5py
import pickle
from nltk.tokenize import word_tokenize

from random import random as rand
from metrics import _precision_score, bleu_metric

random.seed(42)

from torch.utils.data import DataLoader, Dataset

from itertools import cycle

import numpy as np
from metrics import f1_metric, bleu_metric
from nltk.tree import Tree
import nltk
from nltk.corpus import stopwords


def save_hparams(args, path):
    with open(path, 'w', encoding='utf-8') as f:
        for attr, value in sorted(vars(args).items()):
            f.writelines("{}={}\n".format(attr.upper(), value))

def check_mem(cuda_device):
    devices_info = os.popen('"nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")
    total, used = devices_info[int(cuda_device)].split(',')
    return total,used

def check_mem_all():
    devices_info = os.popen('"nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")
    return devices_info

def occupy_mem_new(cuda_device_list, ratio=0.6, num_devices=8):
    import time

    if len(cuda_device_list) == 0 or len(cuda_device_list[0]) == 0:
        while True:
            devices_info = check_mem_all()
            available_devices = []
            occupys = []
            for cuda_device in range(num_devices):
                total, used = devices_info[int(cuda_device)].split(',')
                total = int(total)
                used = int(used)
                occupy = int(total * ratio)
                print("Device-{}: {}/{}/{}".format(cuda_device, total, used, occupy))
                if occupy + used <= total * 0.95:
                    print('Find device-{}!'.format(cuda_device))
                    available_devices.append(cuda_device)
                    occupys.append(occupy)
            if len(available_devices) > 0: # hoooope
                print(available_devices[0])
                os.environ['CUDA_VISIBLE_DEVICES'] = str(available_devices[0])
                try:
                    x = torch.cuda.FloatTensor(256, 1024, occupys[0], device='cuda:0')
                    del x
                except RuntimeError:
                    print("Failed, continue...")
                    time.sleep(2)
                    continue
                break
        input(">>>>>")
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(cuda_device_list)
        for id, cuda_device in enumerate(cuda_device_list):
            while True:
                total, used = check_mem(cuda_device)
                total = int(total)
                used = int(used)
                occupy = int(total * ratio)
                print("Device-{}: {}/{}/{}".format(cuda_device, total, used, occupy))
                if occupy + used <= total * 0.95:
                    print('Find device-{}!'.format(cuda_device))
                    try:
                        x = torch.cuda.FloatTensor(256, 1024, occupy, device='cuda:{}'.format(id))
                        del x
                    except RuntimeError:
                        time.sleep(2)
                        continue
                    break
        # input('>>>>') # todo: del

def detokenize(tk_str):
    tk_list = tk_str.strip().split()
    r_list = []
    for tk in tk_list:
        if tk.startswith('##') and len(r_list) > 0:
            r_list[-1] = r_list[-1] + tk[2:]
        else:
            r_list.append(tk)
    return " ".join(r_list)

def truncate(str, num):
    str = str.strip()
    length = len(str.split())
    list = str.split()[max(0, length - num):]
    return " ".join(list)


# ================================ BART ====================================== #
from transformers import BartTokenizer

class WizardDataset(Dataset):
    def __init__(self, data_path):
        self._data = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                self._data.append(json.loads(line))
        self._n_data = len(self._data)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        knowledge = self._data[i]['knowledge']
        history = self._data[i]['history']
        user = self._data[i]['user']
        response = self._data[i]['response']
        return ('\n\n'.join(knowledge), '\n\n'.join(history), np.array(user), response)

    @staticmethod
    def collate_fn(batch):
        knowledge_list = [item[0] for item in batch]
        history_list = [item[1] for item in batch]
        user_list = [item[2] for item in batch]
        response_list = [item[3] for item in batch]
        return knowledge_list, history_list, user_list, response_list

class WizardRankDataset(Dataset):
    def __init__(self, data_path, score_file):
        self._data = []
        self._score = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                self._data.append(json.loads(line))
        with open(score_file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                self._score.append(line.strip())
        self._n_data = len(self._data)
        assert len(self._data) == len(self._score)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        knowledge = self._data[i]['knowledge']
        history = self._data[i]['history']
        user = self._data[i]['user']
        response = self._data[i]['response']
        # score_line = self._score[i] # str
        score = list(map(float, self._score[i].strip().split('\t')))
        assert len(score) == len(knowledge)
        knowledge2score = {k: s for k, s in zip(knowledge, score)}
        knowledge = sorted(knowledge, key=lambda x: knowledge2score[x], reverse=True)
        return ('\n\n'.join(knowledge), '\n\n'.join(history), np.array(user), response)

    @staticmethod
    def collate_fn(batch):
        knowledge_list = [item[0] for item in batch]
        history_list = [item[1] for item in batch]
        user_list = [item[2] for item in batch]
        response_list = [item[3] for item in batch]
        return knowledge_list, history_list, user_list, response_list

class BartBatcher(object):
    def __init__(self, max_source_length, max_target_length, knowledge_truncate, text_truncate, bart_config, cuda=True):
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

    def encode_line(self, line, max_length, pad_to_max_length=True, return_tensors="pt"):
        # extra_kw = {"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}
        return self.tokenizer(
            [line],
            max_length=max_length,
            padding="max_length" if pad_to_max_length else 'do_not_pad',
            truncation=True,
            return_tensors=return_tensors,
            add_prefix_space=True,
        )

    def tokenize(self, text, max_length=128):
        return self.encode_line(text, max_length=max_length, pad_to_max_length=False)['input_ids'].squeeze().tolist()[1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, knowledge_list, history_list, user_list, response_list=None, training=True):
        if training:
            source_input_list, source_mask_list, target_input_list = [], [], []
            for know, hist, resp in zip(knowledge_list, history_list, response_list):
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h, max_length=999)[:self.text_truncate])

                knowledge_input = []
                for k in know:
                    tmp = [self.know_id] + self.tokenize(k, max_length=999)[:self.knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + 2 < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + 2 == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - 2
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_input = source_input[-self.max_source_length:]
                source_mask = [1] * len(source_input)

                target_input = [self.bos_id] + self.tokenize(resp, max_length=999)[:self.max_target_length-2] + [self.eos_id]

                # padding
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))

                source_mask_list.append(source_mask)
                source_input_list.append(source_input)
                target_input_list.append(target_input)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)

            target_input_list = self.trim_batch(target_input_list)
            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)

            label_list = target_input_list[:, 1:].clone()
            target_input_list = target_input_list[:, :-1].contiguous()
            return {
                "input_ids": source_input_list,
                "attention_mask": source_mask_list,
                "decoder_input_ids": target_input_list,
                "labels": label_list,
                "use_cache": False,
            }
        else:
            source_input_list, source_mask_list = [], []
            for know, hist in zip(knowledge_list, history_list):
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h, max_length=999)[:self.text_truncate])

                knowledge_input = []
                for k in know:
                    tmp = [self.know_id] + self.tokenize(k, max_length=999)[:self.knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + 2 < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + 2 == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - 2
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_input = source_input[-self.max_source_length:]
                source_mask = [1] * len(source_input)

                # padding
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))

                source_mask_list.append(source_mask)
                source_input_list.append(source_input)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)

            return {
                "input_ids": source_input_list,
                "attention_mask": source_mask_list,
                "decoder_start_token_id": self.bos_id,
                "use_cache": True,
            }



# ==================================== BERT ================================== #
from transformers import BertTokenizer

class RedditKsDataset_v0(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._history = reader['history']
        self._knowledge = reader['knowledge']
        self._label = reader['label']
        self._n_data = len(self._history)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        query = self._history[i].split('<#Q#>')
        query = [q.strip() for q in query]
        candidate = self._knowledge[i]
        label = int(self._label[i])
        return '\n\n'.join(query), candidate, label

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        label_list = [item[2] for item in batch]
        return query_list, candidate_list, label_list


class RedditKsDataset(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self._n_data = len(self._src)

    def preprocess(self, sentence):
        tokens = sentence.strip().split()
        remove_list = ['<#bleu0#>', '<#bleu1#>', '<#bleu2#>', '<#bleu3#>', '<#bleu4#>', '<#bleu5#>', '<#bleu6#>', '<#bleu7#>', '<#bleu8#>', '<#bleu9#>', '<#bleu10#>', '<#bleu11#>', '<#bleu12#>', '<#bleu13#>']
        tokens = [t for t in tokens if t not in remove_list]
        if len(tokens) == 0:
            tokens = ['.']
        return detokenize(' '.join(tokens))

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        query = src.strip().split('<#Q2K#>')[0].strip()
        query = query.strip().split('<#Q#>')
        query = [self.preprocess(q.strip()) for q in query]
        candidate = self.preprocess(src.strip().split('<#Q2K#>')[1].strip())
        tgt = self._tgt[i]
        label = int(tgt.strip().split()[-1])
        return '\n\n'.join(query), candidate, label

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        label_list = [item[2] for item in batch]
        return query_list, candidate_list, label_list

class RedditKsDataset_v3(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self.valid_indices = []
        for i in tqdm(range(len(self._tgt))):
            tgt = self._tgt[i]
            label = int(tgt.strip().split()[-1])
            if label == 1:
                self.valid_indices.append(i)

        self._n_data = len(self.valid_indices)

    def preprocess(self, sentence):
        tokens = sentence.strip().split()
        remove_list = ['<#bleu0#>', '<#bleu1#>', '<#bleu2#>', '<#bleu3#>', '<#bleu4#>', '<#bleu5#>', '<#bleu6#>',
                       '<#bleu7#>', '<#bleu8#>', '<#bleu9#>', '<#bleu10#>', '<#bleu11#>', '<#bleu12#>', '<#bleu13#>']
        tokens = [t for t in tokens if t not in remove_list]
        if len(tokens) == 0:
            tokens = ['.']
        return detokenize(' '.join(tokens))

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        i = self.valid_indices[i]
        src = self._src[i]
        query = src.strip().split('<#Q2K#>')[0].strip()
        query = query.strip().split('<#Q#>')
        query = [self.preprocess(q.strip()) for q in query]
        candidate = self.preprocess(src.strip().split('<#Q2K#>')[1].strip())
        return '\n\n'.join(query), candidate

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        knowledge_pool = set(candidate_list)
        query_list_new = []
        candidate_list_new = []
        label_list_new = []
        for query, cand in zip(query_list, candidate_list):
            query_list_new.append(query)
            candidate_list_new.append(cand)
            label_list_new.append(1)
            negs = random.sample(list(knowledge_pool - set(cand)), 3)
            for neg in negs:
                query_list_new.append(query)
                candidate_list_new.append(neg)
                label_list_new.append(0)
        return query_list_new, candidate_list_new, label_list_new

class RedditKsDataset_v4(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self.valid_indices = []
        for i in tqdm(range(len(self._tgt))):
            tgt = self._tgt[i]
            label = int(tgt.strip().split()[-1])
            if label == 1:
                self.valid_indices.append(i)

        self._n_data = len(self.valid_indices)

    def preprocess(self, sentence):
        tokens = sentence.strip().split()
        remove_list = ['<#bleu0#>', '<#bleu1#>', '<#bleu2#>', '<#bleu3#>', '<#bleu4#>', '<#bleu5#>', '<#bleu6#>',
                       '<#bleu7#>', '<#bleu8#>', '<#bleu9#>', '<#bleu10#>', '<#bleu11#>', '<#bleu12#>', '<#bleu13#>']
        tokens = [t for t in tokens if t not in remove_list]
        if len(tokens) == 0:
            tokens = ['.']
        return detokenize(' '.join(tokens))

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        i = self.valid_indices[i]
        src = self._src[i]
        query = src.strip().split('<#Q2K#>')[0].strip()
        query = query.strip().split('<#Q#>')
        query = [self.preprocess(q.strip()) for q in query]
        candidate = self.preprocess(src.strip().split('<#Q2K#>')[1].strip())
        return '\n\n'.join(query), candidate

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        knowledge_pool = set(candidate_list)
        query_list_new = []
        candidate_list_new = []
        label_list_new = []
        for query, cand in zip(query_list, candidate_list):
            query_list_new.append(query)
            candidate_list_new.append(cand)
            label_list_new.append(1)
            negs = random.sample(list(knowledge_pool - set(cand)), 7)
            for neg in negs:
                query_list_new.append(query)
                candidate_list_new.append(neg)
                label_list_new.append(0)
            if len(query_list_new) >= 32: break
        return query_list_new, candidate_list_new, label_list_new


class WizardKsDataset(Dataset):
    def __init__(self, data_path, bert_tokenize=False, max_candidate=999):
        self.bert_tokenize = bert_tokenize
        try:
            self.tokenizer = BertTokenizer.from_pretrained("/home2/xxx/Data/pretrain-models/bert-base-uncased", do_lower_case=True)
        except:
            self.tokenizer = BertTokenizer.from_pretrained("/home/xxx/Data/pretrain-models/bert-base-uncased", do_lower_case=True)

        self._query = []
        self._candidate = []
        self._label = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                data = json.loads(line)
                # query = [' '.join(word_tokenize(q)) for q in data['history']]
                query = [q for q in data['history']]

                for i, cand in enumerate(data['knowledge'][:max_candidate]):
                    self._query.append(query) # list[str]
                    # cand = ' '.join(word_tokenize(cand))
                    self._candidate.append(cand) # str
                    self._label.append(1 if i == 0 else 0) # int
        self._n_data = len(self._query)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        query = self._query[i] # list[str]
        candidate = self._candidate[i] # str
        label = self._label[i] # int
        # preprocess
        if self.bert_tokenize:
            query = [' '.join(self.tokenizer.tokenize(q)) for q in query]
            query = [detokenize(q) for q in query]
            candidate = ' '.join(self.tokenizer.tokenize(candidate))
            candidate = detokenize(candidate)
        return '\n\n'.join(query), candidate, label

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        label_list = [item[2] for item in batch]
        return query_list, candidate_list, label_list


class BertBatcher(object):
    def __init__(self, block_size, knowledge_truncate, text_truncate, bert_config, cuda=True):
        self.block_size = block_size
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = BertTokenizer.from_pretrained(bert_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')

        self.device = torch.device('cuda' if cuda else 'cpu')

    def tokenize(self, text):
        return self.tokenizer(text, add_special_tokens=True)['input_ids'][1:-1]

    def __call__(self, query_list, candidate_list, label_list=None, training=True):
        if training:
            input_ids_list, attention_mask_list, token_type_ids_list, labels = [], [], [], []
            for query, cand, label in zip(query_list, candidate_list, label_list):
                query_input = []
                for q in query:
                    query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

                candidate_input = [self.know_id] + self.tokenize(cand)[:self.knowledge_truncate]
                query_input = query_input[:self.block_size - 3 - len(candidate_input)]
                input_ids = [self.cls_id] + query_input + [self.sep_id] + candidate_input + [self.sep_id]
                attention_mask = [1] * len(input_ids)
                token_type_ids = [0] * (len(query_input) + 2) + [1] * (len(candidate_input) + 1)

                # padding
                input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
                attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))
                token_type_ids = token_type_ids + [0] * (self.block_size - len(token_type_ids))

                input_ids_list.append(input_ids)
                attention_mask_list.append(attention_mask)
                token_type_ids_list.append(token_type_ids)
                labels.append(label)
            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
            token_type_ids_list = torch.tensor(token_type_ids_list, device=self.device, dtype=torch.long)
            labels = torch.tensor(labels, device=self.device, dtype=torch.long)
            return {
                "input_ids": input_ids_list,
                "attention_mask": attention_mask_list,
                "token_type_ids": token_type_ids_list,
                "labels": labels
            }
        else:
            input_ids_list, attention_mask_list, token_type_ids_list = [], [], []
            for query, cand in zip(query_list, candidate_list):
                query_input = []
                for q in query:
                    query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

                candidate_input = [self.know_id] + self.tokenize(cand)[:self.knowledge_truncate]
                query_input = query_input[:self.block_size - 3 - len(candidate_input)]
                input_ids = [self.cls_id] + query_input + [self.sep_id] + candidate_input + [self.sep_id]
                attention_mask = [1] * len(input_ids)
                token_type_ids = [0] * (len(query_input) + 2) + [1] * (len(candidate_input) + 1)

                # padding
                input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
                attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))
                token_type_ids = token_type_ids + [0] * (self.block_size - len(token_type_ids))

                input_ids_list.append(input_ids)
                attention_mask_list.append(attention_mask)
                token_type_ids_list.append(token_type_ids)
            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
            token_type_ids_list = torch.tensor(token_type_ids_list, device=self.device, dtype=torch.long)
            return {
                "input_ids": input_ids_list,
                "attention_mask": attention_mask_list,
                "token_type_ids": token_type_ids_list,
            }


# ===================================== TinyBART(for PostM) ==================================================== #

class BartBatcherForPostM(object):
    def __init__(self, max_source_length, max_target_length, knowledge_truncate, text_truncate, bart_config, cuda=True):
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

    def encode_line(self, line, max_length, pad_to_max_length=True, return_tensors="pt"):
        # extra_kw = {"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}
        return self.tokenizer(
            [line],
            max_length=max_length,
            padding="max_length" if pad_to_max_length else 'do_not_pad',
            truncation=True,
            return_tensors=return_tensors,
            add_prefix_space=True,
        )

    def tokenize(self, text, max_length=128):
        return self.encode_line(text, max_length=max_length, pad_to_max_length=False)['input_ids'].squeeze().tolist()[1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, knowledge_list, history_list, user_list, response_list):
        source_input_list, source_mask_list, target_input_list, target_mask_list = [], [], [], []
        for know, hist, resp in zip(knowledge_list, history_list, response_list):
            history_input = []
            for h in hist:
                history_input += ([self.conv_id] + self.tokenize(h, max_length=999)[:self.text_truncate])
            history_input = history_input[:self.max_source_length - 3]

            knowledge_input = []
            for k in know:
                tmp = [self.know_id] + self.tokenize(k, max_length=999)[:self.knowledge_truncate]
                if len(knowledge_input) + len(tmp) + len(history_input) + 2 < self.max_source_length:
                    knowledge_input += tmp
                elif len(knowledge_input) + len(tmp) + len(history_input) + 2 == self.max_source_length:
                    knowledge_input += tmp
                    break
                else:
                    tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - 2
                    knowledge_input += tmp[:tmp_truncate]
                    break

            source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
            # source_input = source_input[-self.max_source_length:]
            source_mask = [1] * len(source_input)

            target_input = [self.bos_id] + self.tokenize(resp, max_length=999)[:self.max_target_length - 2] + [self.eos_id]
            target_mask = [1] * len(target_input)

            # padding
            source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
            source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
            target_mask = target_mask + [0] * (self.max_target_length - len(target_mask))
            target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))

            source_mask_list.append(source_mask)
            source_input_list.append(source_input)
            target_mask_list.append(target_mask)
            target_input_list.append(target_input)
        source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
        source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
        target_mask_list = torch.tensor(target_mask_list, device=self.device, dtype=torch.long)
        target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)

        target_input_list, target_mask_list = self.trim_batch(target_input_list, attention_mask=target_mask_list)
        source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)

        return {
            "input_ids": source_input_list,
            "attention_mask": source_mask_list,
            "decoder_input_ids": target_input_list,
            "decoder_attention_mask": target_mask_list,
        }


##################################################################################################################


class KGH5Dataset(Dataset):
    def __init__(self, data_path, metric='f1'):
        self._metric = metric
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self._check = dataset['check']
        self._n_data = len(self._src)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        tgt = self._tgt[i]
        check = self._check[i]

        candidate = [detokenize(x) for x in check.split('[SEP]')]
        response = detokenize(tgt.split('[SEP]')[0].strip().split('\t')[0])
        if self._metric == 'f1':
            similarity = [f1_metric([x], [response]) for x in candidate]
        elif self._metric == 'b1':
            similarity = [bleu_metric([x], [response])[0] for x in candidate]
        elif self._metric == 'b2':
            similarity = [bleu_metric([x], [response])[1] for x in candidate]
        else:
            raise NotImplementedError


        chosen = np.argmax(similarity)

        src = src.split('[SEP]')[chosen].strip()
        tgt = tgt.split('[SEP]')[chosen].strip()
        check = check.split('[SEP]')[chosen].strip()

        return src, tgt, check

class KGH5DynamicDataset(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self._check = dataset['check']
        self._n_data = len(self._src)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        tgt = self._tgt[i]
        check = self._check[i]
        return src, tgt, check

class KGH5KSDataset(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self._n_data = len(self._src)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        tgt = self._tgt[i]

        return src, tgt

class KGH5TestDataset(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self._n_data = len(self._src)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        tgt = self._tgt[i]
        return src, tgt


class KGH5TestKSDataset(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._n_data = len(self._src)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        src = self._src[i]
        return (src, )


def get_batch_loader(dataset, collate_fn, batch_size=2, num_workers=0, is_test=True):
    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=(not is_test), num_workers=num_workers, collate_fn=collate_fn
    )
    return loader if is_test else cycle(loader)


# =============================== BART_FULL =================================== #

# for load tf-idf model
def space_tokenize(text):
    return text.split()

class SegmentPreprocessor(object):
    def __init__(self, min_segment_len, threshold, tfidf_path, transformer_path, feature='idf'):
        self.min_segment_len = min_segment_len
        self.threshold = threshold

        self.tfidf = pickle.load(open(tfidf_path, 'rb'))
        self.idf = self.tfidf.idf_
        self.vocab = {name: i for i, name in enumerate(self.tfidf.get_feature_names())}
        self.stop_words = sorted(self.tfidf.get_feature_names(), key=lambda x:self.idf[self.vocab[x]])[:500]

        weight = torch.load(os.path.join(transformer_path, 'pytorch_model.bin'), map_location='cpu')['model.shared.weight']
        self.embeddings = torch.nn.Embedding(weight.size(0), weight.size(1))
        self.embeddings.weight.data[:] = weight

        self.tokenizer = BartTokenizer.from_pretrained(transformer_path, do_lower_case=True)
        self.special_char = 'Ġ'

        self.cos = torch.nn.CosineSimilarity(dim=0)
        self.feature = feature

    def get_weights(self, text):
        ids = self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]
        subwords = self.tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)
        segment_index = []
        prefix = []
        for w in subwords:
            if w.startswith(self.special_char) and len(prefix) == 0:
                prefix = [w]
            elif w.startswith(self.special_char) and len(prefix) > 0:
                segment_index.extend([self.vocab.get(self.tokenizer.convert_tokens_to_string(prefix).strip(), 99999999) for _ in range(len(prefix))])
                prefix = [w]
            else:
                prefix.append(w)
        segment_index.extend([self.vocab.get(self.tokenizer.convert_tokens_to_string(prefix).strip(), 99999999) for _ in range(len(prefix))])

        if self.feature == 'tfidf':
            segment_tfidf = self.tfidf.transform([text])
            indices_to_data = {i: d for i, d in zip(segment_tfidf.indices, segment_tfidf.data)}
            weights = [indices_to_data.get(i, 0) for i in segment_index]
        elif self.feature == 'idf':
            weights = [self.idf[i] if i < len(self.idf) else 0 for i in segment_index]
        else:
            weights = [1. for i in segment_index]
        return ids, weights

    def get_embed(self, ids):
        ids = torch.tensor(ids, dtype=torch.long)
        embed = self.embeddings(ids.unsqueeze(0)).squeeze(0)
        return embed

    def sim(self, segment_embed, segment_weights, knowledge_embed, knowledge_weights):
        # segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        segment_weights = torch.tensor(segment_weights)
        # segment_embed = self.embeddings(segment_ids.unsqueeze(0)).squeeze(0) # [seq_len, dim]
        weighted_segment_embed = segment_embed * segment_weights.unsqueeze(1) # [seq_len, 1]
        segment_embed = torch.mean(weighted_segment_embed, dim=0) # [dim]
        # print('the size of segment_embed is {}'.format(segment_embed.size()))

        # knowledge_ids = torch.tensor(knowledge_ids, dtype=torch.long)
        knowledge_weights = torch.tensor(knowledge_weights)
        # knowledge_embed = self.embeddings(knowledge_ids.unsqueeze(0)).squeeze(0)
        weighted_knowledge_embed = knowledge_embed * knowledge_weights.unsqueeze(1)
        knowledge_embed = torch.mean(weighted_knowledge_embed, dim=0)
        # print('the size of knowledge_embed is {}'.format(knowledge_embed.size()))
        similarity = self.cos(segment_embed, knowledge_embed)
        # print('the size of similarity is {}'.format(similarity.size()))
        return similarity.item()

    def get_segments(self, tree, knowledge_embed, knowledge_weights, min_len=3, threshold=0.4):
        ret = []
        for i in range(len(tree)):
            text = ' '.join(tree[i].leaves())
            text_ids, text_weights = self.get_weights(text)
            text_embed = self.get_embed(text_ids)
            similarity = self.sim(text_embed, text_weights, knowledge_embed, knowledge_weights)
            if similarity > threshold or len(text.split()) <= min_len:
                ret.append((tree[i].leaves(), similarity))
            else:
                ret.extend(self.get_segments(tree[i], knowledge_embed, knowledge_weights, min_len=min_len, threshold=threshold))
        return ret

    def segmentation(self, tree, knowledge):
        assert tree.label() == 'ROOT' and len(tree) == 1
        tree_to_iter = tree[0]
        knowledge_ids, knowledge_weights = self.get_weights(knowledge)
        knowledge_embed = self.get_embed(knowledge_ids)
        segments = self.get_segments(tree_to_iter, knowledge_embed, knowledge_weights, min_len=self.min_segment_len, threshold=self.threshold)
        return segments

    def remove_stop_words(self, text):
        ret = [w for w in text.strip().split() if w not in self.stop_words]
        return ' '.join(ret)

    def get_segments_v2(self, tree, knowledge, min_len=3, threshold=0.4):
        ret = []
        for i in range(len(tree)):
            text = ' '.join(tree[i].leaves())
            prec = _precision_score(self.remove_stop_words(text), [self.remove_stop_words(knowledge)])
            if prec > threshold or len(text.split()) <= min_len:
                ret.append((tree[i].leaves(), prec))
            else:
                ret.extend(self.get_segments_v2(tree[i], knowledge, min_len=min_len, threshold=threshold))
        return ret


    def segmentation_v2(self, tree, knowledge):
        assert tree.label() == 'ROOT' and len(tree) == 1
        tree_to_iter = tree[0]
        segments = self.get_segments_v2(tree_to_iter, knowledge, min_len=self.min_segment_len, threshold=self.threshold)
        return segments

class SegmentPreprocessor_v2(object):
    def __init__(self, min_segment_len, threshold, stop_words_path, stop_words_size):
        self.min_segment_len = min_segment_len
        self.threshold = threshold

        self.stop_words = set()
        with open(stop_words_path, encoding='utf-8') as f:
            for line in f.readlines():
                word, score = line.strip().split('\t')
                self.stop_words.add(word)
                if len(self.stop_words) >= stop_words_size:
                    break

    def remove_stop_words(self, text):
        ret = [w for w in text.strip().split() if w not in self.stop_words]
        return ' '.join(ret)

    def get_segments_v2(self, tree, knowledge, min_len=3, threshold=0.4):
        ret = []
        for i in range(len(tree)):
            text = ' '.join(tree[i].leaves())
            prec = _precision_score(self.remove_stop_words(text), [self.remove_stop_words(knowledge)])
            if prec > threshold or len(text.split()) <= min_len:
                ret.append((tree[i].leaves(), prec))
            else:
                ret.extend(self.get_segments_v2(tree[i], knowledge, min_len=min_len, threshold=threshold))
        return ret

    def segmentation(self, tree, knowledge):
        assert tree.label() == 'ROOT' and len(tree) == 1
        tree_to_iter = tree[0]
        segments = self.get_segments_v2(tree_to_iter, knowledge, min_len=self.min_segment_len, threshold=self.threshold)
        return segments

class SegmentPreprocessor_v3(object):
    def __init__(self, min_segment_len, threshold):
        self.min_segment_len = min_segment_len
        self.threshold = threshold
        self.stop_words = set(stopwords.words('english'))

    def remove_stop_words(self, text):
        ret = [w for w in text.strip().split() if w not in self.stop_words]
        return ' '.join(ret)

    def get_segments_v2(self, tree, knowledge, min_len=3, threshold=0.4):
        ret = []
        for i in range(len(tree)):
            text = ' '.join(tree[i].leaves())
            prec = _precision_score(self.remove_stop_words(text), [self.remove_stop_words(knowledge)])
            if prec > threshold or len(text.split()) <= min_len:
                ret.append((tree[i].leaves(), prec))
            else:
                ret.extend(self.get_segments_v2(tree[i], knowledge, min_len=min_len, threshold=threshold))
        return ret

    def segmentation(self, tree, knowledge):
        assert tree.label() == 'ROOT' and len(tree) == 1
        tree_to_iter = tree[0]
        segments = self.get_segments_v2(tree_to_iter, knowledge, min_len=self.min_segment_len, threshold=self.threshold)
        return segments

class RedditDataset(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._history = reader['history']
        self._response = reader['response']
        self._knowledge = reader['knowledge'] # top-10
        self._check = reader['check'] # top-10
        self._tree = reader['tree']
        self._n_data = len(self._history)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        history = self._history[i]
        tree = self._tree[i]
        # knowledge = self._knowledge[i].split('[SEP]')[0].strip() # todo: choose the highest
        # check = self._check[i].split('[SEP]')[0].strip()
        knowledge = self._knowledge[i]
        check = self._check[i]
        return history, knowledge, tree, check

    @staticmethod
    def collate_fn(batch):
        history_list = [item[0] for item in batch]
        knowledge_list = [item[1] for item in batch]
        tree_list = [item[2] for item in batch]
        check_list = [item[3] for item in batch]
        return history_list, knowledge_list, tree_list, check_list

class RedditSentiDataset(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._history = reader['history']
        self._response = reader['response']
        self._knowledge = reader['knowledge'] # top-10
        self._check = reader['check'] # top-10
        self._tree = reader['tree']
        self._segment_dict = reader['segment_dict']
        self._n_data = len(self._history)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        history = self._history[i]
        tree = self._tree[i]
        # knowledge = self._knowledge[i].split('[SEP]')[0].strip() # todo: choose the highest
        # check = self._check[i].split('[SEP]')[0].strip()
        knowledge = self._knowledge[i]
        check = self._check[i]
        segment_dict = self._segment_dict[i]
        return history, knowledge, tree, check, segment_dict

    @staticmethod
    def collate_fn(batch):
        history_list = [item[0] for item in batch]
        knowledge_list = [item[1] for item in batch]
        tree_list = [item[2] for item in batch]
        check_list = [item[3] for item in batch]
        segment_dict_list = [item[4] for item in batch]
        return history_list, knowledge_list, tree_list, check_list, segment_dict_list

class WizardDataset_v2(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._history = reader['history']
        self._response = reader['response']
        self._knowledge = reader['knowledge']
        self._score = reader['score']
        self._n_data = len(self._history)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        history = self._history[i]
        response = self._response[i]
        knowledge = self._knowledge[i]
        score = self._score[i]
        # sort
        knowledge = knowledge.strip().split('<#K#>')
        knowledge = [k.strip() for k in knowledge]
        score = score.strip().split('\t')
        score = [float(s) for s in score]
        knowledge_to_score = {k: s for k, s in zip(knowledge, score)}
        knowledge = sorted(knowledge, key=lambda x: knowledge_to_score[x], reverse=True)
        knowledge = ' <#K#> '.join(knowledge)
        return history, knowledge, response

    @staticmethod
    def collate_fn(batch):
        history_list = [item[0] for item in batch]
        knowledge_list = [item[1] for item in batch]
        response_list = [item[2] for item in batch]
        return history_list, knowledge_list, response_list

class WizardDataset_v3(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._history = reader['history']
        self._response = reader['response']
        self._knowledge = reader['knowledge']
        self._score = reader['score']
        self._n_data = len(self._history)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        history = self._history[i]
        response = self._response[i]
        knowledge = self._knowledge[i]
        score = self._score[i]
        # sort
        knowledge = knowledge.strip().split('<#K#>')
        knowledge = [k.strip() for k in knowledge]
        score = score.strip().split('\t')
        score = [float(s) for s in score]
        # knowledge_to_score = {k: s for k, s in zip(knowledge, score)}
        # knowledge = sorted(knowledge, key=lambda x: knowledge_to_score[x], reverse=True)
        knowledge = ' <#K#> '.join(knowledge)
        return history, knowledge, response

    @staticmethod
    def collate_fn(batch):
        history_list = [item[0] for item in batch]
        knowledge_list = [item[1] for item in batch]
        response_list = [item[2] for item in batch]
        return history_list, knowledge_list, response_list


class BartFullBatcher(object):
    def __init__(
            self, min_segment_len, split_threshold, know_threshold, copy_threshold, stop_words_path, stop_words_size,
            merge, percentage, drop_know, infilling, mlm, random_know, bleu_percent, reverse, use_sep, mix_know, add_prefix_space, test_knowledge_truncate, test_knowledge_num,
            max_source_length, max_target_length, text_truncate, knowledge_truncate, bart_config, full_knowledge_attn=False, cuda=True
    ):
        # self.min_segment_len = min_segment_len
        # self.split_threshold = split_threshold
        self.know_threshold = know_threshold
        self.copy_threshold = copy_threshold
        # self.tfidf_path = tfidf_path
        self.merge = merge
        self.percentage = percentage
        self.preprocessor = SegmentPreprocessor_v2(
            min_segment_len=min_segment_len, threshold=split_threshold, stop_words_path=stop_words_path, stop_words_size=stop_words_size
        )

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.text_truncate = text_truncate
        self.knowledge_truncate = knowledge_truncate
        self.full_knowledge_attn = full_knowledge_attn

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.drop_know = drop_know
        self.infilling = infilling
        self.mask_id = self.tokenizer.mask_token_id
        self.mlm = mlm
        self.random_know = random_know
        self.bleu_percent = bleu_percent
        self.reverse = reverse
        self.use_sep = use_sep
        self.mix_know = mix_know
        self.add_prefix_space = add_prefix_space
        self.test_knowledge_truncate = test_knowledge_truncate
        self.test_knowledge_num = test_knowledge_num
        self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=self.mlm)


    def tokenize(self, text, add_prefix_space=True):
        return self.tokenizer(text, add_prefix_space=add_prefix_space)['input_ids'][1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, history_list, knowledge_list, tree_list=None, check_list=None, training=True):
        if training:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            target_input_list, label_list, label_m_list, label_z_list = [], [], [], []
            module_ratios = []
            bleu_list = []
            for hist, know, tree, check in zip(history_list, knowledge_list, tree_list, check_list):
                ##############################################################
                know = know.strip().split('[SEP]')
                check = check.strip().split('[SEP]')
                if self.random_know > 0:
                    chosen = random.choice(list(range(len(check))))
                    # chosen = random.choice([0, 1, 2])
                    if random.random() >= self.random_know:
                        chosen = 0
                else:
                    chosen = 0
                # second_check = check[1].strip()
                second_check = check[0].strip() # todo: del
                know = know[chosen].strip() # chosen == 0 achieve best performance !!!
                check = check[chosen].strip()
                ##############################################################

                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                if self.mix_know > 0.0:
                    know_new = know.copy()
                    know_new[random.choice(list(range(len(know))))] = second_check
                    if random.random() < self.mix_know:
                        know = know_new.copy()

                # random.shuffle(know)  # todo: check

                ##########################
                ######## Decoder #########
                ##########################
                # get segments
                segments = self.preprocessor.segmentation(Tree.fromstring(tree.strip()), check)
                # tokenize
                target_input = []
                target_m = []
                target_z = []
                num_context, num_knowledge = 0, 0
                for seg in segments:
                    text, score = seg[0], seg[1]
                    text = ' '.join(text)
                    if self.add_prefix_space or len(target_input) == 0:
                        ids = self.tokenize(text)
                    else:
                        ids = self.tokenize(text, add_prefix_space=False)
                    target_input.extend(ids)
                    target_m.extend([0] * (len(ids) - 1) + [1])
                    if score < self.know_threshold:
                        target_z.extend([0] * len(ids)) # context
                        num_context += 1
                    else:
                        target_z.extend([1] * len(ids))
                        num_knowledge += 1
                    # elif score < self.copy_threshold:
                    #     target_z.extend([1] * len(ids)) # knowledge understanding
                    # else:
                    #     target_z.extend([2] * len(ids)) # knowledge copy
                module_ratios.append(num_context / (num_knowledge + 0.00001))
                bleu_list.append(bleu_metric([check], [' '.join(Tree.fromstring(tree.strip()).leaves())])[1])

                if self.merge:
                    for t in range(len(target_m) - 1):
                        if target_m[t] == 1 and target_z[t] == target_z[t+1]:
                            target_m[t] = 0

                target_input = target_input[:self.max_target_length-2]
                target_m = target_m[:self.max_target_length-2]
                target_z = target_z[:self.max_target_length-2]

                label = target_input.copy() + [self.eos_id]
                target_input = [self.bos_id] + target_input.copy()
                label_m = target_m[:-1] + [0, 1] # the token before [EOS] should generate 0, because it is not natural to treat [EOS] as a new segment
                label_z = target_z + target_z[-1:] # similarly, [EOS] and the last token belong to the same segment, and are generated by same module

                #########################
                ######## Encoder ########
                #########################
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                special_tokens = 4 if self.use_sep else 2

                knowledge_input = []
                for k in know:
                    tmp = [self.know_id] + self.tokenize(k)[:self.knowledge_truncate]
                    if k.strip() == check.strip() and self.drop_know > 0:
                        if random.random() < self.drop_know:
                            if self.infilling:
                                tmp = [self.know_id] + [self.mask_id]
                            else:
                                tmp = [self.know_id] + [self.mask_id] * (len(tmp) - 1)
                    if len(knowledge_input) + len(tmp) + len(history_input) + special_tokens < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + special_tokens == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - special_tokens
                        knowledge_input += tmp[:tmp_truncate]
                        break

                if self.use_sep:
                    source_input = [self.bos_id] + knowledge_input + [self.eos_id] + [self.eos_id] + history_input + [self.eos_id]
                    source_mask = [1] * len(source_input)
                    history_mask = [0] * (len(knowledge_input) + 3) + [1] * len(history_input) + [0]  # for context module
                    knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 3)  # for knowledge module
                else:
                    if self.reverse:
                        source_input = [self.bos_id] + history_input + knowledge_input + [self.eos_id]
                        source_mask = [1] * len(source_input)
                        history_mask = [0] + [1] * len(history_input) + [0] * (len(knowledge_input) + 1)
                        knowledge_mask = [0] * (len(history_input) + 1) + [1] * len(knowledge_input) + [0]
                    else:
                        source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                        source_mask = [1] * len(source_input)
                        history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0] # for context module
                        knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1) # for knowledge module

                #########################
                ######## PADDING ########
                #########################
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))
                target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))
                label = label + [self.pad_id] * (self.max_target_length - len(label))
                label_m = label_m + [0] * (self.max_target_length - len(label_m))
                label_z = label_z + [0] * (self.max_target_length - len(label_z))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)
                target_input_list.append(target_input)
                label_list.append(label)
                label_m_list.append(label_m)
                label_z_list.append(label_z)

            if self.bleu_percent < 1:
                selected = sorted(range(len(bleu_list)), key=lambda x: bleu_list[x], reverse=True)[:int(len(bleu_list) * self.bleu_percent)]
            else:
                selected = sorted(range(len(module_ratios)), key=lambda x: module_ratios[x])[:int(len(module_ratios) * self.percentage)]
            # if self.bleu_percent < 1:
            #     selected = sorted(selected, key=lambda x: bleu_list[x], reverse=True)[:int(len(selected) * self.bleu_percent)]
            source_input_list = [x for idx, x in enumerate(source_input_list) if idx in selected]
            source_mask_list = [x for idx, x in enumerate(source_mask_list) if idx in selected]
            history_mask_list = [x for idx, x in enumerate(history_mask_list) if idx in selected]
            knowledge_mask_list = [x for idx, x in enumerate(knowledge_mask_list) if idx in selected]
            target_input_list = [x for idx, x in enumerate(target_input_list) if idx in selected]
            label_list = [x for idx, x in enumerate(label_list) if idx in selected]
            label_m_list = [x for idx, x in enumerate(label_m_list) if idx in selected]
            label_z_list = [x for idx, x in enumerate(label_z_list) if idx in selected]

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)
            target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)
            label_list = torch.tensor(label_list, device=self.device, dtype=torch.long)
            label_m_list = torch.tensor(label_m_list, device=self.device, dtype=torch.long)
            label_z_list = torch.tensor(label_z_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            target_input_list = self.trim_batch(target_input_list)
            label_list = label_list[:, :target_input_list.size(1)]
            label_m_list = label_m_list[:, :target_input_list.size(1)]
            label_z_list = label_z_list[:, :target_input_list.size(1)]

            if self.mlm > 0:
                source_input_list, _ = self.data_collator.mask_tokens(inputs=source_input_list.to('cpu'))
                source_input_list = source_input_list.to(self.device)

            return {
                "input_ids": source_input_list, # encoder
                "attention_mask": source_mask_list, # encoder
                "decoder_input_ids": target_input_list, # decoder
                "history_attention_mask": history_mask_list, # context module
                "knowledge_attention_mask": source_mask_list.clone() if self.full_knowledge_attn else knowledge_mask_list, # knowledge module
                "labels": label_list.contiguous(),
                "labels_m": label_m_list,
                "labels_z": label_z_list,
                "use_cache": False,
            }
        else:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            for hist, know in zip(history_list, knowledge_list):
                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                special_tokens = 4 if self.use_sep else 2
                knowledge_input = []
                for k in know[:self.test_knowledge_num]:
                    tmp = [self.know_id] + self.tokenize(k)[:self.test_knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + special_tokens < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + special_tokens == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - special_tokens
                        knowledge_input += tmp[:tmp_truncate]
                        break

                if self.use_sep:
                    source_input = [self.bos_id] + knowledge_input + [self.eos_id] + [self.eos_id] + history_input + [self.eos_id]
                    source_mask = [1] * len(source_input)
                    history_mask = [0] * (len(knowledge_input) + 3) + [1] * len(history_input) + [0]
                    knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 3)
                else:
                    if self.reverse:
                        source_input = [self.bos_id] + history_input + knowledge_input + [self.eos_id]
                        source_mask = [1] * len(source_input)
                        history_mask = [0] + [1] * len(history_input) + [0] * (len(knowledge_input) + 1)
                        knowledge_mask = [0] * (len(history_input) + 1) + [1] * len(knowledge_input) + [0]
                    else:
                        source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                        source_mask = [1] * len(source_input)
                        history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0]  # for context module
                        knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1)  # for knowledge module

                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            return {
                "input_ids": source_input_list,  # encoder
                "attention_mask": source_mask_list,  # encoder
                "history_attention_mask": history_mask_list,  # context module
                "knowledge_attention_mask": source_mask_list.clone() if self.full_knowledge_attn else knowledge_mask_list,  # knowledge module
                "decoder_start_token_id": self.bos_id,
                "use_cache": True,
            }


class BartSentiBatcher(object):
    def __init__(
            self, min_segment_len, split_threshold, know_threshold, stop_words_path, stop_words_size,
            percentage, test_knowledge_truncate, test_knowledge_num, senti_threshold, senti_percentage,
            max_source_length, max_target_length, text_truncate, knowledge_truncate, bart_config, cuda=True
    ):
        self.know_threshold = know_threshold
        self.percentage = percentage
        self.preprocessor = SegmentPreprocessor_v2(
            min_segment_len=min_segment_len, threshold=split_threshold, stop_words_path=stop_words_path, stop_words_size=stop_words_size
        )

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.text_truncate = text_truncate
        self.knowledge_truncate = knowledge_truncate

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.senti_threshold = senti_threshold
        self.senti_percentage = senti_percentage
        self.test_knowledge_truncate = test_knowledge_truncate
        self.test_knowledge_num = test_knowledge_num

    def tokenize(self, text, add_prefix_space=True):
        return self.tokenizer(text, add_prefix_space=add_prefix_space)['input_ids'][1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, history_list, knowledge_list, tree_list=None, check_list=None, segment_dict_list=None, training=True):
        if training:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            target_input_list, label_list, label_m_list, label_z_list = [], [], [], []
            module_ratios = []
            senti_intensity = []
            positive_nums = []
            negative_nums = []
            for hist, know, tree, check, segment_dict in zip(history_list, knowledge_list, tree_list, check_list, segment_dict_list):
                segment_dict = json.loads(segment_dict)

                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                ##########################
                ######## Decoder #########
                ##########################
                # get segments
                segments = self.preprocessor.segmentation(Tree.fromstring(tree.strip()), check)
                # tokenize
                target_input = []
                target_m = []
                target_z = []
                num_context, num_knowledge, num_positive, num_negative = 0, 0, 0, 0
                intensity = 1.0
                for seg in segments:
                    text, score = seg[0], seg[1]
                    text = ' '.join(text)
                    ids = self.tokenize(text)
                    target_input.extend(ids)
                    target_m.extend([0] * (len(ids) - 1) + [1])
                    if score < self.know_threshold:
                        target_z.extend([0] * len(ids)) # context
                        num_context += 1
                    else:
                        senti_score = segment_dict[text]
                        if senti_score >= 1 - self.senti_threshold:
                            target_z.extend([2] * len(ids))
                            num_positive += 1
                            intensity = min(intensity, 1 - senti_score)
                        elif senti_score <= self.senti_threshold:
                            target_z.extend([3] * len(ids))
                            num_negative += 1
                            intensity = min(intensity, senti_score)
                        else:
                            target_z.extend([1] * len(ids))
                        num_knowledge += 1
                module_ratios.append(num_context / (num_knowledge + 1e-5))
                senti_intensity.append(intensity)
                positive_nums.append(num_positive)
                negative_nums.append(num_negative)

                target_input = target_input[:self.max_target_length-2]
                target_m = target_m[:self.max_target_length-2]
                target_z = target_z[:self.max_target_length-2]

                label = target_input.copy() + [self.eos_id]
                target_input = [self.bos_id] + target_input.copy()
                label_m = target_m[:-1] + [0, 1] # the token before [EOS] should generate 0, because it is not natural to treat [EOS] as a new segment
                label_z = target_z + target_z[-1:] # similarly, [EOS] and the last token belong to the same segment, and are generated by same module

                #########################
                ######## Encoder ########
                #########################
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                knowledge_input = []
                for k in know:
                    tmp = [self.know_id] + self.tokenize(k)[:self.knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + 2 < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + 2 == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - 2
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)
                history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0] # for context module
                knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1) # for knowledge module

                #########################
                ######## PADDING ########
                #########################
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))
                target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))
                label = label + [self.pad_id] * (self.max_target_length - len(label))
                label_m = label_m + [0] * (self.max_target_length - len(label_m))
                label_z = label_z + [0] * (self.max_target_length - len(label_z))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)
                target_input_list.append(target_input)
                label_list.append(label)
                label_m_list.append(label_m)
                label_z_list.append(label_z)

            selected = range(len(module_ratios))
            selected = sorted(selected, key=lambda x: module_ratios[x])[:int(len(selected) * self.percentage)]
            selected = sorted(selected, key=lambda x: senti_intensity[x])[:int(len(selected) * self.senti_percentage)]

            source_input_list = [x for idx, x in enumerate(source_input_list) if idx in selected]
            source_mask_list = [x for idx, x in enumerate(source_mask_list) if idx in selected]
            history_mask_list = [x for idx, x in enumerate(history_mask_list) if idx in selected]
            knowledge_mask_list = [x for idx, x in enumerate(knowledge_mask_list) if idx in selected]
            target_input_list = [x for idx, x in enumerate(target_input_list) if idx in selected]
            label_list = [x for idx, x in enumerate(label_list) if idx in selected]
            label_m_list = [x for idx, x in enumerate(label_m_list) if idx in selected]
            label_z_list = [x for idx, x in enumerate(label_z_list) if idx in selected]

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)
            target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)
            label_list = torch.tensor(label_list, device=self.device, dtype=torch.long)
            label_m_list = torch.tensor(label_m_list, device=self.device, dtype=torch.long)
            label_z_list = torch.tensor(label_z_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            target_input_list = self.trim_batch(target_input_list)
            label_list = label_list[:, :target_input_list.size(1)]
            label_m_list = label_m_list[:, :target_input_list.size(1)]
            label_z_list = label_z_list[:, :target_input_list.size(1)]

            return {
                "input_ids": source_input_list, # encoder
                "attention_mask": source_mask_list, # encoder
                "decoder_input_ids": target_input_list, # decoder
                "history_attention_mask": history_mask_list, # context module
                "knowledge_attention_mask": knowledge_mask_list, # knowledge module
                "labels": label_list.contiguous(),
                "labels_m": label_m_list,
                "labels_z": label_z_list,
                "use_cache": False,
            }
        else:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            for hist, know in zip(history_list, knowledge_list):
                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                knowledge_input = []
                for k in know[:self.test_knowledge_num]:
                    tmp = [self.know_id] + self.tokenize(k)[:self.test_knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + 2 < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + 2 == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - 2
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)
                history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0]  # for context module
                knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1)  # for knowledge module

                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            return {
                "input_ids": source_input_list,  # encoder
                "attention_mask": source_mask_list,  # encoder
                "history_attention_mask": history_mask_list,  # context module
                "knowledge_attention_mask": knowledge_mask_list,  # knowledge module
                "decoder_start_token_id": self.bos_id,
                "use_cache": True,
            }



################### Domain Adaptation ##########################
from transformers import DataCollatorForLanguageModeling

class RedditKsDataset_v2(Dataset):
    def __init__(self, data_path):
        dataset = h5py.File(data_path, 'r')
        self._src = dataset['src']
        self._tgt = dataset['tgt']
        self.valid_indices = []
        for i in tqdm(range(len(self._tgt))):
            tgt = self._tgt[i]
            label = int(tgt.strip().split()[-1])
            if label == 1:
                self.valid_indices.append(i)

        self._n_data = len(self.valid_indices)

    def preprocess(self, sentence):
        tokens = sentence.strip().split()
        remove_list = ['<#bleu0#>', '<#bleu1#>', '<#bleu2#>', '<#bleu3#>', '<#bleu4#>', '<#bleu5#>', '<#bleu6#>', '<#bleu7#>', '<#bleu8#>', '<#bleu9#>', '<#bleu10#>', '<#bleu11#>', '<#bleu12#>', '<#bleu13#>']
        tokens = [t for t in tokens if t not in remove_list]
        if len(tokens) == 0:
            tokens = ['.']
        return detokenize(' '.join(tokens))

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        i = self.valid_indices[i]
        src = self._src[i]
        query = src.strip().split('<#Q2K#>')[0].strip()
        query = query.strip().split('<#Q#>')
        query = [self.preprocess(q.strip()) for q in query]
        candidate = self.preprocess(src.strip().split('<#Q2K#>')[1].strip())
        return '\n\n'.join(query), candidate

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        return query_list, candidate_list


class BertBatcher_v2(object):
    def __init__(self, block_size, knowledge_truncate, text_truncate, bert_config, mlm_prob=0.15, cuda=True):
        self.block_size = block_size
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = BertTokenizer.from_pretrained(bert_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)

        self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=mlm_prob)

        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')

        self.device = torch.device('cuda' if cuda else 'cpu')

    def tokenize(self, text):
        return self.tokenizer(text, add_special_tokens=True)['input_ids'][1:-1]

    def __call__(self, query_list, candidate_list):
        input_ids_list, attention_mask_list, token_type_ids_list = [], [], []
        for query, cand in zip(query_list, candidate_list):
            query_input = []
            for q in query:
                query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

            candidate_input = [self.know_id] + self.tokenize(cand)[:self.knowledge_truncate]
            query_input = query_input[:self.block_size - 3 - len(candidate_input)]
            input_ids = [self.cls_id] + query_input + [self.sep_id] + candidate_input + [self.sep_id]
            attention_mask = [1] * len(input_ids)
            token_type_ids = [0] * (len(query_input) + 2) + [1] * (len(candidate_input) + 1)

            # padding
            input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
            attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))
            token_type_ids = token_type_ids + [0] * (self.block_size - len(token_type_ids))

            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)
            token_type_ids_list.append(token_type_ids)
        input_ids_list = torch.tensor(input_ids_list, dtype=torch.long)
        input_ids_list, labels_list = self.data_collator.mask_tokens(inputs=input_ids_list)
        input_ids_list = input_ids_list.to(self.device)
        labels_list = labels_list.to(self.device)
        attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
        token_type_ids_list = torch.tensor(token_type_ids_list, device=self.device, dtype=torch.long)
        return {
            "input_ids": input_ids_list,
            "attention_mask": attention_mask_list,
            "token_type_ids": token_type_ids_list,
            "labels": labels_list
        }


############################ Roberta ######################################
from transformers import RobertaTokenizer

class RobertaBatcher(object):
    def __init__(self, block_size, knowledge_truncate, text_truncate, bert_config, cuda=True):
        self.block_size = block_size
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = RobertaTokenizer.from_pretrained(bert_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')

        self.device = torch.device('cuda' if cuda else 'cpu')

    def tokenize(self, text):
        return self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]

    def __call__(self, query_list, candidate_list, label_list=None, training=True):
        if training:
            input_ids_list, attention_mask_list, labels = [], [], []
            for query, cand, label in zip(query_list, candidate_list, label_list):
                query_input = []
                for q in query:
                    query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

                candidate_input = self.tokenize(cand)[:self.knowledge_truncate]
                query_input = query_input[-(self.block_size - 4 - len(candidate_input)):]
                input_ids = [self.cls_id] + query_input + [self.sep_id] + [self.sep_id] + candidate_input + [self.sep_id]
                attention_mask = [1] * len(input_ids)

                # padding
                input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
                attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))

                input_ids_list.append(input_ids)
                attention_mask_list.append(attention_mask)
                labels.append(label)
            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
            labels = torch.tensor(labels, device=self.device, dtype=torch.long)
            return {
                "input_ids": input_ids_list,
                "attention_mask": attention_mask_list,
                "labels": labels
            }
        else:
            input_ids_list, attention_mask_list = [], []
            for query, cand in zip(query_list, candidate_list):
                query_input = []
                for q in query:
                    query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

                candidate_input = self.tokenize(cand)[:self.knowledge_truncate]
                query_input = query_input[-(self.block_size - 4 - len(candidate_input)):]
                input_ids = [self.cls_id] + query_input + [self.sep_id] + [self.sep_id] + candidate_input + [self.sep_id]
                attention_mask = [1] * len(input_ids)

                # padding
                input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
                attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))

                input_ids_list.append(input_ids)
                attention_mask_list.append(attention_mask)
            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
            return {
                "input_ids": input_ids_list,
                "attention_mask": attention_mask_list,
            }

class RobertaBatcher_V2(object):
    def __init__(self, block_size, knowledge_truncate, text_truncate, bert_config, mlm_prob=0.15, cuda=True):
        self.block_size = block_size
        self.knowledge_truncate = knowledge_truncate
        self.text_truncate = text_truncate

        self.tokenizer = RobertaTokenizer.from_pretrained(bert_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)

        self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=mlm_prob)

        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')

        self.device = torch.device('cuda' if cuda else 'cpu')

    def tokenize(self, text):
        return self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]

    def __call__(self, query_list, candidate_list):
        input_ids_list, attention_mask_list = [], []
        for query, cand in zip(query_list, candidate_list):
            query_input = []
            for q in query:
                query_input += ([self.conv_id] + self.tokenize(q)[:self.text_truncate])

            candidate_input = self.tokenize(cand)[:self.knowledge_truncate]
            query_input = query_input[-(self.block_size - 4 - len(candidate_input)):]
            input_ids = [self.cls_id] + query_input + [self.sep_id] + [self.sep_id] + candidate_input + [self.sep_id]
            attention_mask = [1] * len(input_ids)

            # padding
            input_ids = input_ids + [self.pad_id] * (self.block_size - len(input_ids))
            attention_mask = attention_mask + [0] * (self.block_size - len(attention_mask))

            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)
        input_ids_list = torch.tensor(input_ids_list, dtype=torch.long)
        input_ids_list, labels_list = self.data_collator.mask_tokens(inputs=input_ids_list)
        input_ids_list = input_ids_list.to(self.device)
        labels_list = labels_list.to(self.device)
        attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.long)
        return {
            "input_ids": input_ids_list,
            "attention_mask": attention_mask_list,
            "labels": labels_list
        }

from data_generator.wizard_generator_v2 import data_generator
from data_generator.cmudog_generator_v2 import data_generator as cmudog_generator
from tqdm import tqdm
class WizardKsDataset_V2(Dataset):
    def __init__(self, data_path):

        self._query = []
        self._candidate = []
        self._label = []
        for history_line, response_line, knowledge_line in tqdm(data_generator(data_path)):
            knowledge_list = knowledge_line.strip().split('<#K#>')
            knowledge_list = [k.strip() for k in knowledge_list]
            history = history_line.strip().split('<#Q#>')
            history = [h.strip() for h in history]
            for i, k in enumerate(knowledge_list):
                if i == 0:
                    self._label.append(1)
                else:
                    self._label.append(0)
                self._query.append(history.copy())
                self._candidate.append(k)
        self._n_data = len(self._query)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        query = self._query[i] # list[str]
        candidate = self._candidate[i] # str
        label = self._label[i] # int
        return '\n\n'.join(query), candidate, label

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        label_list = [item[2] for item in batch]
        return query_list, candidate_list, label_list

class CMUDoGKsDataset_V2(Dataset):
    def __init__(self, data_path):

        self._query = []
        self._candidate = []
        self._label = []
        for history_line, response_line, knowledge_line in tqdm(cmudog_generator(
                os.path.join(data_path, 'src-test-tokenized.txt'),
                os.path.join(data_path, 'tgt-test-tokenized.txt'),
                os.path.join(data_path, 'knl-test-tokenized.txt')
        )):
            knowledge_list = knowledge_line.strip().split('<#K#>')
            knowledge_list = [k.strip() for k in knowledge_list]
            history = history_line.strip().split('<#Q#>')
            history = [h.strip() for h in history]
            for i, k in enumerate(knowledge_list):
                if i == 0:
                    self._label.append(1)
                else:
                    self._label.append(0)
                self._query.append(history.copy())
                self._candidate.append(k)
        self._n_data = len(self._query)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        query = self._query[i] # list[str]
        candidate = self._candidate[i] # str
        label = self._label[i] # int
        return '\n\n'.join(query), candidate, label

    @staticmethod
    def collate_fn(batch):
        query_list = [item[0] for item in batch]
        candidate_list = [item[1] for item in batch]
        label_list = [item[2] for item in batch]
        return query_list, candidate_list, label_list

class BartConvBatcher(object):
    def __init__(
            self, mlm, max_source_length, max_target_length, text_truncate, knowledge_truncate, bart_config, cuda=True
    ):

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.text_truncate = text_truncate
        self.knowledge_truncate = knowledge_truncate

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.mask_id = self.tokenizer.mask_token_id
        self.mlm = mlm
        self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=self.mlm)


    def tokenize(self, text):
        return self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, history_list, tree_list=None, training=True):
        if training:
            source_input_list, source_mask_list = [], []
            target_input_list, label_list = [], []
            for hist, tree in zip(history_list, tree_list):

                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                resp = ' '.join(Tree.fromstring(tree.strip()).leaves())

                ##########################
                ######## Decoder #########
                ##########################
                target_input = self.tokenize(resp)
                target_input = target_input[:self.max_target_length-2]
                label = target_input.copy() + [self.eos_id]
                target_input = [self.bos_id] + target_input.copy()

                #########################
                ######## Encoder ########
                #########################
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-(self.max_source_length - 2):]
                source_input = [self.bos_id] + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)

                #########################
                ######## PADDING ########
                #########################
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))
                label = label + [self.pad_id] * (self.max_target_length - len(label))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                target_input_list.append(target_input)
                label_list.append(label)

            selected = list(range(len(source_input_list)))
            source_input_list = [x for idx, x in enumerate(source_input_list) if idx in selected]
            source_mask_list = [x for idx, x in enumerate(source_mask_list) if idx in selected]
            target_input_list = [x for idx, x in enumerate(target_input_list) if idx in selected]
            label_list = [x for idx, x in enumerate(label_list) if idx in selected]

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)
            label_list = torch.tensor(label_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            target_input_list = self.trim_batch(target_input_list)
            label_list = label_list[:, :target_input_list.size(1)]

            if self.mlm > 0:
                source_input_list, _ = self.data_collator.mask_tokens(inputs=source_input_list.to('cpu'))
                source_input_list = source_input_list.to(self.device)

            return {
                "input_ids": source_input_list, # encoder
                "attention_mask": source_mask_list, # encoder
                "decoder_input_ids": target_input_list, # decoder
                "labels": label_list.contiguous(),
                "use_cache": False,
            }
        else:
            source_input_list, source_mask_list = [], []
            for hist in history_list:
                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]

                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-(self.max_source_length - 2):]

                source_input = [self.bos_id] + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)

                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)

            return {
                "input_ids": source_input_list,  # encoder
                "attention_mask": source_mask_list,  # encoder
                "decoder_start_token_id": self.bos_id,
                "use_cache": True,
            }

class BartPreBatcher(object):
    def __init__(
            self, mlm, max_source_length, max_target_length, text_truncate, knowledge_truncate, bart_config, cuda=True
    ):

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.text_truncate = text_truncate
        self.knowledge_truncate = knowledge_truncate

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.mask_id = self.tokenizer.mask_token_id
        self.mlm = mlm
        self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm_probability=self.mlm)


    def tokenize(self, text):
        return self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, history_list, tree_list):
        source_input_list, source_mask_list = [], []
        target_input_list, label_list = [], []
        for hist, tree in zip(history_list, tree_list):

            hist = hist.strip().split('<#Q#>')
            hist = [h.strip() for h in hist]
            resp = ' '.join(Tree.fromstring(tree.strip()).leaves())

            source_input = []
            for h in hist + [resp]:
                source_input += self.tokenize(h)
            source_input = source_input[-(self.max_source_length - 2):]

            target_input = [self.bos_id] + source_input.copy()
            label = source_input.copy() + [self.eos_id]
            source_input = [self.bos_id] + source_input + [self.eos_id]
            source_mask = [1] * len(source_input)

            source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
            source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
            target_input = target_input + [self.pad_id] * (self.max_source_length - len(target_input))
            label = label + [self.pad_id] * (self.max_source_length - len(label))

            source_input_list.append(source_input)
            source_mask_list.append(source_mask)
            target_input_list.append(target_input)
            label_list.append(label)

        selected = list(range(len(source_input_list)))
        source_input_list = [x for idx, x in enumerate(source_input_list) if idx in selected]
        source_mask_list = [x for idx, x in enumerate(source_mask_list) if idx in selected]
        target_input_list = [x for idx, x in enumerate(target_input_list) if idx in selected]
        label_list = [x for idx, x in enumerate(label_list) if idx in selected]

        source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
        source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
        target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)
        label_list = torch.tensor(label_list, device=self.device, dtype=torch.long)

        source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
        target_input_list = self.trim_batch(target_input_list)
        label_list = label_list[:, :target_input_list.size(1)]

        if self.mlm > 0:
            source_input_list, _ = self.data_collator.mask_tokens(inputs=source_input_list.to('cpu'))
            source_input_list = source_input_list.to(self.device)

        return {
            "input_ids": source_input_list, # encoder
            "attention_mask": source_mask_list, # encoder
            "decoder_input_ids": target_input_list, # decoder
            "labels": label_list.contiguous(),
            "use_cache": False,
        }

class BartFinetuneBatcher(object):
    def __init__(
            self, min_segment_len, split_threshold, know_threshold, copy_threshold,
            merge, percentage, test_knowledge_truncate, test_knowledge_num,
            max_source_length, max_target_length, text_truncate, knowledge_truncate, bart_config, full_knowledge_attn=False, cuda=True
    ):
        # self.min_segment_len = min_segment_len
        # self.split_threshold = split_threshold
        self.know_threshold = know_threshold
        self.copy_threshold = copy_threshold
        # self.tfidf_path = tfidf_path
        self.merge = merge
        self.percentage = percentage
        self.preprocessor = SegmentPreprocessor_v3(min_segment_len=min_segment_len, threshold=split_threshold)

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.text_truncate = text_truncate
        self.knowledge_truncate = knowledge_truncate
        self.full_knowledge_attn = full_knowledge_attn

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)

        SPECIAL_TOKENS_DICT = {'additional_special_tokens': ['<#K#>', '<#Q#>']}
        self.tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.know_id = self.tokenizer.convert_tokens_to_ids('<#K#>')
        self.conv_id = self.tokenizer.convert_tokens_to_ids('<#Q#>')
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.mask_id = self.tokenizer.mask_token_id
        self.test_knowledge_truncate = test_knowledge_truncate
        self.test_knowledge_num = test_knowledge_num


    def tokenize(self, text, add_prefix_space=True):
        return self.tokenizer(text, add_prefix_space=add_prefix_space)['input_ids'][1:-1]

    def trim_batch(self, input_ids, attention_mask=None):
        """Remove columns that are populated exclusively by pad_token_id"""
        keep_column_mask = input_ids.ne(self.pad_id).any(dim=0)
        if attention_mask is None:
            return input_ids[:, keep_column_mask]
        else:
            return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

    def __call__(self, history_list, knowledge_list, tree_list=None, check_list=None, training=True):
        if training:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            target_input_list, label_list, label_m_list, label_z_list = [], [], [], []
            module_ratios = []
            for hist, know, tree, check in zip(history_list, knowledge_list, tree_list, check_list):
                ##############################################################
                know = know.strip().split('[SEP]')
                check = check.strip().split('[SEP]')
                know = know[0].strip() # chosen == 0 achieve best performance !!!
                check = check[0].strip()
                ##############################################################

                # print('knowledge: ', know)
                # print('check: ', check)
                # print('history: ', hist)
                # print('response: ', ' '.join(Tree.fromstring(tree).leaves()))
                # input('>>>')

                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                ##########################
                ######## Decoder #########
                ##########################
                # get segments
                try:
                    segments = self.preprocessor.segmentation(Tree.fromstring(tree.strip()), check)
                except:
                    print(tree)
                    input('>>>')
                # tokenize
                target_input = []
                target_m = []
                target_z = []
                num_context, num_knowledge = 0, 0
                for seg in segments:
                    text, score = seg[0], seg[1]
                    text = ' '.join(text)
                    text = text.replace('-LRB-', '(').replace('-RRB-', ')').replace('...', '. . .').replace('--', '–').replace('-LSB- ', '[').replace(' -RSB-', ']')
                    ids = self.tokenize(text)
                    target_input.extend(ids)
                    target_m.extend([0] * (len(ids) - 1) + [1])
                    if score < self.know_threshold:
                        target_z.extend([0] * len(ids)) # context
                        num_context += 1
                    else:
                        target_z.extend([1] * len(ids))
                        num_knowledge += 1
                    # elif score < self.copy_threshold:
                    #     target_z.extend([1] * len(ids)) # knowledge understanding
                    # else:
                    #     target_z.extend([2] * len(ids)) # knowledge copy
                module_ratios.append(num_context / (num_knowledge + 0.00001))

                if self.merge:
                    for t in range(len(target_m) - 1):
                        if target_m[t] == 1 and target_z[t] == target_z[t+1] and target_z[t] == 0:
                            target_m[t] = 0

                target_input = target_input[:self.max_target_length-2]
                target_m = target_m[:self.max_target_length-2]
                target_z = target_z[:self.max_target_length-2]

                label = target_input.copy() + [self.eos_id]
                target_input = [self.bos_id] + target_input.copy()
                label_m = target_m[:-1] + [0, 1] # the token before [EOS] should generate 0, because it is not natural to treat [EOS] as a new segment
                label_z = target_z + target_z[-1:] # similarly, [EOS] and the last token belong to the same segment, and are generated by same module

                #########################
                ######## Encoder ########
                #########################
                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                special_tokens = 2

                knowledge_input = []
                for k in know:
                    tmp = [self.know_id] + self.tokenize(k)[:self.knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + special_tokens < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + special_tokens == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - special_tokens
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)
                history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0] # for context module
                knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1) # for knowledge module

                #########################
                ######## PADDING ########
                #########################
                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))
                target_input = target_input + [self.pad_id] * (self.max_target_length - len(target_input))
                label = label + [self.pad_id] * (self.max_target_length - len(label))
                label_m = label_m + [0] * (self.max_target_length - len(label_m))
                label_z = label_z + [0] * (self.max_target_length - len(label_z))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)
                target_input_list.append(target_input)
                label_list.append(label)
                label_m_list.append(label_m)
                label_z_list.append(label_z)

            selected = sorted(range(len(module_ratios)), key=lambda x: module_ratios[x])[:int(len(module_ratios) * self.percentage)]
            source_input_list = [x for idx, x in enumerate(source_input_list) if idx in selected]
            source_mask_list = [x for idx, x in enumerate(source_mask_list) if idx in selected]
            history_mask_list = [x for idx, x in enumerate(history_mask_list) if idx in selected]
            knowledge_mask_list = [x for idx, x in enumerate(knowledge_mask_list) if idx in selected]
            target_input_list = [x for idx, x in enumerate(target_input_list) if idx in selected]
            label_list = [x for idx, x in enumerate(label_list) if idx in selected]
            label_m_list = [x for idx, x in enumerate(label_m_list) if idx in selected]
            label_z_list = [x for idx, x in enumerate(label_z_list) if idx in selected]

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)
            target_input_list = torch.tensor(target_input_list, device=self.device, dtype=torch.long)
            label_list = torch.tensor(label_list, device=self.device, dtype=torch.long)
            label_m_list = torch.tensor(label_m_list, device=self.device, dtype=torch.long)
            label_z_list = torch.tensor(label_z_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            target_input_list = self.trim_batch(target_input_list)
            label_list = label_list[:, :target_input_list.size(1)]
            label_m_list = label_m_list[:, :target_input_list.size(1)]
            label_z_list = label_z_list[:, :target_input_list.size(1)]

            return {
                "input_ids": source_input_list, # encoder
                "attention_mask": source_mask_list, # encoder
                "decoder_input_ids": target_input_list, # decoder
                "history_attention_mask": history_mask_list, # context module
                "knowledge_attention_mask": source_mask_list.clone() if self.full_knowledge_attn else knowledge_mask_list, # knowledge module
                "labels": label_list.contiguous(),
                "labels_m": label_m_list,
                "labels_z": label_z_list,
                "use_cache": False,
            }
        else:
            source_input_list, source_mask_list, history_mask_list, knowledge_mask_list = [], [], [], []
            for hist, know in zip(history_list, knowledge_list):
                hist = hist.strip().split('<#Q#>')
                hist = [h.strip() for h in hist]
                know = know.strip().split('<#K#>')
                know = [k.strip() for k in know]

                history_input = []
                for h in hist:
                    history_input += ([self.conv_id] + self.tokenize(h))
                history_input = history_input[-self.text_truncate:]

                special_tokens = 2
                knowledge_input = []
                for k in know[:self.test_knowledge_num]:
                    tmp = [self.know_id] + self.tokenize(k)[:self.test_knowledge_truncate]
                    if len(knowledge_input) + len(tmp) + len(history_input) + special_tokens < self.max_source_length:
                        knowledge_input += tmp
                    elif len(knowledge_input) + len(tmp) + len(history_input) + special_tokens == self.max_source_length:
                        knowledge_input += tmp
                        break
                    else:
                        tmp_truncate = self.max_source_length - len(knowledge_input) - len(history_input) - special_tokens
                        knowledge_input += tmp[:tmp_truncate]
                        break

                source_input = [self.bos_id] + knowledge_input + history_input + [self.eos_id]
                source_mask = [1] * len(source_input)
                history_mask = [0] * (len(knowledge_input) + 1) + [1] * len(history_input) + [0]  # for context module
                knowledge_mask = [0] + [1] * len(knowledge_input) + [0] * (len(history_input) + 1)  # for knowledge module

                source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
                source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
                history_mask = history_mask + [0] * (self.max_source_length - len(history_mask))
                knowledge_mask = knowledge_mask + [0] * (self.max_source_length - len(knowledge_mask))

                source_input_list.append(source_input)
                source_mask_list.append(source_mask)
                history_mask_list.append(history_mask)
                knowledge_mask_list.append(knowledge_mask)

            source_input_list = torch.tensor(source_input_list, device=self.device, dtype=torch.long)
            source_mask_list = torch.tensor(source_mask_list, device=self.device, dtype=torch.long)
            history_mask_list = torch.tensor(history_mask_list, device=self.device, dtype=torch.long)
            knowledge_mask_list = torch.tensor(knowledge_mask_list, device=self.device, dtype=torch.long)

            source_input_list, source_mask_list = self.trim_batch(source_input_list, attention_mask=source_mask_list)
            history_mask_list = history_mask_list[:, :source_mask_list.size(1)]
            knowledge_mask_list = knowledge_mask_list[:, :source_mask_list.size(1)]

            return {
                "input_ids": source_input_list,  # encoder
                "attention_mask": source_mask_list,  # encoder
                "history_attention_mask": history_mask_list,  # context module
                "knowledge_attention_mask": source_mask_list.clone() if self.full_knowledge_attn else knowledge_mask_list,  # knowledge module
                "decoder_start_token_id": self.bos_id,
                "use_cache": True,
            }


class BartMLMBatcher(object):
    def __init__(
            self, max_source_length, max_target_length, bart_config, cuda=True
    ):

        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

        self.tokenizer = BartTokenizer.from_pretrained(bart_config, do_lower_case=True)
        self.bos_id = self.tokenizer.bos_token_id
        self.pad_id = self.tokenizer.pad_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.prefix = ""
        self.device = torch.device('cuda' if cuda else 'cpu')

        self.mask_id = self.tokenizer.mask_token_id

    def tokenize(self, text):
        return self.tokenizer(text, add_prefix_space=True)['input_ids'][1:-1]

    def __call__(self, sentence, mlm=0.15):
        source_input = self.tokenize(sentence)[:(self.max_source_length - 2)]
        target_input = [self.bos_id] + source_input.copy()
        label = source_input.copy() + [self.eos_id]
        source_input = [self.bos_id] + source_input.copy() + [self.eos_id]
        source_mask = [1] * len(source_input)

        source_input = source_input + [self.pad_id] * (self.max_source_length - len(source_input))
        source_mask = source_mask + [0] * (self.max_source_length - len(source_mask))
        target_input = target_input + [self.pad_id] * (self.max_source_length - len(target_input))
        label = label + [self.pad_id] * (self.max_source_length - len(label))

        if mlm > 0:
            for i in range(len(source_input)):
                if source_input[i] in [self.pad_id, self.bos_id, self.eos_id]:
                    continue
                if random.random() < mlm:
                    if random.random() < 0.8:
                        source_input[i] = self.mask_id
                    elif random.random() < 0.5:
                        source_input[i] = random.randint(0, len(self.tokenizer) - 1)
        return {
            'input_ids': source_input,
            'attention_mask': source_mask,
            'decoder_input_ids': target_input,
            'labels': label,
        }




if __name__ == '__main__':
    knowledge = '''these are sometimes " mercy bookings " intended to get the homeless mentally ill off the street , a warm meal , etc .'''
    segment_1 = "most mentally ill were in hospitalsor homes"
    # segment_1 = "most most mentally ill were in hospitalsor homes"
    segment_2 = "very few homeless"
    segment_3 = "on the street"


    preprocessor = SegmentPreprocessor('debug_data/tfidf.pkl', '/home2/xxx/Data/pretrain-models/facebook/bart-large')
    sim = preprocessor.sim(segment_1, knowledge)