from utils.model_loader import load_model
from unittest import TestCase
from model.topdown_parser import TransformerParser
from transformers import AutoConfig
import codecs
import torch
import pickle
from utils.tree_utils import get_tree_from_merge_trajectory
from utils.vocab_builder import build_vocab, build_vocab_multiprocess, build_vocab_worker


class VocabBuilderTestcase(TestCase):
    def testVocabBuilder(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        config = AutoConfig.from_pretrained('data/en_config/fast_r2d2_tran_span.json')
        vocab_path = 'data/en_config/vocab.txt'
        corpus_path = 'data/en_wiki/wiki.span.200.ids'
        parser = TransformerParser(config)
        load_model(parser, 'data/wiki103_bitfm_span_4/parser14.bin', strict=True)
        parser.eval()
        parser.to(device)

        batch_size = 50
        base_vocab = []
        with codecs.open(vocab_path, mode='r', encoding='utf-8') as f_in:
            for token in f_in:
                base_vocab.append(token.strip())
        inputs = []
        with codecs.open(corpus_path, mode='r') as f_in:
            current_batch = 0
            ids_list = []
            atom_span_batch = []
            lines_read = 0
            for _line in f_in:
                lines_read += 1
                parts = _line.strip().split('|')
                token_ids = [int(t_id) for t_id in parts[0].split()]
                ids_list.append(token_ids)
                if len(parts) > 1:
                    spans = parts[1].split(';')
                    atom_spans = []
                    for span in spans:
                        vals = span.split(',')
                        if len(vals) == 2:
                            atom_spans.append([int(vals[0]), int(vals[1])])
                    atom_span_batch.append(atom_spans)
                else:
                    atom_span_batch.append([])
                current_batch += 1
                if current_batch % batch_size == 0:
                    max_token_ids_len = max(map(len, ids_list))
                    attn_mask = []
                    seq_lens = []
                    padding_ids = []
                    for ids in ids_list:
                        seq_lens.append(len(ids))
                        attn_mask.append([1] * len(ids) + \
                                        [0] * (max_token_ids_len - len(ids)))
                        padding_ids.append(ids + [0] * (max_token_ids_len - len(ids)))
                    input_ids = torch.tensor(padding_ids, device=device)
                    attn_mask = torch.tensor(attn_mask, device=device)
                    with torch.no_grad():
                        s_indices = parser(input_ids, attn_mask, atom_spans=atom_span_batch)
                    s_indices = s_indices.cpu().data.numpy()

                    for sent_i in range(s_indices.shape[0]):
                        tree = get_tree_from_merge_trajectory(s_indices[sent_i], seq_lens[sent_i], None)
                        if tree is not None:
                            inputs.append([tree, ids_list[sent_i]])
                    ids_list = []
                    atom_span_batch = []
                if lines_read > 10000:
                    break

        
        with open('data/tmp/input_tree_ids.pkl', mode='wb') as pickle_file:
            pickle.dump(inputs, pickle_file)
        with open('data/tmp/vocab.pkl', mode='wb') as file_out:
            pickle.dump(base_vocab, file_out)
        build_vocab_multiprocess('data/tmp/vocab.pkl', None, ['data/tmp/input_tree_ids.pkl'],
                                 output_prefix='data/tmp/worker_iter1_proc',
                                 vocab_output_path='data/tmp/iter1.pkl')
        for iter in range(20):
            build_vocab_multiprocess('data/tmp/vocab.pkl', f'data/tmp/iter{iter + 1}.pkl', ['data/tmp/input_tree_ids.pkl'],
                                    output_prefix=f'data/tmp/worker_iter{iter + 2}_proc',
                                    vocab_output_path=f'data/tmp/iter{iter+2}.pkl')