import torch
from torch import nn
from torch.nn import functional as F
import numpy as np


def permute_node_ordering_and_compute_covariance_matrix(X, A):
    perm = np.random.permutation(X.shape[1])
    X = X[:, perm, :]
    A = A[:, perm, :][:, :, perm]

    X_mean = torch.mean(X, dim=2, keepdim=True)
    X_std = torch.std(X, dim=2, keepdim=True)

    X = (X - X_mean) / (X_std + 1e-3)

    C = (X @ torch.transpose(X, 1, 2)) / (X.shape[2] - 1)  # empirical covariance matrix

    return X, C, A, perm

class DeepGraph(torch.nn.Module):
    def __init__(self, max_num_nodes, num_filters, kernel_size=3):
        super(DeepGraph, self).__init__()
        self.max_num_nodes = max_num_nodes
        self.num_filters = num_filters
        self.kernel_size = kernel_size
        self.num_layers = int(np.ceil(np.log2(max_num_nodes)))

        inp_conv = torch.nn.Conv2d(1, num_filters, kernel_size, stride=1, padding=1, dilation=1, bias=True,
                                   padding_mode='zeros')
        inp_conv_d = torch.nn.Conv2d(1, num_filters, kernel_size, stride=1, padding=1, dilation=1, bias=True,
                                   padding_mode='zeros')

        self.all_conv_layers = nn.ModuleList([inp_conv])
        self.all_conv_layers_d = nn.ModuleList([inp_conv_d])
        self.all_batch_norm = nn.ModuleList([torch.nn.BatchNorm2d(num_filters)])

        for k in range(1, self.num_layers):
            d = 2 ** k
            conv = torch.nn.Conv2d(num_filters, num_filters, kernel_size, stride=1, padding=d, dilation=d, bias=True,
                                       padding_mode='zeros')
            conv_d = torch.nn.Conv2d(num_filters, num_filters, kernel_size, stride=1, padding=d, dilation=d, bias=True,
                                         padding_mode='zeros')

            self.all_conv_layers.append(conv)
            self.all_conv_layers_d.append(conv_d)
            self.all_batch_norm.append(torch.nn.BatchNorm2d(num_filters))

        self.output_conv = torch.nn.Conv2d(num_filters, 1, 1, stride=1, padding=0, dilation=1, bias=True, padding_mode='zeros')

    def forward(self, C):
        C = C.unsqueeze(1)  # n_samples, 1, n_nodes, n_nodes

        for k in range(len(self.all_conv_layers)):
            C1 = self.all_conv_layers[k](C)
            C2 = self.all_conv_layers_d[k](C)
            C2_diag = torch.diagonal(C2, offset=0, dim1=2, dim2=3)

            C = C1 + C2_diag.unsqueeze(2) + C2_diag.unsqueeze(3)
            C = self.all_batch_norm[k](C)
            C = F.relu(C)

        out = self.output_conv(C).squeeze(1)  # n_samples, n_nodes, n_nodes

        return out