import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from torch import Tensor
import platform
from models.utils import adj_to_symmetric_norm, scipy_sparse_mat_to_torch_sparse_tensor, adj_to_symmetric_norm_tensor, \
                        csr_sparse_dense_matmul
    
class LogisticRegression(nn.Module):
    def __init__(self, feat_dim, edge_dim, output_dim, dropout, task_level):
        super(LogisticRegression, self).__init__()
        self.adj = None
        self.query_edges = None
        self.dropout = nn.Dropout(dropout)
        if task_level == "edge": 
            self.fc_node_edge = nn.Linear(feat_dim, edge_dim) 
            self.linear = nn.Linear(2*edge_dim, output_dim)
        else:
            self.fc_node_edge = nn.Linear(feat_dim, output_dim)

    def forward(self, feature):
        if self.query_edges is None:
            output = self.fc_node_edge(feature)
        else:
            x = self.fc_node_edge(feature)
            x = torch.cat((x[self.query_edges[:, 0]], x[self.query_edges[:, 1]]), dim=-1)
            x = self.dropout(x)
            output = self.linear(x)
        return output

class SGC(nn.Module):
    def __init__(self, prop_steps, feat_dim, output_dim, dropout, task_level):
        super(SGC, self).__init__()
        
        self.prop_steps = prop_steps
        self.base_model = LogisticRegression(feat_dim=feat_dim, edge_dim=None, output_dim=output_dim, dropout=dropout, task_level=task_level)
        self.processed_feature = None
        # self.base_model.reset_parameters()
        
    def propagate(self, adj, feature):
        adj_norm = adj_to_symmetric_norm(adj, 0.5).tocsr()
        # adj_norm = adj
        # adj_norm = scipy_sparse_mat_to_torch_sparse_tensor(adj_norm)
        self._adj = adj_norm
        
        if not isinstance(adj, sp.csr_matrix):
            raise TypeError("The adjacency matrix must be a scipy csr sparse matrix!")
        elif not isinstance(feature, np.ndarray):
            raise TypeError("The feature matrix must be a numpy.ndarray!")
        elif self._adj.shape[1] != feature.shape[0]:
            raise ValueError("Dimension mismatch detected for the adjacency and the feature matrix!")

        prop_feat_list = [feature]
        for _ in range(self.prop_steps):
            # if platform.system() == "Linux":
            #     feat_temp = csr_sparse_dense_matmul(self._adj, prop_feat_list[-1])
            # else:
            feat_temp = self._adj.dot(prop_feat_list[-1])
            prop_feat_list.append(feat_temp)
        return [torch.FloatTensor(feat) for feat in prop_feat_list]
    def preprocess(self, adj, feature):
        self.processed_feat_list = self.propagate(adj, feature.numpy())
        self.processed_feature = self.processed_feat_list[-1]
    def model_forward(self, idx, device):
        return self.forward(idx, device)
    def forward(self, idx, device):
        if idx is not None:
            processed_feature = self.processed_feature[idx].to(device)
        else:
            processed_feature = self.processed_feature.to(device)
            
        output = self.base_model(processed_feature)
        return output