import torch
from torch import nn

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class LinearMapping(nn.Module):
    def __init__(self, input_dim, sensitive_size, mapping_channels):
        '''
        sx -> y
        '''
        super(LinearMapping, self).__init__()
        self.fc1 = nn.Linear(input_dim + sensitive_size, 1)
          
        self.reset_parameters()
        
    def reset_parameters(self):
        self.fc1.reset_parameters()

    def forward(self, x, s,edge_index):
        xs = torch.cat((x, s), dim=1)
        with torch.no_grad():
            out = self.fc1(xs)
        return torch.flatten(out)
    
class LinearMappingNeighbor(MessagePassing):
    def __init__(self, in_channels, sensitive_size, mapping_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = nn.Linear(sensitive_size, mapping_channels, bias=False)
        self.bias = nn.Parameter(torch.Tensor(mapping_channels))
        self.fc1 = nn.Linear(in_channels + mapping_channels, 1)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()
        self.fc1.reset_parameters()
        
    def forward(self, x, s, edge_index):
        # x has shape [N, in_channels]
        # devide x into s, x
        with torch.no_grad():
            # edge_index has shape [2, E]

            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])

            # Step 2: Linearly transform node feature matrix.
            s = self.lin(s)

            # Step 3: Compute normalization.
            row, col = edge_index
            deg = degree(col, s.shape[0], dtype=s.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

            # Step 4-5: Start propagating messages.
            aggr_s = self.propagate(edge_index, x=s, norm=norm)

            # Step 6: Add itself features
            aggr_s += self.bias
            
            aggr_x = torch.cat((x, aggr_s), dim= 1)
            out = self.fc1(aggr_x)

            return torch.flatten(out)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j
    
    