import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import matplotlib.pyplot as plt
import torch.optim as optim
## r=1
from torch_geometric.nn import global_max_pool

from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d as BN, LeakyReLU as LRU
from torch.nn import Sequential as Seq, Dropout, Linear as Lin

try:
    from src import graphOps as GO
    from src.graphOps import getConnectivity
    from mpl_toolkits.mplot3d import Axes3D
    from src.utils import saveMesh, h_swish
    from src.inits import glorot, identityInit

except:
    import graphOps as GO
    from graphOps import getConnectivity
    from mpl_toolkits.mplot3d import Axes3D
    from utils import saveMesh, h_swish
    from inits import glorot, identityInit

from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn import GCN2Conv
from torch_scatter import scatter_add


def conv2(X, Kernel):
    return F.conv2d(X, Kernel, padding=int((Kernel.shape[-1] - 1) / 2))


def conv1(X, Kernel):
    return F.conv1d(X, Kernel, padding=int((Kernel.shape[-1] - 1) / 2))


def conv1T(X, Kernel):
    return F.conv_transpose1d(X, Kernel, padding=int((Kernel.shape[-1] - 1) / 2))


def conv2T(X, Kernel):
    return F.conv_transpose2d(X, Kernel, padding=int((Kernel.shape[-1] - 1) / 2))


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


def tv_norm(X, eps=1e-3):
    X = X - torch.mean(X, dim=1, keepdim=True)
    X = X / torch.sqrt(torch.sum(X ** 2, dim=1, keepdim=True) + eps)
    return X


def diffX(X):
    X = X.squeeze()
    return X[:, 1:] - X[:, :-1]


def diffXT(X):
    X = X.squeeze()
    D = X[:, :-1] - X[:, 1:]
    d0 = -X[:, 0].unsqueeze(1)
    d1 = X[:, -1].unsqueeze(1)
    D = torch.cat([d0, D, d1], dim=1)
    return D


def doubleLayer(x, K1, K2):
    x = F.conv1d(x, K1.unsqueeze(-1))
    x = F.layer_norm(x, x.shape)
    x = torch.relu(x)
    x = F.conv1d(x, K2.unsqueeze(-1))
    return x


###################################################################################pdegcn


def MLP(channels, batch_norm=True):
    return Seq(*[
        Seq(Lin(channels[i - 1], channels[i]), BN(channels[i]), ReLU())
        for i in range(1, len(channels))
    ])


class graphNetwork_nodesOnly(nn.Module):

    def __init__(self, nNin, nopen, nhid, nNclose, nlayer, h=0.1, dense=False, varlet=False, wave=True,
                 diffOrder=1, num_output=1024, dropOut=False, modelnet=False, faust=False, GCNII=False,
                 graphUpdate=None, PPI=False, gated=False, realVarlet=False, mixDyamics=False, doubleConv=False,
                 tripleConv=False):
        super(graphNetwork_nodesOnly, self).__init__()
        self.wave = wave
        self.realVarlet = realVarlet
        if not wave:
            self.heat = True
        else:
            self.heat = False
        self.mixDynamics = mixDyamics
        self.h = h
        self.varlet = varlet
        self.dense = dense
        self.diffOrder = diffOrder
        self.num_output = num_output
        self.graphUpdate = graphUpdate
        self.doubleConv = doubleConv
        self.tripleConv = tripleConv
        self.gated = gated
        self.faust = faust
        self.PPI = PPI
        if dropOut > 0.0:
            self.dropout = dropOut
        else:
            self.dropout = False
        self.nlayers = nlayer
        stdv = 1e-2
        stdvp = 1e-2
        if self.faust or self.PPI:
            stdv = 1e-1
            stdvp = 1e-1
            stdv = 1e-2
            stdvp = 1e-2
        self.K1Nopen = nn.Parameter(torch.randn(nopen, nNin) * stdv)
        self.K2Nopen = nn.Parameter(torch.randn(nopen, nopen) * stdv)
        self.convs1x1 = nn.Parameter(torch.randn(nlayer, nopen, nopen) * stdv)
        self.modelnet = modelnet

        if self.modelnet:
            self.KNclose = nn.Parameter(torch.randn(1024, num_output) * stdv)  # num_output on left size
        elif not self.faust:
            self.KNclose = nn.Parameter(torch.randn(num_output, nopen) * stdv)  # num_output on left size
        else:
            self.KNclose = nn.Parameter(torch.randn(nopen, nopen) * stdv)

        if varlet:
            Nfeatures = 1 * nopen
        else:
            Nfeatures = 1 * nopen

        self.KN1 = nn.Parameter(torch.rand(nlayer, Nfeatures, nhid) * stdvp)
        rrnd = torch.rand(nlayer, Nfeatures, nhid) * (1e-3)

        self.KN1 = nn.Parameter(identityInit(self.KN1) + rrnd)

        if self.realVarlet:
            self.KN1 = nn.Parameter(torch.rand(nlayer, nhid, 2 * Nfeatures) * stdvp)
            self.KE1 = nn.Parameter(torch.rand(nlayer, nhid, 2 * Nfeatures) * stdvp)

        if self.mixDynamics:
            self.alpha = nn.Parameter(-0 * torch.ones(1, 1))

        self.KN2 = nn.Parameter(torch.rand(nlayer, nhid, 1 * nhid) * stdvp)
        self.KN2 = nn.Parameter(identityInit(self.KN2))

        if self.tripleConv:
            self.KN3 = nn.Parameter(torch.rand(nlayer, nopen, 1 * nhid) * stdvp)
            self.KN3 = nn.Parameter(identityInit(self.KN3))

        if self.faust:
            self.lin1 = torch.nn.Linear(nopen, nopen)
            self.lin2 = torch.nn.Linear(nopen, num_output)

        self.modelnet = modelnet

        self.PPI = PPI
        if self.modelnet:
            self.mlp = Seq(
                MLP([64, 128]), Dropout(0.5), MLP([128, 64]), Dropout(0.5),
                Lin(64, 10))

    def reset_parameters(self):
        glorot(self.K1Nopen)
        glorot(self.K2Nopen)
        glorot(self.KNclose)
        if self.realVarlet:
            glorot(self.KE1)
        if self.modelnet:
            glorot(self.mlp)

    def edgeConv(self, xe, K, groups=1):
        if xe.dim() == 4:
            if K.dim() == 2:
                xe = F.conv2d(xe, K.unsqueeze(-1).unsqueeze(-1), groups=groups)
            else:
                xe = conv2(xe, K, groups=groups)
        elif xe.dim() == 3:
            if K.dim() == 2:
                xe = F.conv1d(xe, K.unsqueeze(-1), groups=groups)
            else:
                xe = conv1(xe, K, groups=groups)
        return xe

    def singleLayer(self, x, K, relu=True, norm=False, groups=1, openclose=False):
        if openclose:  # if K.shape[0] != K.shape[1]:
            x = self.edgeConv(x, K, groups=groups)
            if norm:
                x = F.instance_norm(x)
            if relu:
                x = F.relu(x)
            else:
                x = F.tanh(x)
        if not openclose:  # if K.shape[0] == K.shape[1]:
            x = self.edgeConv(x, K, groups=groups)
            if not relu:
                x = F.tanh(x)
            else:
                x = F.relu(x)
            if norm:
                beta = torch.norm(x)
                x = beta * tv_norm(x)
            x = self.edgeConv(x, K.t(), groups=groups)
        return x

    def finalDoubleLayer(self, x, K1, K2):
        x = F.tanh(x)
        x = self.edgeConv(x, K1)
        x = F.tanh(x)
        x = self.edgeConv(x, K2)
        x = F.tanh(x)
        x = self.edgeConv(x, K2.t())
        x = F.tanh(x)
        x = self.edgeConv(x, K1.t())
        x = F.tanh(x)
        return x

    def savePropagationImage(self, xn, Graph, i=0, minv=None, maxv=None):
        plt.figure()
        img = xn.clone().detach().squeeze().reshape(32, 32).cpu().numpy()
        if (maxv is not None) and (minv is not None):
            plt.imshow(img, vmax=maxv, vmin=minv)
        else:
            plt.imshow(img)

        plt.colorbar()
        plt.show()
        plt.savefig('plots/layer' + str(i) + '.jpg')

        plt.close()

    def updateGraph(self, Graph, features=None):
        # If features are given - update graph according to feaure space l2 distance
        N = Graph.nnodes
        I = Graph.iInd
        J = Graph.jInd
        edge_index = torch.cat([I.unsqueeze(0), J.unsqueeze(0)], dim=0)
        if features is not None:
            features = features.squeeze()
            D = torch.relu(torch.sum(features ** 2, dim=0, keepdim=True) + \
                           torch.sum(features ** 2, dim=0, keepdim=True).t() - \
                           2 * features.t() @ features)
            D = D / D.std()
            D = torch.exp(-2 * D)
            w = D[I, J]
            Graph = GO.graph(I, J, N, W=w, pos=None, faces=None)

        else:
            [edge_index, edge_weights] = gcn_norm(edge_index)  # Pre-process GCN normalization.
            I = edge_index[0, :]
            J = edge_index[1, :]
            # deg = self.getDegreeMat(Graph)
            Graph = GO.graph(I, J, N, W=edge_weights, pos=None, faces=None)

        return Graph, edge_index

    def forward(self, xn, Graph, data=None, xe=None):
        # Opening layer
        # xn = [B, C, N]
        # xe = [B, C, N, N] or [B, C, E]
        # Opening layer

        if not self.faust:
            [Graph, edge_index] = self.updateGraph(Graph)
        if self.faust:
            xn = torch.cat([xn, Graph.edgeDiv(xe)], dim=1)
        xhist = []
        debug = False

        if debug:
            xnnorm = torch.norm(xn, dim=1)
            vmin = xnnorm.min().detach().numpy()
            vmax = xnnorm.max().detach().numpy()
            saveMesh(xn.squeeze().t(), Graph.faces, Graph.pos, -1, vmax=vmax, vmin=vmin)

        if self.realVarlet:
            xe = Graph.nodeGrad(xn)
            if self.dropout:
                xe = F.dropout(xe, p=self.dropout, training=self.training)
            xe = self.singleLayer(xe, self.K2Nopen, relu=True)

        if self.dropout:
            xn = F.dropout(xn, p=self.dropout, training=self.training)

        xn = self.singleLayer(xn, self.K1Nopen, relu=True, openclose=True, norm=False)

        x0 = xn.clone()
        debug = False
        if debug:
            image = False
            if image:
                plt.figure()
                print("xn shape:", xn.shape)
                img = xn.clone().detach().squeeze().cpu().numpy().reshape(32, 32)
                minv = img.min()
                maxv = img.max()
                # img = img / img.max()
                plt.imshow(img, vmax=maxv, vmin=minv)
                plt.colorbar()
                plt.show()
                plt.savefig('plots/img_xn_norm_layer_verlet' + str(1) + 'order_nodeDeriv' + str(0) + '.jpg')
                plt.close()
            else:
                saveMesh(xn.squeeze().t(), Graph.faces, Graph.pos, 0, vmax=vmax, vmin=vmin)

        xn_old = x0
        nlayers = self.nlayers
        for i in range(nlayers):

            if self.graphUpdate is not None:
                if i % self.graphUpdate == self.graphUpdate - 1:  # update graph

                    Graph, edge_index = self.updateGraph(Graph, features=xn)
                    dxe = Graph.nodeAve(xn)

            if self.realVarlet:
                gradX = Graph.nodeGrad(xn)
                intX = Graph.nodeAve(xn)
                dxe = torch.cat([intX, gradX], dim=1)
                if self.dropout:
                    dxe = F.dropout(dxe, p=self.dropout, training=self.training)
                dxe = (self.singleLayer(dxe, self.KE1[i], relu=False))
                xe = (xe + self.h * dxe)

                divE = Graph.edgeDiv(xe)
                aveE = Graph.edgeAve(xe, method='ave')
                dxn = torch.cat([aveE, divE], dim=1)
                if self.dropout:
                    dxn = F.dropout(dxn, p=self.dropout, training=self.training)
                dxn = F.tanh(self.singleLayer(dxn, self.KN1[i], relu=False))
                xn = (xn + self.h * dxn)

            if not self.realVarlet:
                if self.varlet:
                    gradX = Graph.nodeGrad(xn)

                if self.dropout:
                    if self.varlet:
                        gradX = F.dropout(gradX, p=self.dropout, training=self.training)
                if self.varlet and not self.gated:
                    efficient = True
                    if efficient:
                        if not self.doubleConv:
                            dxn = (self.singleLayer(gradX, self.KN1[i], norm=False, relu=True, groups=1))  # KN2
                        else:
                            dxn = self.finalDoubleLayer(gradX, self.KN1[i], self.KN2[i])
                        dxn = Graph.edgeDiv(dxn)

                        if self.tripleConv:
                            dxn = self.singleLayer(dxn, self.KN3[i], norm=False, relu=False)
                    else:
                        if not self.doubleConv:
                            dxe = (self.singleLayer(gradX, self.KN1[i], norm=False, relu=False, groups=1))
                        else:
                            dxe = self.finalDoubleLayer(gradX, self.KN1[i], self.KN2[i])
                        dxn = Graph.edgeDiv(dxe)
                        if self.tripleConv:
                            dxn = self.singleLayer(dxn, self.KN3[i], norm=False, relu=False)

                elif self.varlet and self.gated:
                    W = F.tanh(Graph.nodeGrad(self.singleLayer(xn, self.KN2[i], relu=False)))
                    lapX = Graph.nodeLap(xn)
                    dxn = F.tanh(lapX + Graph.edgeDiv(W * Graph.nodeGrad(xn)))
                else:
                    dxn = (self.singleLayer(lapX, self.KN1[i], relu=False))
                    dxn = F.tanh(dxn)

                if self.mixDynamics:
                    tmp_xn = xn.clone()
                    beta = F.sigmoid(self.alpha)
                    alpha = 1 - beta

                    if 1 == 1:
                        alpha = alpha / self.h
                        beta = beta / (self.h ** 2)

                        xn = (2 * beta * xn - beta * xn_old + alpha * xn - dxn) / (beta + alpha)
                    else:
                        alpha = 0.5 * alpha / self.h
                        beta = beta / (self.h ** 2)

                        xn = (2 * beta * xn - beta * xn_old + alpha * xn_old - dxn) / (beta + alpha)
                    xn_old = tmp_xn

                elif self.wave:
                    tmp_xn = xn.clone()
                    xn = 2 * xn - xn_old - (self.h ** 2) * dxn
                    xn_old = tmp_xn
                else:
                    tmp = xn.clone()
                    xn = (xn - self.h * dxn)
                    xn_old = tmp

                if self.modelnet:
                    xhist.append(xn)

            if debug:
                if image:
                    self.savePropagationImage(xn, Graph, i + 1, minv=minv, maxv=maxv)
                else:
                    saveMesh(xn.squeeze().t(), Graph.faces, Graph.pos, i + 1, vmax=vmax, vmin=vmin)

        xn = F.dropout(xn, p=self.dropout, training=self.training)
        xn = F.conv1d(xn, self.KNclose.unsqueeze(-1))

        xn = xn.squeeze().t()
        if self.modelnet:
            out = global_max_pool(xn, data.batch)
            out = self.mlp(out)
            return F.log_softmax(out, dim=-1)

        if self.faust:
            x = F.elu(self.lin1(xn))
            if self.dropout:
                x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lin2(x)
            return F.log_softmax(x, dim=1), F.sigmoid(self.alpha)

        if self.PPI:
            return xn, Graph

        ## Otherwise its citation graph node classification:
        return F.log_softmax(xn, dim=1), Graph
