import torch
import torch.nn.functional as F
from typing import Sequence
from torch import Tensor, nn
from torch_sparse import SparseTensor
from torch_sparse import sum as sparse_sum
from torch_sparse import mean as sparse_mean
from torch_sparse import max as sparse_max


class GNNSparseLayer(nn.Module):
    def __init__(
        self, 
        hidden_dim: int, 
        aggregation: str = "sum", 
        norm: str = "batch",
        learn_norm: bool = True,
        track_norm: bool = False,
        asym: bool = False
    ):
        super(GNNSparseLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.aggregation = aggregation
        
        # Linear Layer for nodes
        if not asym:
            self.U = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.V = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.B = nn.Linear(hidden_dim, hidden_dim, bias=True)
        
        # Linear Layer for edges
        if asym:
            self.U0 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.U1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.V = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.E1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.E2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.C = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.X = nn.Linear(3 * hidden_dim, hidden_dim, bias=True)
        else:
            self.C = nn.Linear(hidden_dim, hidden_dim, bias=True)
        
        # Normalization for nodes and edges
        if norm == "batch":
            self.norm_x = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
            self.norm_e = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
        else:
            self.norm_x = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
            self.norm_e = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)

    def forward(self, x: Tensor, e: Tensor, edge_index: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (V, H) Node features; e: (E, H) Edge features
            edge_index: (2, E) Tensor with edges representing connections from source to target nodes.
        Returns:
            Updated x and e after one layer of GNN.
        """
        nodes_num = x.shape[0] # Total number of nodes
        
        # Linear transformation for node embeddings
        Ux: Tensor = self.U(x) # (V, H)
        
        # Aggregate neighbor information for edges
        Vx = self.V(x[edge_index[1]]) # (E, H)
        
        # Message passing from nodes to edges
        Ax = self.A(x) # (V, H), source
        Bx = self.B(x) # (V, H), target
        
        # Update edge features
        Ce = self.C(e) # (E, H)
        e = Ax[edge_index[0]] + Bx[edge_index[1]] + Ce # (E, H)
            
        # Sigmoid gates for edge features
        gates = torch.sigmoid(e) # (E, H)
        
        # Aggregate messages for node embeddings
        x = Ux + self.aggregate(Vx, gates, edge_index, nodes_num) # (V, H)

        # Apply normalization and activation
        x = F.relu(self.norm_x(x)) # (V, H)
        e = F.relu(self.norm_e(e)) # (E, H)
        
        return x, e
    
    def asym_forward(self, x: Tensor, e: Tensor, d: Tensor, edge_index: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (V, H) Node embedding
            e: (E, H) (mixed) Edge embedding
            edge_index: (2, E) Tensor with edges representing connections from source to target nodes.
        Returns:
            Updated x and e after one layer of GNN.
        """
        nodes_num, hidden_dim = x.shape
                
        # Linear transformation for node embeddings
        U0x: Tensor = self.U0(x) # (B, V, H)
        E1e: Tensor = self.E1(e)
        E2e: Tensor = self.E2(e)

        x = self.X(torch.cat([U0x, E1e.sum(1), E2e.sum(2)], dim=-1))
        U1x = self.U1(x)
        
        # Aggregate neighbor information for edges
        V1x: Tensor = self.V(x[edge_index[1]]) # (B, V, H)
        V1x = V1x.unsqueeze(1).expand(-1, nodes_num, -1, -1) # (B, V, V, H)
        
        # Message passing from nodes to edges
        A1x: Tensor = self.A1(x) # (B, V, H), source
        A2x: Tensor = self.A2(x) # (B, V, H), target
                
        # Update edge features
        Ce = self.C(e) # (E, H)
        e = A1x.unsqueeze(dim=1) + A2x.unsqueeze(dim=2) + Ce # (B, V, V, H)   (Wx_j + Wx_i + We_ij)
        
        # Sigmoid gates for edge features
        gates = torch.sigmoid(e) # (B, V, V, H)

        # Aggregate messages for node/edge embeddings
        x = U1x + self.aggregate(V1x, gates, edge_index, nodes_num) # + self.aggregate(V2x, gates_T, graph) # (B, V, H)
        
        # Apply normalization and activation
        x = F.relu(self.norm_x(x))
        e = F.relu(self.norm_e(e)) # (E, H)
        
        return e
    
    
    def aggregate(
        self, Vx: Tensor, gates: Tensor, edge_index: Tensor, nodes_num: int
    ) -> Tensor:
        """
        Args:
            Vx: (E, H); gates: (E, H); edge_index: (2, E)

        Returns:
            node feature: (V, H)
        """
        sparseVh = SparseTensor(
            row=edge_index[0],
            col=edge_index[1],
            value=Vx * gates,
            sparse_sizes=(nodes_num, nodes_num)
        )
        if self.aggregation == "mean":
            return sparse_mean(sparseVh, dim=1)
        elif self.aggregation == "max":
            return sparse_max(sparseVh, dim=1)
        else:
            return sparse_sum(sparseVh, dim=1)
        
        
class GNNSparseBlock(nn.Module):
    def __init__(
        self, 
        num_layers: int, 
        hidden_dim: int, 
        aggregation: str = "sum", 
        norm: str = "layer",
        learn_norm: bool = True, 
        track_norm: bool = False,
        asym: bool = False
    ):
        super(GNNSparseBlock, self).__init__()
        
        # gnn layer
        self.layers = nn.ModuleList([
            GNNSparseLayer(hidden_dim, aggregation, norm, learn_norm, track_norm, asym)
            for _ in range(num_layers)
        ])
        
        # per layer out
        self.per_layer_out = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm),
                nn.SiLU(),
                zero_module(nn.Linear(hidden_dim, hidden_dim)),
            ) for _ in range(num_layers)
        ])
    
    def forward(self, x: Tensor, e: Tensor, edge_index: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (V, H) Node features; 
            e: (E, H) Edge features;
            edge_index: (2, E) Tensor with edges representing connections from source to target nodes.
        
        Return:
            updated features. x: (V, H); e: (E, H);
        """
        # gnn layer
        for layer, out_layer in zip(self.layers, self.per_layer_out):
            x_in, e_in = x, e
            x, e = layer(x, e, edge_index)
            x = x + x_in
            e = e_in + out_layer(e)
        
        # return
        return x, e

    def asym_forward(
        self, x: Tensor, e: Tensor, d: Tensor, edges_feature: Tensor, edge_index: Tensor, nodes_num: int 
    ) -> Sequence[Tensor]:
        """
        Args:
            e: (E, H) Decision Variables (edge);
            edges_feature: (E, H) Edge features;
            edge_index: (2, E) Tensor with edges representing connections from source to target nodes.
            
        Return:
            updated feature e: (E, H)
        """
        # gnn layer
        for layer, out_layer in zip(self.layers, self.per_layer_out):
            layer: GNNSparseLayer
            x_in, e_in = x, e
            x, e = layer.asym_forward(x, e, d, edge_index)
            x = x + x_in
            e = e_in + out_layer(e)
        
        # return
        return x, e
    

class GNNDenseLayer(nn.Module):
    def __init__(
        self, 
        hidden_dim: int, 
        aggregation: str = "sum", 
        norm: str = "batch",
        learn_norm: bool = True,
        track_norm: bool = False,
        asym: bool = False
    ):
        super(GNNDenseLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.aggregation = aggregation
        
        # Linear Layer for nodes
        if not asym:
            self.U = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.V = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.B = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.C = nn.Linear(hidden_dim, hidden_dim, bias=True)
        
        # Linear Layer for edges
        if asym:
            self.U0 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.U1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.V = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.E1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.E2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.A2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.C = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.X = nn.Linear(3 * hidden_dim, hidden_dim, bias=True)
        
        # Normalization for nodes and edges
        if asym:
            if norm == "batch":
                self.norm_e = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
                self.norm_x = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
            else:
                self.norm_x = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
                self.norm_e = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
        else:
            if norm == "batch":
                self.norm_x = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
                self.norm_e = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
            else:
                self.norm_x = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
                self.norm_e = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
                
    def forward(self, x: Tensor, e: Tensor, graph: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (B, V, H) Node features; 
            e: (B, V, V, H) Edge features
            graph: (B, V, V) Graph adjacency matrices
        Returns:
            Updated x and e after one layer of GNN.
        """
        batch_size, nodes_num, hidden_dim = x.shape
        
        # Linear transformation for node embeddings
        Ux: Tensor = self.U(x) # (B, V, H)
        
        # Aggregate neighbor information for edges
        Vx: Tensor = self.V(x) # (B, V, H)
        Vx = Vx.unsqueeze(1).expand(-1, nodes_num, -1, -1) # (B, V, V, H)
        
        # Message passing from nodes to edges
        Ax: Tensor = self.A(x) # (B, V, H), source
        Bx: Tensor = self.B(x) # (B, V, H), target
        
        # Update edge features
        Ce = self.C(e) # (B, V, V, H)
        e = Ax.unsqueeze(dim=1) + Bx.unsqueeze(dim=2) + Ce # (B, V, V, H)   (Wx_j + Wx_i + We_ij)
            
        # Sigmoid gates for edge features
        gates = torch.sigmoid(e) # (B, V, V, H)
        
        # Aggregate messages for node embeddings
        x = Ux + self.aggregate(Vx, gates, graph) # (B, V, H)

        # Apply normalization and activation
        x = x.view(batch_size * nodes_num, hidden_dim) # (B*V, H)
        x = F.relu(self.norm_x(x)).view(batch_size, nodes_num, hidden_dim) # (B, V, H)
        e = e.view(batch_size * nodes_num * nodes_num, hidden_dim)
        e = F.relu(self.norm_e(e)).view(batch_size, nodes_num, nodes_num, hidden_dim) # (B, V, V, H)
        
        return x, e
    
    def asym_forward(self, x: Tensor, e: Tensor, d: Tensor, graph: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (B, V, H) Node features; 
            e: (B, V, V, H) Decision Variables (edge)
            d: (B, V, V) Edge features (distance martix)
            graph: (B, V, V) Graph adjacency matrices
        Returns:
            Updated x and e after one layer of GNN.
        """
        batch_size, nodes_num, _, hidden_dim = e.shape
        
        # Linear transformation for node embeddings
        U0x: Tensor = self.U0(x) # (B, V, H)
        E1e: Tensor = self.E1(e)
        E2e: Tensor = self.E2(e)

        x = self.X(torch.cat([U0x, E1e.sum(1), E2e.sum(2)], dim=-1))
        U1x = self.U1(x)
        
        # Aggregate neighbor information for edges
        V1x: Tensor = self.V(x) # (B, V, H)
        V1x = V1x.unsqueeze(1).expand(-1, nodes_num, -1, -1) # (B, V, V, H)
        
        # Message passing from nodes to edges
        A1x: Tensor = self.A1(x) # (B, V, H), source
        A2x: Tensor = self.A2(x) # (B, V, H), target
        
        # Update edge features
        Ce = self.C(e) # (B, V, V, H)
        # Dd = self.D(d)
        e = A1x.unsqueeze(dim=1) + A2x.unsqueeze(dim=2) + Ce # (B, V, V, H)   (Wx_j + Wx_i + We_ij)
            
        # Sigmoid gates for edge features
        gates = torch.sigmoid(e) # (B, V, V, H)
        # gates_T = torch.sigmoid(e.transpose(1, 2)) # (B, V, V, H)
        
        # Aggregate messages for node embeddings
        x = U1x + self.aggregate(V1x, gates, graph) # + self.aggregate(V2x, gates_T, graph) # (B, V, H)
        
        # Apply normalization and activation
        x = x.view(batch_size * nodes_num, hidden_dim) # (B*V, H)
        x = F.relu(self.norm_x(x)).view(batch_size, nodes_num, hidden_dim) # (B, V, H)
        e = e.view(batch_size * nodes_num * nodes_num, hidden_dim)
        e = F.relu(self.norm_e(e)).view(batch_size, nodes_num, nodes_num, hidden_dim) # (B, V, V, H)
        
        return x, e
    
    def aggregate(self, Vx: Tensor, gates: Tensor, graph: Tensor) -> Tensor:  
        """
        Args:
            Vx: (B, V, H); gates: (B, V, V, H); graph: (B, V, V)

        Returns:
            node feature: (B, V, H)
        """
        Vx = Vx * gates
        if self.aggregation == "mean":
            return torch.sum(Vx, dim=2) / (torch.sum(graph, dim=2).unsqueeze(-1).type_as(Vx))
        elif self.aggregation == "max":
            return torch.max(Vx, dim=2)[0]
        else:
            return torch.sum(Vx, dim=2)
        
class GNNDenseBlock(nn.Module):
    def __init__(
        self, 
        num_layers: int, 
        hidden_dim: int, 
        aggregation: str = "sum", 
        norm: str = "layer",
        learn_norm: bool = True, 
        track_norm: bool = False,
        asym: bool = False
    ):
        super(GNNDenseBlock, self).__init__()
        
        # gnn layer
        self.layers = nn.ModuleList([
            GNNDenseLayer(hidden_dim, aggregation, norm, learn_norm, track_norm, asym)
            for _ in range(num_layers)
        ])
        
        # per layer out
        self.per_layer_out = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm),
                nn.SiLU(),
                zero_module(nn.Linear(hidden_dim, hidden_dim)),
            ) for _ in range(num_layers)
        ])

        self.node_time_layers = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(num_layers)
        ])
        
        self.edge_time_layers = nn.ModuleList([
            nn.Sequential(
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            ) for _ in range(num_layers)
        ])
    
    def forward(self, x: Tensor, e: Tensor, d: Tensor, t: Tensor, edge_index, 
                focus_on_node: bool, focus_on_edge: bool) -> Sequence[Tensor]:
        """
        Args:
            x: (B, V, H) Node features; 
            e: (B, V, V, H) Edge features;
            edge_index: None
            
        Return:
            updated features. x: (B, V, H); e: (B, V, V, H);
        """
        batch_size, nodes_num, _ = x.shape
        graph = torch.ones(size=(batch_size, nodes_num, nodes_num)).to(x.device)
        
        # gnn layer
        for layer, out_layer, edge_time_layer, node_time_layer in zip(
                self.layers, self.per_layer_out, self.edge_time_layers, self.node_time_layers #, self.dist_layers
            ):
            x_in, e_in = x, e
            x, e = layer(x, e, graph)
            if focus_on_edge:
                e = e + edge_time_layer(t) # + dist_layer(d)
            elif focus_on_node:
                raise NotImplementedError()
            x = x + x_in
            e = e_in + out_layer(e)
        
        # return
        return x, e

    def asym_forward(
        self, x: Tensor, e: Tensor, d: Tensor, edges_feature: Tensor, edge_index: Tensor, nodes_num: int 
    ) -> Sequence[Tensor]:
        """
        Args:
            x: (B, V, H) Node features; 
            e: (B, V, V, H) Decision Variables (edge);
            d: (B, V, V, H) Edge features;
            edge_index: None
            
        Return:
            updated features. x: (B, V, H); e: (B, V, V, H);
        """
        batch_size, nodes_num, _, _ = e.shape
        graph = torch.ones(size=(batch_size, nodes_num, nodes_num)).to(e.device)
        # gnn layer
        for layer, out_layer in zip(self.layers, self.per_layer_out):
            layer: GNNDenseLayer
            x_in, e_in = x, e
            x, e = layer.asym_forward(x, e, d, graph)
            x = x + x_in
            e = e_in + out_layer(e)
        
        # return
        return x, e
    

def zero_module(module: nn.Module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class GNNDenseLayer_V2(nn.Module):
    def __init__(
        self, 
        hidden_dim: int, 
        aggregation: str = "sum", 
        norm: str = "batch",
        learn_norm: bool = True,
        track_norm: bool = False,
        asym: bool = False
    ):
        super(GNNDenseLayer_V2, self).__init__()
        self.hidden_dim = hidden_dim
        self.aggregation = aggregation
                
        # Linear Layer for nodes
        self.U = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.A = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.B = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.V1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.V2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.X = nn.Linear(2 * hidden_dim, hidden_dim, bias=True)
        # Linear Layer for edges
        self.C = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.D = nn.Linear(2, hidden_dim, bias=True)
        
        # Normalization for nodes and edges
        if norm == "batch":
            self.norm_x = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
            self.norm_e = nn.BatchNorm1d(hidden_dim, affine=learn_norm, track_running_stats=track_norm)
        else:
            self.norm_x = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)
            self.norm_e = nn.LayerNorm(hidden_dim, elementwise_affine=learn_norm)

    
    def asym_forward(self, x: Tensor, e: Tensor, edges_feature: Tensor, graph: Tensor) -> Sequence[Tensor]:
        """
        Args:
            x: (B, V, H) Node features; 
            e: (B, V, V, H) noised solution (edge)
            edges_feature: (B, V, V) Edge features (distance martix)
            graph: (B, V, V) Graph adjacency matrices
        Returns:
            Updated x and e after one layer of GNN.
        """
        batch_size, nodes_num, hidden_dim = x.shape
        edges_feature = torch.stack([edges_feature, edges_feature.transpose(1, 2)], dim=3) # (B, V, V, 2)
        
        # Linear transformation for input embeddings (x, e, edges_feature)
        Ux: Tensor = self.U(x) # (B, V, H)
        Ce: Tensor = self.C(e)  # B x V x V x H / E x H
        De: Tensor = self.D(edges_feature) # (B, V, V, H)
        
        # Aggregate neighbor information for edges
        V1x: Tensor = self.V1(x) # (B, V, H)
        V2x: Tensor = self.V2(x) # (B, V, H)
        W1x = V1x.unsqueeze(1) + V2x.unsqueeze(2) # (B, V, V, H)
        W2x = V1x.unsqueeze(2) + V2x.unsqueeze(1) # (B, V, V, H)

        # Sigmoid gates for edge features
        gates = torch.sigmoid(e)  # B x V x V x H / E x H
        gates_T = torch.sigmoid(e.transpose(1, 2))

        # Aggregate messages for node embeddings
        x1 = Ux + self.aggregate(W1x, gates, graph) # (B, V, H)
        x2 = Ux + self.aggregate(W2x, gates_T, graph) # (B, V, H)
        x = self.X(torch.cat([x1, x2], dim=-1)) # (B, V, 2H) -> (B, V, H)
        
        # Message passing from nodes to edges
        Ax: Tensor = self.A(x1) # (B, V, H), source
        Bx: Tensor = self.B(x2) # (B, V, H), target
        
        # # Update edge features
        e = Ax.unsqueeze(1) + Bx.unsqueeze(2) + Ce + De # (B, V, V, H)

        # # Apply normalization and activation
        x = x.view(batch_size * nodes_num, hidden_dim) # (B*V, H)
        x = F.relu(self.norm_x(x)).view(batch_size, nodes_num, hidden_dim) # (B, V, H)
        e = e.view(batch_size * nodes_num * nodes_num, hidden_dim)
        e = F.relu(self.norm_e(e)).view(batch_size, nodes_num, nodes_num, hidden_dim) # (B, V, V, H)

        return x, e

    
    def aggregate(self, Vx: Tensor, gates: Tensor, graph: Tensor) -> Tensor:  
        """
        Args:
            Vx: (B, V, H); gates: (B, V, V, H); graph: (B, V, V)

        Returns:
            node feature: (B, V, H)
        """
        Vx = Vx * gates
        if self.aggregation == "mean":
            return torch.sum(Vx, dim=2) / (torch.sum(graph, dim=2).unsqueeze(-1).type_as(Vx))
        elif self.aggregation == "max":
            return torch.max(Vx, dim=2)[0]
        else:
            return torch.sum(Vx, dim=2)
