from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import scipy.sparse as sp

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data

from federatedscope.gfl.model import SAGE_Net
"""
https://proceedings.neurips.cc//paper/2021/file/34adeb8e3242824038aa65460a47c29e-Paper.pdf
Fedsageplus models from the "Subgraph Federated Learning with Missing Neighbor Generation" (FedSage+) paper, in NeurIPS'21
Source: https://github.com/zkhku/fedsage
"""


class Sampling(nn.Module):
    def __init__(self):
        super(Sampling, self).__init__()

    def forward(self, inputs):
        rand = torch.normal(0, 1, size=inputs.shape)

        return inputs + rand.to(inputs.device)


class FeatGenerator(nn.Module):
    def __init__(self, latent_dim, dropout, num_pred, feat_shape):
        super(FeatGenerator, self).__init__()
        self.num_pred = num_pred
        self.feat_shape = feat_shape
        self.dropout = dropout
        self.sample = Sampling()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 2048)
        self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)

    def forward(self, x):
        x = self.sample(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.tanh(self.fc_flat(x))

        return x


class NumPredictor(nn.Module):
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim
        super(NumPredictor, self).__init__()
        self.reg_1 = nn.Linear(self.latent_dim, 1)

    def forward(self, x):
        x = F.relu(self.reg_1(x))
        return x


# Mend the graph via NeighGen
class MendGraph(nn.Module):
    def __init__(self, num_pred):
        super(MendGraph, self).__init__()
        self.num_pred = num_pred
        for param in self.parameters():
            param.requires_grad = False

    def mend_graph(self, x, edge_index, pred_degree, gen_feats):
        device = gen_feats.device
        num_node, num_feature = x.shape
        new_edges = []
        gen_feats = gen_feats.view(-1, self.num_pred, num_feature)

        if pred_degree.device.type != 'cpu':
            pred_degree = pred_degree.cpu()
        pred_degree = torch._cast_Int(torch.round(pred_degree)).detach()
        x = x.detach()
        fill_feats = torch.vstack((x, gen_feats.view(-1, num_feature)))

        for i in range(num_node):
            for j in range(min(self.num_pred, max(0, pred_degree[i]))):
                new_edges.append(
                    np.asarray([i, num_node + i * self.num_pred + j]))

        new_edges = torch.tensor(np.asarray(new_edges).reshape((-1, 2)),
                                 dtype=torch.int64).T
        new_edges = new_edges.to(device)
        if len(new_edges) > 0:
            fill_edges = torch.hstack((edge_index, new_edges))
        else:
            fill_edges = torch.clone(edge_index)
        return fill_feats, fill_edges

    def forward(self, x, edge_index, pred_missing, gen_feats):
        fill_feats, fill_edges = self.mend_graph(x, edge_index, pred_missing,
                                                 gen_feats)

        return fill_feats, fill_edges


class LocalSage_Plus(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 hidden,
                 gen_hidden,
                 dropout=0.5,
                 num_pred=5):
        super(LocalSage_Plus, self).__init__()

        self.encoder_model = SAGE_Net(in_channels=in_channels,
                                      out_channels=gen_hidden,
                                      hidden=hidden,
                                      max_depth=2,
                                      dropout=dropout)
        self.reg_model = NumPredictor(latent_dim=gen_hidden)
        self.gen = FeatGenerator(latent_dim=gen_hidden,
                                 dropout=dropout,
                                 num_pred=num_pred,
                                 feat_shape=in_channels)
        self.mend_graph = MendGraph(num_pred)

        self.classifier = SAGE_Net(in_channels=in_channels,
                                   out_channels=out_channels,
                                   hidden=hidden,
                                   max_depth=2,
                                   dropout=dropout)

    def forward(self, data):
        x = self.encoder_model(data)
        degree = self.reg_model(x)
        gen_feat = self.gen(x)
        mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
                                                      degree, gen_feat)
        nc_pred = self.classifier(
            Data(x=mend_feats, edge_index=mend_edge_index))
        return degree, gen_feat, nc_pred[:data.num_nodes]

    def inference(self, impared_data, raw_data):
        x = self.encoder_model(data)
        degree = self.reg_model(x)
        gen_feat = self.gen(x)
        mend_feats, mend_edge_index = self.mend_graph(raw_data.x,
                                                      raw_data.edge_index,
                                                      degree, gen_feat)
        nc_pred = self.classifier(
            Data(x=mend_feats, edge_index=mend_edge_index))
        return degree, gen_feat, nc_pred[:raw_data.num_nodes]


class FedSage_Plus(nn.Module):
    def __init__(self, local_graph: LocalSage_Plus):
        super(FedSage_Plus, self).__init__()
        self.encoder_model = local_graph.encoder_model
        self.reg_model = local_graph.reg_model
        self.gen = local_graph.gen
        self.mend_graph = local_graph.mend_graph
        self.classifier = local_graph.classifier
        self.encoder_model.requires_grad_(False)
        self.reg_model.requires_grad_(False)
        self.mend_graph.requires_grad_(False)
        self.classifier.requires_grad_(False)

    def forward(self, data):
        x = self.encoder_model(data)
        degree = self.reg_model(x)
        gen_feat = self.gen(x)
        mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
                                                      degree, gen_feat)
        nc_pred = self.classifier(
            Data(x=mend_feats, edge_index=mend_edge_index))
        return degree, gen_feat, nc_pred[:data.num_nodes]