from typing import Optional
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv,  GINConv, ChebConv
from .HoloConvLayer import HoloConv
import os.path as osp
from torch.nn import Linear
# import scipy.sparse as sp
from .BernConv import Bern_prop




from torch_geometric.typing import OptTensor


def adj_matrix_to_edge_idx(adj_matrix: sp.spmatrix, device) -> torch.LongTensor:
    amcoo = adj_matrix.tocoo()
    return torch.from_numpy(
        np.stack([amcoo.row, amcoo.col], axis=0)
    ).to(dtype=torch.long, device=device)


from .res_layer_via_matmul import ResolventConvLayerViaMatMul, ReLUWithIdx





class ResConvModel_via_matmul(torch.nn.Module):
    def __init__(self, input_dim, out_dim, adj_matrix, K_minus, zero_order,singular_value_normalization ,  hidden_channel_list = [128], omega=-1, dropout=False,
                 normalizing_factor_minus = None, bias = True, resolvent_path: Optional[str] = None):
        super(ResConvModel_via_matmul, self).__init__()
        self.dropout = dropout
        self.graph_conv_layers = nn.ModuleList()

        dimensions = [input_dim] + list(hidden_channel_list)
        for c_in, c_out in zip(dimensions[:-1], dimensions[1:]):
            self.graph_conv_layers.append(ResolventConvLayerViaMatMul(
                c_in, c_out,adj_matrix=adj_matrix, K_minus=K_minus, zero_order=zero_order,
                omega = omega, singular_value_normalization=singular_value_normalization, 
                normalizing_factor_minus=normalizing_factor_minus, bias=bias, resolvent_path=resolvent_path))

        # self.conv1 = HoloConv(input_dim, filter_num, K_plus, K_minus, edge_index, L_singular_value_normalization=True, R_singular_value_normalization=False)
        # self.conv2 = HoloConv(filter_num, filter_num, K_plus, K_minus, edge_index, L_singular_value_normalization=True, R_singular_value_normalization=False)
        self.fc_layer = nn.Conv1d(dimensions[-1], out_dim, kernel_size=1)

        # self.layer = layer
        # if layer == 3:
        #     self.conv3 = HoloConv(filter_num, filter_num, K)
        self.reg_params = list(self.parameters())

    def forward(self, x, idx):
        
        for convlayer in self.graph_conv_layers:
            # print(x.shape)
            x = convlayer(x)
            # print(x.shape)
            x= F.relu(x)
            # x, _ = ReLUWithIdx()((x, idx))

        # x = self.conv1(x, edge_index)
        # x = F.relu(x)
        # x = self.conv2(x, edge_index)
        
        # x = F.relu(x)

        # if self.layer == 3:
        #     x = self.conv3(x, edge_index)
        #     x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        # print(x.shape)
        x = x.unsqueeze(0)
        # print(x.shape)
        x = x.permute((0, 2, 1))
        # print(x.shape)
        x = self.fc_layer(x)
        # print(x.shape)
        x = x.permute((0, 2, 1)).squeeze(0)
        # print(x.shape)
        # raise Exception

        return F.log_softmax(x[idx], dim=1)



########################################################################################################################################################################################################################### 



class ResConv_predict_first_via_matmul(torch.nn.Module):
    def __init__(self, input_dim, out_dim, adj_matrix, K_minus, zero_order, hidden,singular_value_normalization, omega=-1, dropout=0.0, dprate=0.0,
                 normalizing_factor_minus = None, bias = True, resolvent_path: Optional[str] = None):
        super(ResConv_predict_first_via_matmul, self).__init__()
        #    def __init__(self,  input_dim, out_dim, edge_index, hidden, K, dprate, dropout):
        
        self.lin1 = Linear(input_dim, hidden)
        self.lin2 = Linear(hidden, out_dim)
        # self.m = torch.nn.BatchNorm1d(out_dim)
        self.prop1 = ResolventConvLayerViaMatMul(out_dim, out_dim,adj_matrix=adj_matrix, K_minus=K_minus, zero_order=zero_order,
                                                                         omega = omega, singular_value_normalization=singular_value_normalization,
                                                                         normalizing_factor_minus=normalizing_factor_minus, bias=bias,
                                                                         resolvent_path=resolvent_path)
       
 
        self.dprate = dprate
        self.dropout = dropout

        self.reg_params = list(self.parameters())

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, x, idx):
   
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
      
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        x = self.prop1(x)

        return F.log_softmax(x[idx], dim=1)


####################################################################################################################################################################################################################

class HoloConv_predict_first(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, K_plus, K_minus, hidden,  omega=-1, dropout= 0.0, dprate = 0.0,
                 edge_weights = None, L_singular_value_normalization=True, R_singular_value_normalization=False, normalizing_factor_plus = None, normalizing_factor_minus = None, bias = True ):
        super(HoloConv_predict_first, self).__init__()
        #    def __init__(self,  input_dim, out_dim, edge_index, hidden, K, dprate, dropout):
        
        self.lin1 = Linear(input_dim, hidden)
        self.lin2 = Linear(hidden, out_dim)
        # self.m = torch.nn.BatchNorm1d(out_dim)
        self.prop1 = HoloConv(out_dim, out_dim, K_plus, K_minus,
                               edge_index, omega, L_singular_value_normalization= L_singular_value_normalization,
                                 R_singular_value_normalization= R_singular_value_normalization,
                                   normalizing_factor_plus = normalizing_factor_plus, 
                                   normalizing_factor_minus = normalizing_factor_minus, bias=bias)
       
        self.edge_index = edge_index
        self.dprate = dprate
        self.dropout = dropout
        self.reg_params = list(self.parameters())

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
     
        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)
        x = self.prop1(x)

        return F.log_softmax(x[idx], dim=1)



class HoloConvModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, K_plus, K_minus, hidden_channel_list = [128], omega=-1, dropout=False,
                 edge_weights = None, L_singular_value_normalization=True, R_singular_value_normalization=False, normalizing_factor_plus = None, normalizing_factor_minus = None, bias = True ):
        super(HoloConvModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.edge_weights = edge_weights
        self.graph_conv_layers = nn.ModuleList()

        dimensions = [input_dim] + list(hidden_channel_list)
        for c_in, c_out in zip(dimensions[:-1], dimensions[1:]):
            self.graph_conv_layers.append(HoloConv(c_in, c_out, K_plus, K_minus, edge_index, omega, L_singular_value_normalization=L_singular_value_normalization, R_singular_value_normalization=R_singular_value_normalization,
                                   normalizing_factor_plus = normalizing_factor_plus, 
                                   normalizing_factor_minus = normalizing_factor_minus, bias=bias))

        # self.conv1 = HoloConv(input_dim, filter_num, K_plus, K_minus, edge_index, L_singular_value_normalization=True, R_singular_value_normalization=False)
        # self.conv2 = HoloConv(filter_num, filter_num, K_plus, K_minus, edge_index, L_singular_value_normalization=True, R_singular_value_normalization=False)
        self.fc_layer = nn.Conv1d(dimensions[-1], out_dim, kernel_size=1)

        # self.layer = layer
        # if layer == 3:
        #     self.conv3 = HoloConv(filter_num, filter_num, K)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index
        for convlayer in self.graph_conv_layers:
            # print(x.shape)
            x = convlayer(x, edge_index)
            # print(x.shape)
            x = F.relu(x)

        # x = self.conv1(x, edge_index)
        # x = F.relu(x)
        # x = self.conv2(x, edge_index)
        
        # x = F.relu(x)

        # if self.layer == 3:
        #     x = self.conv3(x, edge_index)
        #     x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        # print(x.shape)
        x = x.unsqueeze(0)
        # print(x.shape)
        x = x.permute((0, 2, 1))
        # print(x.shape)
        x = self.fc_layer(x)
        # print(x.shape)
        x = x.permute((0, 2, 1)).squeeze(0)
        # print(x.shape)
        # raise Exception

        return F.log_softmax(x[idx], dim=1)



########################################################################################################################################################################################################################### 



###########################################################################################################################################################################################################################





        

class HoloConv_laplacian_ppnp_propagate(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index,  hidden, propagation: nn.Module,   dropout= 0.0, dprate = 0.0,
                 bias: bool = False ):
        super(HoloConv_laplacian_ppnp_propagate, self).__init__()
           
        self.lin1 = Linear(input_dim, hidden)
        self.lin2 = Linear(hidden, out_dim)
        self.m = torch.nn.BatchNorm1d(out_dim)

        self.prop1 = propagation

        self.edge_index = edge_index

        self.dprate = dprate
        self.dropout = dropout

        self.reg_params = list(self.parameters())

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, data, idx):
        x = data
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
 

        if self.dprate != 0.0:
            x = F.dropout(x, p=self.dprate, training=self.training)

        x = self.prop1(x, idx)

        return F.log_softmax(x, dim=1)




########################### Baselines ################################
from torch_geometric.nn import ARMAConv





class ARMAModel(torch.nn.Module):
     def __init__(self, input_dim, out_dim, edge_index, filter_num, num_stacks = 2, num_layers = 1,  dropout=False,
                 layer=2):
        super(ARMAModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = ARMAConv(input_dim, filter_num, num_stacks = num_stacks, num_layers = num_layers, dropout =  dropout)
        self.conv2 = ARMAConv(filter_num, filter_num,num_stacks = num_stacks, num_layers = num_layers, dropout =  dropout)
        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)

        self.layer = layer
        if layer == 3:
            self.conv3 = ARMAConv(filter_num, filter_num)
        
        self.reg_params = list(self.parameters())

     def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        x = F.relu(x)

        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)



################################################################################################





#  input_dim, out_dim,

class BernNet(torch.nn.Module):
    def __init__(self,  input_dim, out_dim, edge_index, hidden, K, dprate, dropout):
        super(BernNet, self).__init__()
        self.lin1 = Linear(input_dim, hidden)
        self.lin2 = Linear(hidden, out_dim)
        self.m = torch.nn.BatchNorm1d(out_dim)
        self.prop1 = Bern_prop(K)

        self.edge_index = edge_index

        self.dprate = dprate
        self.dropout = dropout

        self.reg_params = list(self.parameters())

    def reset_parameters(self):
        self.prop1.reset_parameters()

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        #x= self.m(x)

        if self.dprate == 0.0:
            x = self.prop1(x, edge_index)
            # return F.log_softmax(x, dim=1)
            return F.log_softmax(x[idx], dim=1)
        else:
            x = F.dropout(x, p=self.dprate, training=self.training)
            x = self.prop1(x, edge_index)
            # return F.log_softmax(x, dim=1)
            return F.log_softmax(x[idx], dim=1)








################################################################################################




class GATModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, heads, filter_num, dropout=False,
                 layer=2):
        super(GATModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = GATConv(input_dim, filter_num, heads=heads)
        self.conv2 = GATConv(filter_num * heads, filter_num, heads=heads)
        self.Conv = nn.Conv1d(filter_num * heads, out_dim, kernel_size=1)
        self.layer = layer
        if layer == 3:
            self.conv3 = GATConv(filter_num * heads, filter_num, heads=heads)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)


class SAGEModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, filter_num, dropout=False, layer=2):
        super(SAGEModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = SAGEConv(input_dim, filter_num)
        self.conv2 = SAGEConv(filter_num, filter_num)
        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)

        self.layer = layer
        if layer == 3:
            self.conv3 = SAGEConv(filter_num, filter_num)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)


class GCNModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, filter_num, dropout=False, layer=2):
        super(GCNModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = GCNConv(input_dim, filter_num)
        self.conv2 = GCNConv(filter_num, filter_num)
        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)

        self.layer = layer
        if layer == 3:
            self.conv3 = GCNConv(filter_num, filter_num)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x, dim=1)[idx]


class ChebModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, filter_num, K, dropout=False,
                 layer=2):
        super(ChebModel, self).__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = ChebConv(input_dim, filter_num, K)
        self.conv2 = ChebConv(filter_num, filter_num, K)
        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)

        self.layer = layer
        if layer == 3:
            self.conv3 = ChebConv(filter_num, filter_num, K)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        # print('hello')
        x = F.relu(x)

        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)


class GINModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, filter_num, dropout=False, layer=2):
        super().__init__()
        self.dropout = dropout
        self.edge_index = edge_index
        self.line1 = nn.Linear(input_dim, filter_num)
        self.line2 = nn.Linear(filter_num, filter_num)

        self.conv1 = GINConv(self.line1)
        self.conv2 = GINConv(self.line2)

        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)
        self.layer = layer
        if layer == 3:
            self.line3 = nn.Linear(filter_num, filter_num)
            self.conv3 = GINConv(self.line3)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)
