import os
import json


class Tokenizer(object):
    def get_dict(self, seqs):
        s = set()
        for seq in seqs:
            for elem in seq:
                s.add(elem)
        ret = {e: i + 1 for i, e in enumerate(s)}
        ret['<mask>'] = 0
        return ret

    def convert_sequence(self, seqs, dic):
        result = []
        for seq in seqs:
            a = []
            for elem in seq:
                if elem not in dic:
                    unk = '<unk>'
                    if unk not in dic:
                        dic[unk] = len(dic)
                    a.append(dic[unk])
                else:
                    a.append(dic[elem])
            result.append(a)
        return result

    def padding(self, seqs, el, pad=0):
        lengths = []
        for seq in seqs:
            lengths.append(len(seq) + 1)
            for _ in range(el - len(seq)):
                seq.append(pad)
        return seqs, lengths

    def initialize_basic(self, X, Y, X_test, Y_test):
        voc = self.get_dict(X)
        act = self.get_dict(Y)

        x_out = self.convert_sequence(X, voc)
        y_out = self.convert_sequence(Y, act)
        x_test_out = self.convert_sequence(X_test, voc)
        y_test_out = self.convert_sequence(Y_test, act)
        return x_out, y_out, x_test_out, y_test_out, voc, act

    def get_maximum_length(self, train, test):
        train_max = max([len(x) for x in train])
        test_max = max([len(x) for x in test])
        return max(train_max, test_max) + 1

    def initialize(self, X, Y, X_test, Y_test):
        X, Y, X_test, Y_test, voc, act = self.initialize_basic(
            X, Y, X_test, Y_test)
        max_input = self.get_maximum_length(X, X_test)
        max_output = self.get_maximum_length(Y, Y_test)
        X, X_len = self.padding(X, max_input)
        Y, Y_len = self.padding(Y, max_output)
        X_test, X_test_len = self.padding(X_test, max_input)
        Y_test, Y_test_len = self.padding(Y_test, max_output)
        samples = X, Y, X_test, Y_test
        dicts = voc, act
        lengths = X_len, Y_len, X_test_len, Y_test_len
        maxs = max_input, max_output
        return samples, dicts, lengths, maxs


def load_list(fn):
    with open(fn, 'r') as f:
        lines = f.readlines()
    lines = [line.strip() for line in lines]
    return lines


def load_processed():
    folder = "cfq_data"
    inputs = load_list(os.path.join(folder, 'input.txt'))
    outputs = load_list(os.path.join(folder, 'output.txt'))
    return inputs, outputs


def load_split(fn):
    with open(fn, 'r') as f:
        lines = f.readlines()
    assert len(lines) == 1
    line = lines[0].strip()
    return json.loads(line)


def get_terms(line):
    terms = line.split(' ')
    terms = [term for term in terms if len(term) > 0]
    return terms


def filters(inputs, outputs, index):
    inputs = [get_terms(inputs[i]) for i in index]
    outputs = [get_terms(outputs[i]) for i in index]
    return inputs, outputs


def get_cfq_data(split_file):
    split = load_split(split_file)
    a, b = load_processed()
    train_input, train_output = filters(a, b, split['trainIdxs'])
    test_input, test_output = filters(a, b, split['testIdxs'])

    tokenizer = Tokenizer()
    samples, dicts, lengths, maxs = tokenizer.initialize(train_input, train_output, test_input, test_output)
    return samples, dicts, lengths, maxs


if __name__ == '__main__':
    split_file = "../../data/cfq/splits/mcd1.json"
    samples, dicts, lengths, maxs = get_cfq_data(split_file)
    print(maxs)
