from unittest import TestCase
from data_structure.const_tree import SpanTree
from experiments.baseline_models import BertForDPClassification, BertForDPClassificationMeanPooling
import torch


class BertDPClassification(TestCase):
    def test_dp_classification(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        # make sure pytorch_model.bin, config.json, vocab.txt in 'data/pretrain_dir
        model = BertForDPClassificationMeanPooling('data/pretrain_bert', 1000)
        model.to(device)
        model.eval()

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]
        left_span = SpanTree(0, 4)
        right_span = SpanTree(5, 9)
        root1 = SpanTree(0, 9, [left_span, right_span])
        left_span = SpanTree(0, 2)
        right_span = SpanTree(3, 4)
        root2 = SpanTree(0, 4, [left_span, right_span])

        result1 = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        trees=[root1, root2], labels=[[1,5],[2,3]])
        tensor_cache = result1['logits']
        roots = result1['roots']
        loss1 = result1['loss']
        cache_ids = [root.cache_id for root in roots]
        root_repr1 = tensor_cache[cache_ids]  # (2, dim)

        input_ids = [[11,12,13,14,15,16,17,18,19,20],
                     [1,2,3,4,5,6,7,8,9,10]]
        attn_mask = [[1,1,1,1,1,0,0,0,0,0],
                     [1,1,1,1,1,1,1,1,1,1]]
        result2 = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        trees=[root2, root1], labels=[[2,3], [1,5]])

        tensor_cache = result2['logits']
        roots = result2['roots']
        loss2 = result2['loss']
        cache_ids = [root.cache_id for root in roots]
        root_repr2 = tensor_cache[cache_ids]  # (2, dim)

        self.assertTrue(torch.dist(root_repr1[0], root_repr2[1]) < 0.001)
        self.assertTrue(torch.dist(root_repr1[1], root_repr2[0]) < 0.001)
        self.assertEqual(loss1, loss2)

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]
        left_span = SpanTree(0, 4)
        right_span = SpanTree(5, 9)
        root1 = SpanTree(0, 9, [left_span, right_span])
        left_span = SpanTree(0, 2)
        right_span = SpanTree(3, 4)
        root2 = SpanTree(0, 4, [left_span, right_span])

        result1 = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        trees=[root1, root2])
        pred_labels1 = result1['predict']

        input_ids = [[11,12,13,14,15,16,17,18,19,20],
                     [1,2,3,4,5,6,7,8,9,10]]
        attn_mask = [[1,1,1,1,1,0,0,0,0,0],
                     [1,1,1,1,1,1,1,1,1,1]]
        result2 = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        trees=[root2, root1])
        pred_labels2 = result2['predict']

        self.assertEqual(pred_labels1[0], pred_labels2[1])
        self.assertEqual(pred_labels1[1], pred_labels2[0])