import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import scipy.sparse as sp
import math
import config
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

from utils import make_gnn_model, device

class Sampling(nn.Module):
    def __init__(self):
        super(Sampling, self).__init__()

    def forward(self, inputs):
        rand = torch.normal(0, 1, size=inputs.shape)
        if config.cuda:
            return inputs + rand.to(device)
        else:
            return inputs + rand

class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape
    def forward(self, x):
        return x.view(self.shape)


class MendGraph(nn.Module):
    def __init__(self, node_len, num_pred, feat_shape):
        super(MendGraph, self).__init__()
        self.num_pred = num_pred
        self.feat_shape = feat_shape
        self.org_node_len = node_len

        for param in self.parameters():
            param.requires_grad=False

    def mend_graph(self, org_feats, org_edges, pred_degree, gen_feats):
        new_edges=[]
        gen_feats = gen_feats.view(-1, self.num_pred, self.feat_shape)

        if config.cuda:
            pred_degree=pred_degree.cpu()

        pred_degree = torch._cast_Int(pred_degree).detach()
        org_feats=org_feats.detach()
        fill_feats = torch.vstack((org_feats, gen_feats.view(-1, self.feat_shape)))

        for i in range(self.org_node_len):
            for j in range(min(self.num_pred, max(0, pred_degree[i]))):
                new_edges.append(np.asarray([i, self.org_node_len + i * self.num_pred + j]))
                new_edges.append(np.asarray([self.org_node_len + i * self.num_pred + j, i]))

        new_edges=torch.tensor(np.asarray(new_edges).reshape((2,-1)))

        if config.cuda:
            new_edges=new_edges.to(device)

        if len(new_edges[0])>0:
            fill_edges=torch.hstack((org_edges, new_edges))
        else:
            fill_edges=torch.clone(org_edges)
            
        return fill_feats, fill_edges


    def forward(self,org_feats, org_edges, pred_missing, gen_feats):
        fill_feats, fill_edges = self.mend_graph(org_feats, org_edges, pred_missing, gen_feats)

        return fill_feats, fill_edges


class Gen(nn.Module):
    def __init__(self,latent_dim, dropout,num_pred,feat_shape):
        super(Gen, self).__init__()
        self.num_pred=num_pred
        self.feat_shape=feat_shape
        self.sample = Sampling()

        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256,2048)
        self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)

        self.dropout = dropout

    def forward(self, x):
        x = self.sample(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.tanh(self.fc_flat(x))
        return x

class RegModel(nn.Module):
    def __init__(self,latent_dim):
        super(RegModel,self).__init__()
        self.reg_1 = nn.Linear(latent_dim, 1)

    def forward(self,x):
        x = F.relu(self.reg_1(x))
        return x


class FedSage_Plus(nn.Module):
    def __init__(self, feat_shape, node_len, n_classes, args):
        super(FedSage_Plus, self).__init__()
        self.encoder_model = make_gnn_model(architecture = args["architecture"], in_channels = feat_shape, num_classes = args["enc_hidden_channels"]).to(device)
        self.reg_model = RegModel(latent_dim = args["enc_hidden_channels"])
        self.gen = Gen(latent_dim = args["enc_hidden_channels"], dropout = args["dropout"], num_pred = args["num_pred"], feat_shape = feat_shape)
        self.mend_graph=MendGraph(node_len = node_len, num_pred = args["num_pred"], feat_shape = feat_shape)
        self.classifier = make_gnn_model(architecture = args["architecture"], in_channels = feat_shape, num_classes = n_classes).to(device)
        self.mend_graph.requires_grad_(False)


    def forward(self, feat, edges):
        x = F.relu(self.encoder_model(feat, edges)[0])
        degree = self.reg_model(x)
        gen_feat = self.gen(x)
        mend_feats, mend_edges=self.mend_graph(feat, edges, degree, gen_feat)
        out = self.classifier(mend_feats, mend_edges)
        nc_pred = F.relu(out[0])
        return degree, gen_feat, nc_pred, out[2]
