import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp
from models.utils import adj_to_symmetric_norm, scipy_sparse_mat_to_torch_sparse_tensor, adj_to_symmetric_norm_tensor, \
                            normalize_tensor
    
class GraphConvolution(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, dropout, task_level):
        super(GraphConvolution, self).__init__()
        self.query_edges = None
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc1_node_edge = nn.Linear(feat_dim, hidden_dim)
        self.task_level = task_level
        if task_level != "one_layer":
            self.fc2_node_edge = nn.Linear(hidden_dim, output_dim)
    def reset_parameters(self):
        self.fc1_node_edge.reset_parameters()
        self.fc2_node_edge.reset_parameters()

    def forward(self, feature, adj):
        x = torch.mm(adj, feature)
        x = self.fc1_node_edge(x)
        x = self.relu(x)
        if self.dropout:
            x = self.dropout(x)
        if self.task_level == "one_layer":
            return x
        if self.query_edges is None:
            x = torch.mm(adj, x)
            output = self.fc2_node_edge(x)
        else:
            x = torch.mm(adj, x)
            x = torch.cat((x[self.query_edges[:, 0]], x[self.query_edges[:, 1]]), dim=-1)
            output = self.fc2_node_edge(x)
        if self.task_level == "node":
            return F.log_softmax(output, dim=1)
        else:
            return output

class GCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, dropout, task_level):
        super(GCN, self).__init__()

        self.base_model = GraphConvolution(feat_dim=feat_dim, hidden_dim=hidden_dim, output_dim=output_dim, dropout=dropout, task_level=task_level)
        self.processed_feature = None
        # self.base_model.reset_parameters()
    def preprocess(self, adj, feature, type="default", tensor_feat=True):
        if type == "tensor":
            if adj.is_sparse:
                adj = adj.to_dense()
            adj_norm = adj_to_symmetric_norm_tensor(adj)
        else:
            adj_norm = adj_to_symmetric_norm(adj, 0.5)
            # adj_norm = adj
            adj_norm = scipy_sparse_mat_to_torch_sparse_tensor(adj_norm)
        self.processed_adj = adj_norm
        if feature is not None:
            if tensor_feat:
                self.processed_feature = torch.FloatTensor(feature)
            else:
                self.processed_feature = feature
    def model_forward(self, idx, device, ori=None):
        return self.forward(idx, device, ori)
    def forward(self, idx, device, ori):
        processed_feature = None
        processed_feature = self.processed_feature.to(device)
        processed_adj = self.processed_adj.to(device)
        if ori is not None:
            self.base_model.query_edges = ori
            
        output = self.base_model(processed_feature, processed_adj)
        return output[idx] if idx is not None else output