import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from HoloConvLayer import HoloConv


def adj_matrix_to_edge_idx(adj_matrix: sp.spmatrix, device) -> torch.LongTensor:
    amcoo = adj_matrix.tocoo()
    return torch.from_numpy(
        np.stack([amcoo.row, amcoo.col], axis=0)
    ).to(dtype=torch.long, device=device)






class ChebModel(torch.nn.Module):
    def __init__(self, input_dim, out_dim, edge_index, filter_num, K, dropout=False,
                 layer=2):
        super(ChebModel, self).__init__()

        #### Initialize Laplacian: #######





        ### Initialize Layers#######

        self.dropout = dropout
        self.edge_index = edge_index
        self.conv1 = ChebConv(input_dim, filter_num, K)
        self.conv2 = ChebConv(filter_num, filter_num, K)
        self.Conv = nn.Conv1d(filter_num, out_dim, kernel_size=1)

        self.layer = layer
        if layer == 3:
            self.conv3 = ChebConv(filter_num, filter_num, K)
        self.reg_params = list(self.parameters())

    def forward(self, data, idx):
        x, edge_index = data, self.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        print('hello')

        if self.layer == 3:
            x = self.conv3(x, edge_index)
            x = F.relu(x)

        if self.dropout > 0:
            x = F.dropout(x, self.dropout, training=self.training)
        x = x.unsqueeze(0)
        x = x.permute((0, 2, 1))
        x = self.Conv(x)
        x = x.permute((0, 2, 1)).squeeze()

        return F.log_softmax(x[idx], dim=1)



