from unittest import TestCase
from model.topdown_parser import TransformerParser, LSTMParser
from transformers import AutoConfig
import torch


class ParserUnittest(TestCase):
    def test_transformer_parser(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        config = AutoConfig.from_pretrained('data/en_config/fast_r2d2_transformer_parser.json')
        parser = TransformerParser(config)
        parser.to(device)
        parser.eval()
        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20],
                     [21,22,23,24,25,26,27,28,29,30]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,1,0,0,0,0],
                     [1,1,1,1,1,1,1,1,0,0]]
        atom_spans = [[[1,3], [6,7]], [[0,1]], [[2,4]]]
        input_ids = torch.tensor(input_ids, device=device)
        attn_mask = torch.tensor(attn_mask, device=device)
        s_indices1 = parser(input_ids, attn_mask, atom_spans=atom_spans)
        print(s_indices1)
        self.assertTrue(s_indices1[0][0] in [1,2,6])
        self.assertTrue(s_indices1[0][1] in [1,2,6])
        self.assertTrue(s_indices1[0][2] in [1,2,6])

        input_ids = [[1,2,3,4,5,6,7,8,9,10,11],
                     [11,12,13,14,15,16,17,18,19,20,21],
                     [21,22,23,24,25,26,27,28,29,30,31]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1,0],
                     [1,1,1,1,1,1,0,0,0,0,0],
                     [1,1,1,1,1,1,1,1,0,0,0]]
        atom_spans = [[[1,3], [6,7]], [[0,1]], [[2,4]]]
        input_ids = torch.tensor(input_ids, device=device)
        attn_mask = torch.tensor(attn_mask, device=device)
        s_indices2 = parser(input_ids, attn_mask, atom_spans=atom_spans)
        print(s_indices2)


        input_ids = [[11,12,13,14,15,16,17,18,19,20,21],
                     [1,2,3,4,5,6,7,8,9,10,11],
                     [21,22,23,24,25,26,27,28,29,30,31]]
        attn_mask = [[1,1,1,1,1,1,0,0,0,0,0],
                     [1,1,1,1,1,1,1,1,1,1,0],
                     [1,1,1,1,1,1,1,1,0,0,0]]
        atom_spans = [[[0,1]], [[1,3], [6,7]], [[2,4]]]
        input_ids = torch.tensor(input_ids, device=device)
        attn_mask = torch.tensor(attn_mask, device=device)
        s_indices3 = parser(input_ids, attn_mask, atom_spans=atom_spans)
        print(s_indices3)

        self.assertTrue(torch.all(s_indices1[0][:9] == s_indices2[0][:9]))
        self.assertTrue(torch.all(s_indices2[0][:9] == s_indices3[1][:9]))

    def test_lstm_parser(self):
        pass