import torch

from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx, from_networkx

import networkx as nx
import numpy as np

from federatedscope.core.configs.config import global_cfg


class HideGraph(BaseTransform):
    r"""
    Generate impaired graph with labels and features to train NeighGen, 
    hide Node from validation set from raw graph.
    
    Arguments:
        hidden_portion (int): hidden_portion of validation set.
        num_pred (int): hyperparameters which limit the maximum value of the prediction
        
    :returns:
        filled_data : impaired graph with attribute "num_missing"
    :rtype:
        nx.Graph
    """
    def __init__(self, hidden_portion=0.5, num_pred=5):
        self.hidden_portion = hidden_portion
        self.num_pred = num_pred

    def __call__(self, data):

        val_ids = torch.where(data.val_mask == True)[0]
        hide_ids = np.random.choice(val_ids,
                                    int(len(val_ids) * self.hidden_portion),
                                    replace=False)
        remaining_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        remaining_mask[hide_ids] = False
        remaining_nodes = torch.where(remaining_mask == True)[0].numpy()

        data.ids_missing = [[] for _ in range(data.num_nodes)]

        G = to_networkx(data,
                        node_attrs=[
                            'x', 'y', 'train_mask', 'val_mask', 'test_mask',
                            'index_orig', 'ids_missing'
                        ],
                        to_undirected=True)

        for missing_node in hide_ids:
            neighbors = G.neighbors(missing_node)
            for i in neighbors:
                G.nodes[i]['ids_missing'].append(missing_node)
        for i in G.nodes:
            ids_missing = G.nodes[i]['ids_missing']
            del G.nodes[i]['ids_missing']
            G.nodes[i]['num_missing'] = np.array([len(ids_missing)],
                                                 dtype=np.float32)
            if len(ids_missing) > 0:
                if len(ids_missing) <= self.num_pred:
                    G.nodes[i]['x_missing'] = np.vstack(
                        (data.x[ids_missing],
                         np.zeros((self.num_pred - len(ids_missing),
                                   data.x.shape[1]))))
                else:
                    G.nodes[i]['x_missing'] = data.x[
                        ids_missing[:self.num_pred]]
            else:
                G.nodes[i]['x_missing'] = np.zeros(
                    (self.num_pred, data.x.shape[1]))

        return from_networkx(nx.subgraph(G, remaining_nodes))

    def __repr__(self):
        return f'{self.__class__.__name__}({self.hidden_portion})'


def FillGraph(impaired_data, original_data, pred_missing, pred_feats,
              num_pred):
    # Mend the original data
    original_data = original_data.detach().cpu()
    new_features = original_data.x
    new_edge_index = original_data.edge_index.T
    pred_missing = pred_missing.detach().cpu().numpy()
    pred_feats = pred_feats.detach().cpu().reshape(
        (-1, num_pred, original_data.num_node_features))

    start_id = original_data.num_nodes
    for node in range(len(pred_missing)):
        num_fill_node = np.around(pred_missing[node]).astype(np.int32).item()
        if num_fill_node > 0:
            new_ids_i = np.arange(start_id,
                                  start_id + min(num_pred, num_fill_node))
            org_id = impaired_data.index_orig[node]
            org_node = torch.where(
                original_data.index_orig == org_id)[0].item()
            new_edges = torch.tensor([[org_node, fill_id]
                                      for fill_id in new_ids_i],
                                     dtype=torch.int64)
            new_features = torch.vstack(
                (new_features, pred_feats[node][:num_fill_node]))
            new_edge_index = torch.vstack((new_edge_index, new_edges))
            start_id = start_id + min(num_pred, num_fill_node)
    new_y = torch.zeros(new_features.shape[0], dtype=torch.int64)
    new_y[:original_data.num_nodes] = original_data.y
    filled_data = Data(
        x=new_features,
        edge_index=new_edge_index.T,
        train_idx=torch.where(original_data.train_mask == True)[0],
        valid_idx=torch.where(original_data.val_mask == True)[0],
        test_idx=torch.where(original_data.test_mask == True)[0],
        y=new_y,
    )
    return filled_data


@torch.no_grad()
def GraphMender(model, impaired_data, original_data):
    r"""Mend the graph with generation model
    Arguments:
        model (torch.nn.module): trained generation model
        impaired_data (PyG.Data): impaired graph
        original_data (PyG.Data): raw graph
    :returns:
        filled_data : Graph after Data Enhancement
    :rtype:
        PyG.data
    """
    device = impaired_data.x.device
    model = model.to(device)
    pred_missing, pred_feats, _ = model(impaired_data)

    return FillGraph(impaired_data, original_data, pred_missing, pred_feats,
                     global_cfg.fedsageplus.num_pred)