from unittest import TestCase
import tqdm
import numpy as np
import json
import torch

from model.r2d2_cuda import R2D2Cuda

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def hasattr(self, val):
        return val in self


span_loss_config = '{\
  "architectures": [\
    "Bert"\
  ],\
  "model_type": "bert",\
  "attention_probs_dropout_prob": 0.1,\
  "hidden_act": "gelu",\
  "hidden_dropout_prob": 0.1,\
  "embedding_dim": 64,\
  "hidden_size": 64,\
  "initializer_range": 0.02,\
  "intermediate_size": 256,\
  "max_role_embeddings": 4,\
  "num_attention_heads": 8,\
  "type_vocab_size": 2,\
  "vocab_size": 30522,\
  "encoder_num_hidden_layers": 3,\
  "decoder_num_hidden_layers": 1,\
  "pad_token_id": 0,\
  "bos_token_id": 4,\
  "eos_token_id": 5,\
  "cls_token_id": 101,\
  "sum_token_id": 7,\
  "mask_token_id": 103,\
  "nsp_token_id": 8,\
  "lr_token_id": 9,\
  "rr_token_id": 10,\
  "eot_token_id": 11,\
  "tree_mask_token_id": 12,\
  "policy_token_id": 6,\
  "window_size": 4,\
  "tie_decoder": false,\
  "parser_hidden_dim": 64,\
  "parser_input_dim": 32,\
  "parser_num_layers": 2,\
  "unilm": true,\
  "max_context_length": 5,\
  "loss":[{"name":"generative_loss"}]\
}'

default_loss_config = '{\
  "architectures": [\
    "Bert"\
  ],\
  "model_type": "bert",\
  "attention_probs_dropout_prob": 0.1,\
  "hidden_act": "gelu",\
  "hidden_dropout_prob": 0.1,\
  "embedding_dim": 64,\
  "hidden_size": 64,\
  "initializer_range": 0.02,\
  "intermediate_size": 256,\
  "max_role_embeddings": 4,\
  "num_attention_heads": 8,\
  "type_vocab_size": 2,\
  "vocab_size": 30522,\
  "encoder_num_hidden_layers": 3,\
  "decoder_num_hidden_layers": 1,\
  "pad_token_id": 0,\
  "bos_token_id": 4,\
  "eos_token_id": 5,\
  "cls_token_id": 101,\
  "sum_token_id": 7,\
  "mask_token_id": 103,\
  "nsp_token_id": 8,\
  "lr_token_id": 9,\
  "rr_token_id": 10,\
  "eot_token_id": 11,\
  "tree_mask_token_id": 12,\
  "policy_token_id": 6,\
  "window_size": 4,\
  "tie_decoder": false,\
  "parser_hidden_dim": 64,\
  "parser_input_dim": 32,\
  "parser_num_layers": 2,\
  "unilm": true\
}'


class TestFastR2D2(TestCase):
    def test_loss_regression(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')

        config1 = dotdict(json.loads(span_loss_config))
        model1 = R2D2Cuda(config1)
        model1.to(device)
        model1.eval()
        
        config2 = dotdict(json.loads(default_loss_config))
        model2 = R2D2Cuda(config2)
        model2.to(device)
        model2.eval()
        model2.load_state_dict(model1.state_dict())

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]

        results = model1(input_ids=torch.tensor(input_ids, device=device),
                         attention_mask=torch.tensor(attn_mask, device=device))

        loss1 = results['loss']

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]

        results = model2(input_ids=torch.tensor(input_ids, device=device),
                         attention_mask=torch.tensor(attn_mask, device=device))

        loss2 = results['loss']
        delta = loss1 - loss2
        self.assertTrue((delta * delta).sum() < 1e-7, 
                        f'actual loss: {(delta * delta).sum()}')
        
    def test_generative_loss(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')

        config = dotdict(json.loads(span_loss_config))
        model = R2D2Cuda(config)
        model.to(device)
        model.eval()

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]
        atom_spans = [[[1,3], [7,9]], [[2,4]]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                        atom_spans=atom_spans,
                        attention_mask=torch.tensor(attn_mask, device=device))

        loss1 = results['loss']

        input_ids = [[11,12,13,14,15,16,17,18,19,20],
                    [1,2,3,4,5,6,7,8,9,10]]
        attn_mask = [[1,1,1,1,1,0,0,0,0,0],
                    [1,1,1,1,1,1,1,1,1,1]]
        atom_spans = [[[2,4]], [[1,3], [7,9]]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                         attention_mask=torch.tensor(attn_mask, device=device),
                         atom_spans=atom_spans)

        loss2 = results['loss']
        delta = loss1 - loss2
        self.assertTrue((delta * delta).sum() < 1e-7, 
                        f'actual loss: {(delta * delta).sum()}')