import torch
from unittest import TestCase
from transformers import AutoConfig
from model.fast_r2d2_interpretable import FastR2D2ClassificationAttn
from experiments.fast_r2d2_miml import FastR2D2WithAttention


class RevDecoderUnittest(TestCase):
    def testBatchCorrect(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')

        config = AutoConfig.from_pretrained("data/en_config/fast_r2d2_rev.json")
        model = FastR2D2ClassificationAttn(config, label_num=2, enable_mha=True, causal_decode=True)
        model.to(device)
        model.eval()
        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        keep_logits=True)
        sentence_embedding1 = results['logits'][1]
        
        input_ids = [[11,12,13,14,15]]
        attn_mask = [[1,1,1,1,1]]
        results = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        keep_logits=True)
        sentence_embedding2 = results['logits'][0]
        delta = sentence_embedding1 - sentence_embedding2
        self.assertTrue((delta * delta).sum() < 0.01, f'actual distance: {(delta * delta).sum()}')

    def testRevAttentionCorrect(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')

        config = AutoConfig.from_pretrained("data/en_config/fast_r2d2_rev.json")
        model = FastR2D2WithAttention(config, enable_mha=False, causal_decode=True)
        model.to(device)
        model.eval()
        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        sample_trees=10)
        attn_output1 = results['attn_output'][1][:5, :]
        
        input_ids = [[11,12,13,14,15]]
        attn_mask = [[1,1,1,1,1]]
        results = model(input_ids=torch.tensor(input_ids, device=device),
                        attention_mask=torch.tensor(attn_mask, device=device),
                        sample_trees=10)
        attn_output2 = results['attn_output'][0]
        delta = attn_output1 - attn_output2
        self.assertTrue((delta * delta).sum() < 0.01, f'actual output: {(delta * delta).sum()}')