from unicodedata import bidirectional
from unittest import TestCase
import torch
import json
import torch.nn as nn

from model.tree_encoder import UniLMEncoder

mini_r2d2_config = '{\
  "architectures": [\
    "Bert"\
  ],\
  "model_type": "bert",\
  "attention_probs_dropout_prob": 0.1,\
  "hidden_act": "gelu",\
  "hidden_dropout_prob": 0.1,\
  "embedding_dim": 64,\
  "hidden_size": 64,\
  "initializer_range": 0.02,\
  "intermediate_size": 256,\
  "max_role_embeddings": 4,\
  "num_attention_heads": 8,\
  "type_vocab_size": 2,\
  "max_positions":10,\
  "encoder_num_hidden_layers": 3\
}'

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def hasattr(self, val):
        return val in self

class UniLMTestcase(TestCase):
    def testUniLMEncoder(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')
        config = dotdict(json.loads(mini_r2d2_config))
        embedding = nn.Embedding(100, config.embedding_dim)
        embedding.to(device)
        
        encoder = UniLMEncoder(config)
        encoder.to(device)
        encoder.eval()
        N = 3
        ids_list = [[1,2,3], [4,5,6], [7,8,9]]
        memory = torch.rand(N, 2, config.hidden_size, device=device)
        # ids_batch = torch.tensor(ids_list, device=device)
        # result1 = encoder(input_ids=ids_batch, memory=memory, embeddings=embedding, bidirectional_pos=False)
        # result2 = encoder(input_ids=ids_list, memory=memory, embeddings=embedding, bidirectional_pos=False)
        # delta = result1 - result2
        # self.assertTrue((delta * delta).sum() < 0.01, f'actual distance: {(delta * delta).sum()}')

        ids_list = [[1,2,3], [4,5,6], [7,8,9]]
        memory = torch.rand(N, 2, config.hidden_size, device=device)
        
        ids_list = [[1,2], [3,4,5,6], [7,8]]
        result1 = encoder(input_ids=ids_list, memory=memory, embeddings=embedding, bidirectional_pos=True)
        self.assertEquals(result1.shape[1], 4)
        
        ids_list = [[1,2], [7,8], [3,4,5,6]]
        memory_ = memory[[0, 2, 1], :, :]
        result2 = encoder(input_ids=ids_list, memory=memory_, embeddings=embedding, bidirectional_pos=True)
        delta = result1[1] - result2[2]
        self.assertTrue((delta * delta).sum() < 0.01, f'actual distance: {(delta * delta).sum()}')


        ids_list = [[1,2,3], [4,5,6], [7,8,9]]
        memory = torch.rand(N, 2, config.hidden_size, device=device)
        # ids_batch = torch.tensor(ids_list, device=device)
        # result1 = encoder(input_ids=ids_batch, memory=memory, embeddings=embedding)
        # result2 = encoder(input_ids=ids_list, memory=memory, embeddings=embedding)
        # delta = result1 - result2
        # self.assertTrue((delta * delta).sum() < 0.01, f'actual distance: {(delta * delta).sum()}')
        
        ids_list = [[1,2], [3,4,5,6], [7,8]]
        result1 = encoder(input_ids=ids_list, memory=memory, embeddings=embedding, bidirectional_pos=False)
        self.assertEquals(result1.shape[1], 4)
        
        ids_list = [[1,2], [7,8], [3,4,5,6]]
        memory_ = memory[[0, 2, 1], :, :]
        result2 = encoder(input_ids=ids_list, memory=memory_, embeddings=embedding, bidirectional_pos=False)
        delta = result1[1] - result2[2]
        self.assertTrue((delta * delta).sum() < 0.01, f'actual distance: {(delta * delta).sum()}')