import json
import random
from unittest import TestCase
import numpy as np
import torch

from model.r2d2_cuda import R2D2Cuda
from model.topdown_parser import LSTMParser

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def hasattr(self, val):
        return val in self


mini_r2d2_config = '{\
  "architectures": [\
    "Bert"\
  ],\
  "model_type": "bert",\
  "attention_probs_dropout_prob": 0.1,\
  "hidden_act": "gelu",\
  "hidden_dropout_prob": 0.1,\
  "embedding_dim": 64,\
  "hidden_size": 64,\
  "initializer_range": 0.02,\
  "intermediate_size": 256,\
  "max_positions": 10,\
  "num_attention_heads": 8,\
  "type_vocab_size": 2,\
  "vocab_size": 30522,\
  "encoder_num_hidden_layers": 3,\
  "pad_token_id": 0,\
  "bos_token_id": 4,\
  "eos_token_id": 5,\
  "cls_token_id": 101,\
  "sum_token_id": 7,\
  "mask_token_id": 103,\
  "nsp_token_id": 8,\
  "lr_token_id": 9,\
  "rr_token_id": 10,\
  "eot_token_id": 11,\
  "tree_mask_token_id": 12,\
  "policy_token_id": 6,\
  "window_size": 4,\
  "tie_decoder": false,\
  "parser_hidden_dim": 64,\
  "parser_input_dim": 32,\
  "parser_num_layers": 2,\
  "unilm": true,\
  "loss": [{"name":"span_loss", "params":{"max_span":2}}]\
}'

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class TestSpanLoss(TestCase):
    def test_span_loss(self):
        device = torch.device('cpu')
        if torch.cuda.is_available():
            device = torch.device('cuda:0')

        set_seed(403)

        config = dotdict(json.loads(mini_r2d2_config))
        model = R2D2Cuda(config)
        model.to(device)
        model.eval()

        input_ids = [[1,2,3,4,5,6,7,8,9,10],
                     [11,12,13,14,15,16,17,18,19,20]]
        attn_mask = [[1,1,1,1,1,1,1,1,1,1],
                     [1,1,1,1,1,0,0,0,0,0]]
        atom_spans = [[[1,2], [3,4]], [[1,2]]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                        atom_spans=atom_spans,
                        attention_mask=torch.tensor(attn_mask, device=device))

        loss1 = results['loss']

        input_ids = [[11,12,13,14,15,16,17,18,19,20],
                    [1,2,3,4,5,6,7,8,9,10]]
        attn_mask = [[1,1,1,1,1,0,0,0,0,0],
                    [1,1,1,1,1,1,1,1,1,1]]
        atom_spans = [[[1,2]], [[1,2], [3,4]]]

        results = model(input_ids=torch.tensor(input_ids, device=device),
                        atom_spans=atom_spans,
                        attention_mask=torch.tensor(attn_mask, device=device))

        loss2 = results['loss']
        delta = loss1 - loss2
        self.assertTrue((delta * delta).sum() < 1e-7)