"""
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import numpy as np

import unittest

import dgl

import models.graph.hgt as hgt

from scheduling.environment import SchedulingEnvironment

class TestHGT(unittest.TestCase):
    def setUp(self) -> None:
        self.nodes = ['node', 'output']
        self.edges = [('node', 'edge', 'output'), ('node', 'attention', 'output')]
        self.edge_features = {'attention' : 'node'}
        self.output_nodes = ['output']
        
        self.in_dims = {'node': 3, 'output': 5, 'attention': 5}
        self.hid_dims = {'node': 8, 'output': 8, 'attention': 8}
        self.out_dims = {'node': 7, 'output': 1, 'attention': 1}

        self.data_dict = {
            # 'node': torch.zeros(10, 3),
            ('node', 'edge', 'output'): ([i for i in range(10)], [i for i in range(10)]),
            ('node', 'attention', 'output'): ([i for i in range(10)], [i for i in range(10)])
        }

        self.num_nodes_dict = {
            'node': 10,
            'output': 10,
        }
        self.g = dgl.heterograph(self.data_dict, self.num_nodes_dict, idtype=torch.int64)

        self.node_feat = {
            'node': torch.ones((10, 3)),
            'output': torch.ones((10, 5))
        }

        self.edge_feat = {
            'attention': torch.ones(10, 5)
        }
        

    
    def test_HGT_input_output(self):
        num_types = 2
        num_relations = 3
        n_heads = 2
        n_layers = 3
        d_k = 2
        dropout = 0.1
        use_norm = True
        
        model = hgt.HGT(self.g, self.node_feat, self.edge_feat, self.in_dims, self.hid_dims, self.out_dims, n_layers, n_heads, use_norm=use_norm)
        
        # self.assertEqual(model.in_dim, in_dim)
        # self.assertEqual(model.out_dim, out_dim)
        # self.assertEqual(model.num_types, num_types)
        # self.assertEqual(model.num_relations, num_relations)
        # self.assertEqual(model.n_heads, n_heads)
        # self.assertEqual(model.d_k, d_k)
        # self.assertEqual(model.dropout, dropout)
        # self.assertEqual(model.use_norm, use_norm)
        

if __name__ == '__main__':
    unittest.main()