import math, time
import torch
import torch_sparse
import numpy as np
from torch_scatter import scatter_max
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import MegaGNN.graphgym.register as register
from MegaGNN.graphgym.config import cfg
from torch_geometric.data import HeteroData
from torch_geometric.nn.inits import glorot, zeros, ones, reset
from torch_geometric.nn import Linear


class MEGAGTLayer(nn.Module):
    r"""Graph Transformer layer for heterogeneous graphs.

    This layer implements a multi-head attention mechanism for heterogeneous graphs,
    supporting different node and edge types. It includes both local and global
    attention mechanisms, with optional layer/batch normalization and residual connections.

    Args:
        dim_in (int): Input dimension of node features
        dim_h (int): Hidden dimension for attention computation
        dim_out (int): Output dimension of node features
        metadata (tuple): Tuple containing (node_types, edge_types) of the heterogeneous graph
        num_heads (int, optional): Number of attention heads. Defaults to 1.
        layer_norm (bool, optional): Whether to use layer normalization. Defaults to False.
        batch_norm (bool, optional): Whether to use batch normalization. Defaults to False.
        return_attention (bool, optional): Whether to return attention scores. Defaults to False.
    """
    def __init__(self, dim_in, dim_h, dim_out, metadata, num_heads=1,
                 layer_norm=False, batch_norm=False, return_attention=False, **kwargs):
        super(MEGAGTLayer, self).__init__()

        self.dim_in = dim_in
        self.dim_h = dim_h
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        self.activation = register.act_dict[cfg.gt.act]
        self.metadata = metadata
        self.return_attention = return_attention
   
        # Initialize attention projection matrices
        self.k_lin = torch.nn.ModuleDict()
        self.q_lin = torch.nn.ModuleDict()
        self.v_lin = torch.nn.ModuleDict()
        self.e_lin = torch.nn.ModuleDict()
        self.g_lin = torch.nn.ModuleDict()
        self.oe_lin = torch.nn.ModuleDict()
        self.o_lin = torch.nn.ModuleDict()
        
        # Projection matrices for different node types
        for node_type in metadata[0]:
            self.k_lin[node_type] = Linear(dim_in, dim_h)
            self.q_lin[node_type] = Linear(dim_in, dim_h)
            self.v_lin[node_type] = Linear(dim_in, dim_h)
            self.o_lin[node_type] = Linear(dim_h, dim_out)
            
        # Projection matrices for different edge types
        for edge_type in metadata[1]:
            edge_type = '__'.join(edge_type)
            self.e_lin[edge_type] = Linear(dim_in, dim_h)
            self.g_lin[edge_type] = Linear(dim_h, dim_out)
            self.oe_lin[edge_type] = Linear(dim_h, dim_out)
            
        # Edge weight parameters for attention computation
        H, D = self.num_heads, self.dim_h // self.num_heads
        if cfg.gt.edge_weight:
            self.edge_weights = nn.Parameter(torch.Tensor(len(metadata[1]), H, D, D))
            self.msg_weights = nn.Parameter(torch.Tensor(len(metadata[1]), H, D, D))
            nn.init.xavier_uniform_(self.edge_weights)
            nn.init.xavier_uniform_(self.msg_weights)


        self.norm1_global = torch.nn.ModuleDict()
        self.norm2_ffn = torch.nn.ModuleDict()
        for node_type in metadata[0]:
            if self.layer_norm:
                self.norm1_global[node_type] = nn.LayerNorm(dim_h)
            if self.batch_norm:
                self.norm1_global[node_type] = nn.BatchNorm1d(dim_h)
        
        self.norm1_edge_global = torch.nn.ModuleDict()
        self.norm2_edge_ffn = torch.nn.ModuleDict()
        for edge_type in metadata[1]:
            edge_type = "__".join(edge_type)
            if self.layer_norm:
                self.norm1_edge_global[edge_type] = nn.LayerNorm(dim_h)
            if self.batch_norm:
                self.norm1_edge_global[edge_type] = nn.BatchNorm1d(dim_h)

        self.dropout_global = nn.Dropout(cfg.gt.dropout)
        self.dropout_attn = nn.Dropout(cfg.gt.attn_dropout)

        for node_type in metadata[0]:
            # Different node type have a different projection matrix
            if self.layer_norm:
                self.norm2_ffn[node_type] = nn.LayerNorm(dim_h)
            if self.batch_norm:
                self.norm2_ffn[node_type] = nn.BatchNorm1d(dim_h)
        
        # Feed Forward block.
        if cfg.gt.ffn == 'Type':
            self.ff_linear1_type = torch.nn.ModuleDict()
            self.ff_linear2_type = torch.nn.ModuleDict()
            for node_type in metadata[0]:
                self.ff_linear1_type[node_type] = nn.Linear(dim_h, dim_h * 2)
                self.ff_linear2_type[node_type] = nn.Linear(dim_h * 2, dim_h)
            self.ff_linear1_edge_type = torch.nn.ModuleDict()
            self.ff_linear2_edge_type = torch.nn.ModuleDict()
            for edge_type in metadata[1]:
                edge_type = "__".join(edge_type)
                self.ff_linear1_edge_type[edge_type] = nn.Linear(dim_h, dim_h * 2)
                self.ff_linear2_edge_type[edge_type] = nn.Linear(dim_h * 2, dim_h)
        
        self.ff_dropout1 = nn.Dropout(cfg.gt.dropout)
        self.ff_dropout2 = nn.Dropout(cfg.gt.dropout)


    def forward(self, batch):
        """Forward pass of the MEGAGTLayer.

        Args:
            batch: A HeteroData object or a regular Data object containing the graph structure
                  and node/edge features.

        Returns:
            Updated batch with transformed node features. If return_attention is True,
            also returns the attention scores.
        """
        has_edge_attr = False
        if isinstance(batch, HeteroData):
            h_dict, edge_index_dict = batch.collect('x'), batch.collect('edge_index')
            if sum(batch.num_edge_features.values()):
                edge_attr_dict = batch.collect('edge_attr')
                has_edge_attr = True
        else:
            h_dict = {'node_type': batch.x}
            edge_index_dict = {('node_type', 'edge_type', 'node_type'): batch.edge_index}
            if sum(batch.num_edge_features.values()):
                edge_attr_dict = {('node_type', 'edge_type', 'node_type'): batch.edge_attr}
                has_edge_attr = True
        h_in_dict = h_dict
        if has_edge_attr:
            edge_attr_in_dict = edge_attr_dict.copy()

        h_out_dict_list = {node_type: [] for node_type in h_dict}

        # Pre-normalization
        if self.layer_norm or self.batch_norm:
            h_dict = {
                node_type: self.norm1_global[node_type](h_dict[node_type])
                for node_type in batch.node_types
            }
            if has_edge_attr:
                edge_attr_dict = {
                    edge_type: self.norm1_edge_global["__".join(edge_type)](edge_attr_dict[edge_type])
                    for edge_type in batch.edge_types
                }
        
        h_attn_dict_list = {node_type: [] for node_type in h_dict}
 
        # Test if Signed attention is beneficial
        # st = time.time()
        H, D = self.num_heads, self.dim_h // self.num_heads
        homo_data = batch.to_homogeneous()
        edge_index = homo_data.edge_index
        node_type_tensor = homo_data.node_type
        edge_type_tensor = homo_data.edge_type
        q = torch.empty((homo_data.num_nodes, self.dim_h), device=homo_data.x.device)
        k = torch.empty((homo_data.num_nodes, self.dim_h), device=homo_data.x.device)
        v = torch.empty((homo_data.num_nodes, self.dim_h), device=homo_data.x.device)
        edge_attr = torch.empty((homo_data.num_edges, self.dim_h), device=homo_data.x.device)
        edge_gate = torch.empty((homo_data.num_edges, self.dim_h), device=homo_data.x.device)
        for idx, node_type in enumerate(batch.node_types):
            mask = node_type_tensor == idx
            q[mask] = self.q_lin[node_type](h_dict[node_type])
            k[mask] = self.k_lin[node_type](h_dict[node_type])
            v[mask] = self.v_lin[node_type](h_dict[node_type])
        for idx, edge_type_tuple in enumerate(batch.edge_types):
            edge_type = '__'.join(edge_type_tuple)
            mask = edge_type_tensor == idx
            edge_attr[mask] = self.e_lin[edge_type](edge_attr_dict[edge_type_tuple])
            edge_gate[mask] = self.g_lin[edge_type](edge_attr_dict[edge_type_tuple])
        src_nodes, dst_nodes = edge_index
        num_edges = edge_index.shape[1]
        L = homo_data.x.shape[0]
        S = homo_data.x.shape[0]

        if has_edge_attr:
            edge_attr = edge_attr.view(-1, H, D)
            edge_attr = edge_attr.transpose(0,1) # (h, sl, d_model)

            edge_gate = edge_gate.view(-1, H, D)
            edge_gate = edge_gate.transpose(0,1) # (h, sl, d_model)

        q = q.view(-1, H, D)
        k = k.view(-1, H, D)
        v = v.view(-1, H, D)

        # transpose to get dimensions h * sl * d_model
        q = q.transpose(0,1)
        k = k.transpose(0,1)
        v = v.transpose(0,1)
            
        src_nodes, dst_nodes = edge_index
        num_edges = edge_index.shape[1]
        # Compute query and key for each edge
        edge_q = q[:, dst_nodes, :]  # Queries for destination nodes # num_heads * num_edges * d_k
        edge_k = k[:, src_nodes, :]  # Keys for source nodes
        edge_v = v[:, src_nodes, :]

        if hasattr(self, 'edge_weights'):
            edge_weight = self.edge_weights[edge_type_tensor]  # (num_edges, num_heads, d_k, d_k)

            edge_weight = edge_weight.transpose(0, 1)  # Transpose for batch matrix multiplication: (num_heads, num_edges, d_k, d_k)
            edge_k = edge_k.unsqueeze(-1) # Add dimension for matrix multiplication (num_heads, num_edges, d_k, 1)

            edge_k = torch.matmul(edge_weight, edge_k)  # (num_heads, num_edges, d_k, 1)
            edge_k = edge_k.squeeze(-1)  # Remove the extra dimension (num_heads, num_edges, d_k)

        # Step 1:Compute attention scores
        edge_scores = edge_q * edge_k
        if has_edge_attr:
            edge_scores = edge_scores + edge_attr
            edge_v = edge_v * F.sigmoid(edge_gate)
            edge_attr = edge_scores
        
        edge_scores = torch.sum(edge_scores, dim=-1) / math.sqrt(D) # num_heads * num_edges
        edge_scores = torch.clamp(edge_scores, min=-5, max=5)

        expanded_dst_nodes = dst_nodes.repeat(H, 1)  # Repeat dst_nodes for each head
        
        # Step 2: Calculate max for each destination node per head using scatter_max
        max_scores, _ = scatter_max(edge_scores, expanded_dst_nodes, dim=1, dim_size=L) # This broadcasts the destination node indices across all 8 feature dimensions, making it possible to perform the scatter operation independently for each head.
        max_scores = max_scores.gather(1, expanded_dst_nodes)

        # Step 3: Exponentiate scores and sum
        exp_scores = torch.exp(edge_scores - max_scores)
        sum_exp_scores = torch.zeros((H, L), device=edge_scores.device)
        sum_exp_scores.scatter_add_(1, expanded_dst_nodes, exp_scores) # (dim, index, src)

        # Step 4: Apply softmax
        edge_scores = exp_scores / sum_exp_scores.gather(1, expanded_dst_nodes)
        edge_scores = edge_scores.unsqueeze(-1)
        edge_scores = self.dropout_attn(edge_scores)
        saved_scores = edge_scores
        
        out = torch.zeros((H, L, D), device=q.device)
        out.scatter_add_(1, dst_nodes.unsqueeze(-1).expand((H, num_edges, D)), edge_scores * edge_v)

        out = out.transpose(0,1).contiguous().view(-1, H * D)

        for idx, node_type in enumerate(batch.node_types):
            mask = node_type_tensor == idx
            out_type = self.o_lin[node_type](out[mask, :])
            h_attn_dict_list[node_type].append(out_type.squeeze())
        if has_edge_attr:
            edge_attr = edge_attr.transpose(0,1).contiguous().view(-1, H * D)
            for idx, edge_type_tuple in enumerate(batch.edge_types):
                edge_type = '__'.join(edge_type_tuple)
                mask = edge_type_tensor == idx
                out_type = self.oe_lin[edge_type](edge_attr[mask, :])
                edge_attr_dict[edge_type_tuple] = out_type

        h_attn_dict = {}
        for node_type in h_attn_dict_list:
            h_attn_dict[node_type] = torch.sum(torch.stack(h_attn_dict_list[node_type], dim=0), dim=0)
            h_attn_dict[node_type] = self.dropout_global(h_attn_dict[node_type])

        if cfg.gt.residual == 'Fixed':
            h_attn_dict = {
                node_type: h_attn_dict[node_type] + h_in_dict[node_type]
                for node_type in batch.node_types
            }

            if has_edge_attr:
                edge_attr_dict = {
                    edge_type: edge_attr_dict[edge_type] + edge_attr_in_dict[edge_type]
                    for edge_type in batch.edge_types
                }            
        
        # Concat output
        h_out_dict_list = {
            node_type: h_out_dict_list[node_type] + [h_attn_dict[node_type]] for node_type in batch.node_types
        }

        # Combine global information
        h_dict = {
            node_type: sum(h_out_dict_list[node_type]) for node_type in batch.node_types
        }
        
        # Pre-normalization
        if self.layer_norm or self.batch_norm:
            h_dict = {
                node_type: self.norm2_ffn[node_type](h_dict[node_type])
                for node_type in batch.node_types
            }
        
        if cfg.gt.ffn == 'Type':
            h_dict = {
                node_type: h_dict[node_type] + self._ff_block_type(h_dict[node_type], node_type)
                for node_type in batch.node_types
            }
            if has_edge_attr:
                edge_attr_dict = {
                    edge_type: edge_attr_dict[edge_type] + self._ff_block_edge_type(edge_attr_dict[edge_type], edge_type)
                    for edge_type in batch.edge_types
                }
                
        if isinstance(batch, HeteroData):
            for node_type in batch.node_types:
                batch[node_type].x = h_dict[node_type]
            if has_edge_attr:
                for edge_type in batch.edge_types:
                    batch[edge_type].edge_attr = edge_attr_dict[edge_type]
        else:
            batch.x = h_dict['node_type']

        if self.return_attention:
            return batch, saved_scores
        return batch
    
    def _ff_block_type(self, x, node_type):
        """Feed Forward block for node types.
        
        Args:
            x (torch.Tensor): Input tensor
            node_type (str): Type of the node
            
        Returns:
            torch.Tensor: Transformed tensor after feed-forward network
        """
        x = self.ff_dropout1(self.activation(self.ff_linear1_type[node_type](x)))
        return self.ff_dropout2(self.ff_linear2_type[node_type](x))
    
    def _ff_block(self, x):
        """Feed Forward block.
        """
        x = self.ff_dropout1(self.activation(self.ff_linear1(x)))
        return self.ff_dropout2(self.ff_linear2(x))
    
    def _ff_block_edge_type(self, x, edge_type):
        """Feed Forward block for edge types.
        
        Args:
            x (torch.Tensor): Input tensor
            edge_type (str): Type of the edge
            
        Returns:
            torch.Tensor: Transformed tensor after feed-forward network
        """
        edge_type = "__".join(edge_type)
        x = self.ff_dropout1(self.activation(self.ff_linear1_edge_type[edge_type](x)))
        return self.ff_dropout2(self.ff_linear2_edge_type[edge_type](x))
