import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn
from torch.nn import Parameter
from torch.nn import Linear
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch.nn.utils import spectral_norm
from scipy.sparse import coo_matrix
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from utils import *
from .classifier import *
from torch_geometric.data import Data


class MLP(nn.Module):
    def __init__(self, args, input_dim, hidden_dim):
        super(MLP, self).__init__()
        self.args = args
            
        self.lin = nn.Sequential(
            Linear(input_dim, hidden_dim),
            nn.Dropout(p=args.dropout),
            nn.ReLU(),
            Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x, edge_index=None):
        feat = self.lin(x)
        return feat


class GCN(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout):
        super(GCN, self).__init__()
        self.body = GCN_Body(nfeat,nhid,dropout)
        self.fc = nn.Linear(nhid, nclass)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, x, edge_index):
        x = self.body(x, edge_index)
        x = self.fc(x)
        return x


class GCN_Body(nn.Module):
    def __init__(self, nfeat, nhid, dropout):
        super(GCN_Body, self).__init__()
        self.gc1 = GCNConv(nfeat, nhid)

    def forward(self, x, edge_index):
        x = self.gc1(x, edge_index)
        return x    



class GNN(nn.Module):
    def __init__(self, args, input_dim, hidden_dim, encoder_type):
        super(GNN, self).__init__()
        self.args = args
        self.dropout = nn.Dropout(args.dropout)
        self.feat_proj = nn.Linear(input_dim, hidden_dim)
        encoder_type = encoder_type.lower()
        if encoder_type == "gcn":
            self.gnn_layers = nn.ModuleList([GCNConv(hidden_dim, hidden_dim) for i in range(args.n_layers)])
        elif encoder_type == "sage":
            self.gnn_layers = nn.ModuleList([SAGEConv(hidden_dim, hidden_dim) for i in range(args.n_layers)])
        elif encoder_type == "gat":
            self.gnn_layers = nn.ModuleList([GATConv(hidden_dim, hidden_dim, heads=1) for i in range(args.n_layers)])
        else:
            raise NotImplementedError("Invalid GNN type. Choose from 'GCN', 'SAGE', or 'GAT'.")

    def forward(self, feat, edge_index):
        x = self.feat_proj(feat) 
        all_embeddings = [x]
        for i, gnn in enumerate(self.gnn_layers):
            x = gnn(x, edge_index)
            x = F.relu(x)
            if i != len(self.gnn_layers) - 1:
                x = self.dropout(x)
            all_embeddings.append(x)
        
        # 多层融合
        all_embeddings = torch.stack(all_embeddings, dim=1)
        all_embeddings = torch.mean(all_embeddings, dim=1)
        return all_embeddings

class Encoder(nn.Module):
    def __init__(self, args, encoder_type):
        super(Encoder, self).__init__()
        if encoder_type == "MLP":
            self.body = MLP(args, args.num_features, args.hidden_dim)
        elif encoder_type == "vanilla":
            self.body = GCN_Body(args.num_features, args.hidden_dim, args.dropout)
        else:
            self.body = GNN(args, args.num_features, args.hidden_dim, encoder_type)
        self.fc = Classifier(args, input_dim = args.hidden_dim, num_cls = 1) 

        for m in self.modules():
            self.weights_init(m)
            
        self.to(args.device)

    def forward(self, feat, edge_index):
        embeddings = self.body(feat, edge_index)
        cls = self.fc(embeddings)
        return embeddings, cls


    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)


