
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import dgl

from models.graph.hetgat_layer import HetGATLayer, MultiHetGATLayer

import unittest
class TestHetGAT(unittest.TestCase):
    def setUp(self) -> None:
        self.nodes = ['node', 'output']
        self.edges = [('node', 'edge', 'output'), ('node', 'attention', 'output')]
        self.edge_features = {'attention' : 'node'}
        self.output_nodes = ['output']
        
        self.in_dims = {'node': 3, 'output': 5, 'attention': 5}
        self.out_dims = {'node': 7, 'output': 1, 'attention': 1}

        self.data_dict = {
            # 'node': torch.zeros(10, 3),
            ('node', 'edge', 'output'): ([i for i in range(10)], [i for i in range(10)]),
            ('node', 'attention', 'output'): ([i for i in range(10)], [i for i in range(10)])
        }

        self.num_nodes_dict = {
            'node': 10,
            'output': 10,
        }
        self.g = dgl.heterograph(self.data_dict, self.num_nodes_dict, idtype=torch.int64)

        self.node_feat = {
            'node': torch.ones((10, 3)),
            'output': torch.ones((10, 5))
        }

        self.edge_feat = {
            'attention': torch.ones(10, 5)
        }
        
        
    
    def test_hetgat(self):
        """Test the hetgat"""

        hetgat = HetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, self.in_dims, self.out_dims, 8)

        # Unit Test 1
        results_node, results_edge = hetgat(self.g, self.node_feat, self.edge_feat, 'output')
        # print([(key, value) for key, value in results_node.items()])
        # print([(key, value.shape) for key, value in results_edge.items()])
        self.assertEqual(results_edge.keys(), {'attention'})
        self.assertEqual(results_edge['attention'].shape, torch.Size([10, 8, 1]))

    def test_multihetgat(self):
        """Test the multi hetgat
        MultiHetGATLayer is a wrapper for the Multi Head HetGATLayer. This is a test to see if the wrapper works."""
        
        in_dims = {'node': 3, 'output': 5, 'attention': 5}
        in_dims2 = {'node': 7 * 8, 'output': 1 * 8, 'attention': 1 * 8}
        out_dims = {'node': 7, 'output': 1, 'attention': 1}

        hetgat1 = MultiHetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, in_dims, out_dims, 8)
        results_node1, results_edge1 = hetgat1(self.g, self.node_feat, self.edge_feat, 'output')
        
        self.assertEqual(results_node1['node'].shape, torch.Size([10, 7*8]))
        self.assertEqual(results_node1['output'].shape, torch.Size([10, 1*8]))
        self.assertEqual(results_edge1['attention'].shape, torch.Size([10, 1*8]))
        
        hetgat2 = MultiHetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, in_dims2, out_dims, 8)
        results_node2, results_edge2 = hetgat2(self.g, results_node1, results_edge1, 'output')
        
        self.assertEqual(results_node2['node'].shape, torch.Size([10, 7*8]))
        self.assertEqual(results_node2['output'].shape, torch.Size([10, 1*8]))
        self.assertEqual(results_edge2['attention'].shape, torch.Size([10, 1*8]))


    def test_multihetgat_linear(self):
        """Test the linear merge
        This is a test to see if the linear merge works. The linear merge is a unique form of attention compilation that I have run into in the source code of DGL. 
        source: https://github.com/dmlc/dgl/blob/master/examples/pytorch/han/model_hetero.py
        """
        # Unit Test
        in_dims = {'node': 3, 'output': 5, 'attention': 5}
        out_dims = {'node': 7, 'output': 1, 'attention': 1}

        hetgat = MultiHetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, in_dims, out_dims, 8, merge='linear')
        results_node, results_edge = hetgat(self.g, self.node_feat, self.edge_feat, 'output')

        # print([(key, value.shape) for key, value in results_node.items()])
        # print([(key, value.shape) for key, value in results_edge.items()])
        self.assertTrue(results_node['node'].shape == torch.Size([10, 7]))
        self.assertTrue(results_node['output'].shape == torch.Size([10, 1]))
        self.assertTrue(results_edge['attention'].shape == torch.Size([10, 1]))

    def test_hetgat_mixed_activate(self):
        """Test the mixed activation function
        This is a test to see if the mixed activation function works. The mixed activation function is a unique form of attention.
        """
        hetgat = HetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, self.in_dims, self.out_dims, 8, mode='mixed')
        
        results_node, results_edge = hetgat(self.g, self.node_feat, self.edge_feat, 'output')
        self.assertEqual(results_node['node'].shape, torch.Size([10, 24, 7]))
        self.assertEqual(results_node['output'].shape, torch.Size([10, 24, 1]))
        self.assertEqual(results_edge['attention'].shape, torch.Size([10, 24, 1]))
        
        # print(f"\nHetGAT Mixed Activation Function")
        # print(f"result_node_shapes: {[(key, value.shape) for key, value in results_node.items()]}")
        # print(f"result_edge_shapes: {[(key, value.shape) for key, value in results_edge.items()]}")

    def test_multihetgat_mixed(self):
        """Test the mixed activation function
        This is a test to see if the mixed activation function works. The mixed activation function is a unique form of attention.
        """
        in_dims = {'node': 3, 'output': 5, 'attention': 5}
        out_dims = {'node': 7, 'output': 1, 'attention': 1}

        hetgat = MultiHetGATLayer(self.nodes, self.edges, self.edge_features, self.edge_features, self.output_nodes, in_dims, out_dims, 8, mode='mixed')
        results_node, results_edge = hetgat(self.g, self.node_feat, self.edge_feat, 'output')
        print(f"\nMultiHetGAT Mixed Activation Function")
        print(f"result_node_shapes: {[(key, value.shape) for key, value in results_node.items()]}")
        print(f"result_edge_shapes: {[(key, value.shape) for key, value in results_edge.items()]}")

if __name__ == '__main__':
    import random
    # random seed
    seed = 10
    torch.manual_seed(seed)
    random.seed(seed)
    unittest.main()
