import torch
import torch.nn as nn
import torch.nn.functional as F
from config import args
from utils import sparse_mx_to_torch_sparse_tensor, homo_adj_to_symmetric_norm, hete_adj_to_symmetric_norm


class NLMLP(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, dropout=0.5, bn=False, ln=False, kernel=5):
        super(NLMLP, self).__init__()
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.kernel = kernel
        self.dropout = dropout
        self.bn = bn
        self.ln =ln

        self.use_graph_op = False
        self.pre_graph_op = None
        
        self.lin1 = nn.Linear(feat_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, output_dim)
        self.proj = nn.Linear(output_dim, 1)
        self.conv1d = nn.Conv1d(output_dim, output_dim, kernel, padding=int((kernel-1)/2))
        self.conv1d_2 = nn.Conv1d(output_dim, output_dim, kernel, padding=int((kernel-1)/2))
        self.lin = nn.Linear(2*output_dim, output_dim)

        self.post_graph_op = None

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.proj.reset_parameters()
        self.conv1d.reset_parameters()
        self.conv1d_2.reset_parameters()
        self.lin.reset_parameters()


    def preprocess(self, adj, feature):
        self.pre_msg_learnable = False
        self.processed_feature = feature

    def postprocess(self, adj, output):
        if self.post_graph_op is not None:
            output = F.softmax(output, dim=1)
            output = output.detach().numpy()
            output = self.post_graph_op.propagate(adj, output)
            output = self.post_msg_op.aggregate(output)

        return output


    # a wrapper of the forward function
    def model_forward(self, idx, device):
        return self.forward(idx, device)


    def forward(self, idx, device):
        processed_feature = None
        if self.pre_msg_learnable is False:
            processed_feature = self.processed_feature.to(device)
        else:
            transferred_feat_list = [feat.to(
                device) for feat in self.processed_feat_list]
            processed_feature = self.pre_msg_op.aggregate(
                transferred_feat_list)

        x = processed_feature
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x1 = self.lin2(x)

        g_score = self.proj(x1)  
        g_score_sorted, sort_idx = torch.sort(g_score, dim=0)
        _, inverse_idx = torch.sort(sort_idx, dim=0)
        

        sorted_x = g_score_sorted*x1[sort_idx].squeeze()
        sorted_x = torch.transpose(sorted_x, 0, 1).unsqueeze(0) 
        sorted_x = F.relu(self.conv1d(sorted_x))
        sorted_x = F.dropout(sorted_x, p=self.dropout, training=self.training)
        sorted_x = self.conv1d_2(sorted_x)
        sorted_x = torch.transpose(sorted_x.squeeze(), 0, 1) 
        x2 = sorted_x[inverse_idx].squeeze()  
        
        out = torch.cat([x1,x2], dim=1)
        out = self.lin(out)

        return out[idx]

class NLGCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, dropout=0.5, bn=False, ln=False, kernel=5):
        super(NLGCN, self).__init__()
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.kernel = kernel
        self.dropout = dropout
        self.bn = bn
        self.ln =ln

        self.use_graph_op = False
        self.pre_graph_op = None

        self.conv1 = GCNConv(feat_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.proj = nn.Linear(output_dim, 1)
        self.conv1d = nn.Conv1d(output_dim, output_dim, self.kernel, padding=int((self.kernel-1)/2))
        self.conv1d_2 = nn.Conv1d(output_dim, output_dim, self.kernel, padding=int((self.kernel-1)/2))
        self.lin = nn.Linear(2*output_dim, output_dim)


        self.post_graph_op = None

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.proj.reset_parameters()
        self.conv1d.reset_parameters()
        self.conv1d_2.reset_parameters()
        self.lin.reset_parameters()


    def preprocess(self, adj, feature, homo=args.homo):
        self.pre_msg_learnable = False
        self.processed_feature = feature

        if homo:
            adj = homo_adj_to_symmetric_norm(adj, r=0.5)
        else:
            adj = hete_adj_to_symmetric_norm(adj, r=0.5)

        self.adj = sparse_mx_to_torch_sparse_tensor(adj)

    def postprocess(self, adj, output):
        if self.post_graph_op is not None:
            output = F.softmax(output, dim=1)
            output = output.detach().numpy()
            output = self.post_graph_op.propagate(adj, output)
            output = self.post_msg_op.aggregate(output)

        return output


    # a wrapper of the forward function
    def model_forward(self, idx, device):
        return self.forward(idx, device)


    def forward(self, idx, device):
        processed_feature = None
        if self.pre_msg_learnable is False:
            processed_feature = self.processed_feature.to(device)
        else:
            transferred_feat_list = [feat.to(
                device) for feat in self.processed_feat_list]
            processed_feature = self.pre_msg_op.aggregate(
                transferred_feat_list)


        x, adj = processed_feature, self.adj.to(device)
        x = F.relu(self.conv1(x, adj))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x1 = self.conv2(x, adj)
        
        g_score = self.proj(x1) 
        g_score_sorted, sort_idx = torch.sort(g_score, dim=0)
        _, inverse_idx = torch.sort(sort_idx, dim=0)
        
        sorted_x = g_score_sorted*x1[sort_idx].squeeze()
        sorted_x = torch.transpose(sorted_x, 0, 1).unsqueeze(0) 
        sorted_x = F.relu(self.conv1d(sorted_x))
        sorted_x = F.dropout(sorted_x, p=self.dropout, training=self.training)
        sorted_x = self.conv1d_2(sorted_x)
        sorted_x = torch.transpose(sorted_x.squeeze(), 0, 1)
        x2 = sorted_x[inverse_idx].squeeze()  
        
        out = torch.cat([x1,x2], dim=1)
        out = self.lin(out)
        

        return out[idx]

class GCNConv(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNConv, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, adjacency_hat):
        x = self.linear(x)
        x = torch.mm(adjacency_hat, x)
        return x