import numpy as np
from src.utils import config
from stellargraph.core import StellarGraph
import pandas as pd
import stellargraph as sg
import torch
import scipy.sparse as sp
from torch_geometric.data import Data

def accuracy(pred, true):
    acc = 0.0
    for predi, truei in zip(pred, true):
        if np.argmax(predi) == np.argmax(truei):
            acc += 1.0
    acc /= len(pred)
    return acc


def accuracy_missing(pred, true):
    acc = 0.0
    for predi, truei in zip(pred, true):
        if int(predi) == int(truei):
            acc += 1.0
    acc /= len(pred)
    return acc


# Despite the variable name impaired_graph, the function is actually creating a new graph that enhances the original_graph by adding new nodes and edges.
def fill_graph(impaired_graph:StellarGraph, original_graph:StellarGraph,missing, new_feats,feat_shape):
    new_feats = new_feats.reshape((-1, config.num_pred, feat_shape))
    original_node_ids=[id_i for id_i in original_graph.nodes()]
    fill_node_ids = [id_i for id_i in impaired_graph.nodes()]
    fill_node_feats=[]
    org_feats=original_graph.node_features()
    for i in range(len(list(original_graph.nodes()))):
        fill_node_feats.append(np.asarray(org_feats[i].reshape(-1)))
    org_edges = np.copy(original_graph.edges())
    fill_edges_source = [edge[0] for edge in org_edges]
    fill_edges_target = [edge[1] for edge in org_edges]

    start_id = -1
    for new_i in range(len(missing)):
        if int(missing[new_i]) > 0:
            new_ids_i = np.arange(start_id, start_id - min(config.num_pred, int(missing[new_i])), -1)

            i_pred = 0
            for i in new_ids_i:
                original_node_ids.append(int(i))
                if isinstance(new_feats[new_i][i_pred], np.ndarray) == False:
                    if config.cuda:
                        new_feats = new_feats.cpu()
                    new_feats = new_feats.detach().numpy()
                fill_node_feats.append(np.asarray(new_feats[new_i][i_pred].reshape(-1)))
                i_pred += 1
                fill_edges_source.append(fill_node_ids[new_i])
                fill_edges_target.append(int(i))

            start_id = start_id - min(config.num_pred, int(missing[new_i]))

    fill_edges_source = np.asarray(fill_edges_source).reshape((-1))
    fill_edges_target = np.asarray(fill_edges_target).reshape((-1))
    fill_edges = pd.DataFrame()
    fill_edges['source'] = fill_edges_source
    fill_edges['target'] = fill_edges_target
    fill_node_feats_np = np.asarray(fill_node_feats).reshape((-1,feat_shape))
    fill_node_ids_np = np.asarray(original_node_ids).reshape(-1)

    fill_nodes = sg.IndexedArray(fill_node_feats_np, fill_node_ids_np)
    fill_G = sg.StellarGraph(nodes=fill_nodes, edges=fill_edges)
    return fill_nodes, fill_G




def get_adj(edges, node_len):
    if config.cuda:
        edges=edges.cpu()
    adj = sp.coo_matrix((np.ones(edges.shape[0]),
                         (edges[:, 0], edges[:, 1])),
                        shape=(node_len, node_len),
                        dtype=np.float32)

    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

    adj = normalize(adj + sp.eye(adj.shape[0]))
    adj = sparse_mx_to_torch_sparse_tensor(adj)
    if config.cuda:
        adj=adj.cuda()
    return adj


def normalize(mx):
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse_coo_tensor(indices, values, shape)


def fill_graph_pytorch(impaired_graph_data, missing, new_feats, feat_shape):
    """
    PyTorch version of fill_graph that maintains gradient flow

    Parameters:
    - impaired_graph_data: A PyTorch Data object with x (node features) and edge_index attributes
    - missing: Tensor indicating how many nodes to add per existing node
    - new_feats: Tensor of features for new nodes
    - feat_shape: Feature dimension

    Returns:
    - filled_graph: A PyTorch Data object with added nodes and edges
    """
    new_feats = new_feats.reshape(-1, config.num_pred, feat_shape)

    # Get original graph data
    original_x = impaired_graph_data.x
    original_edge_index = impaired_graph_data.edge_index

    # Lists to store new nodes and edges
    new_node_feats = []
    new_edges_source = []
    new_edges_target = []

    # Start ID for new nodes (continuing from the last original node)
    start_id = original_x.size(0)

    # For each potential new node
    for node_idx in range(len(missing)):
        if missing[node_idx] > 0:
            # How many nodes to add for this original node
            num_to_add = min(config.num_pred, int(missing[node_idx]))

            # Add each new node
            for i in range(num_to_add):
                # Add the new node features
                new_node_feats.append(new_feats[node_idx][i])

                # Connect the original node to the new node
                new_edges_source.append(node_idx)
                new_edges_target.append(start_id)

                # Connect the new node to the original node (if your graph is undirected)
                new_edges_source.append(start_id)
                new_edges_target.append(node_idx)

                start_id += 1

    # Create a mask to identify added nodes
    added_mask = torch.zeros(original_x.size(0) + len(new_node_feats), dtype=torch.bool)
    if new_node_feats:
        added_mask[original_x.size(0):] = True

    # If we have new nodes to add
    if new_node_feats:
        # Convert lists to tensors
        new_node_feats_tensor = torch.stack(new_node_feats)

        # Combine original and new features
        combined_x = torch.cat([original_x, new_node_feats_tensor], dim=0)

        # Create new edges tensor
        new_edges_source_tensor = torch.tensor(new_edges_source, dtype=torch.long)
        new_edges_target_tensor = torch.tensor(new_edges_target, dtype=torch.long)
        new_edges_tensor = torch.stack([new_edges_source_tensor, new_edges_target_tensor], dim=0).cuda()

        # Combine original and new edges
        combined_edge_index = torch.cat([original_edge_index, new_edges_tensor], dim=1)
    else:
        # No new nodes, keep original graph
        combined_x = original_x
        combined_edge_index = original_edge_index

    # Create new Data object
    filled_graph = Data(
        x=combined_x,
        edge_index=combined_edge_index,
        # Copy any other attributes from the original graph
        train_mask=impaired_graph_data.train_mask if hasattr(impaired_graph_data, 'train_mask') else None,
        test_mask=impaired_graph_data.test_mask if hasattr(impaired_graph_data, 'test_mask') else None,
        y=impaired_graph_data.y if hasattr(impaired_graph_data, 'y') else None,
        added_mask=added_mask  # Add the mask to identify added nodes
    )

    # Update masks to include new nodes (typically new nodes would not be in train/test sets)
    if hasattr(filled_graph, 'train_mask') and filled_graph.train_mask is not None:
        new_train_mask = torch.zeros(combined_x.size(0), dtype=torch.bool)
        new_train_mask[:original_x.size(0)] = filled_graph.train_mask
        filled_graph.train_mask = new_train_mask

    if hasattr(filled_graph, 'test_mask') and filled_graph.test_mask is not None:
        new_test_mask = torch.zeros(combined_x.size(0), dtype=torch.bool)
        new_test_mask[:original_x.size(0)] = filled_graph.test_mask
        filled_graph.test_mask = new_test_mask

    if hasattr(filled_graph, 'y') and filled_graph.y is not None:
        # For new nodes, assign a default label (e.g., -1)
        new_y = torch.full((combined_x.size(0),), fill_value=-1, dtype=filled_graph.y.dtype)
        new_y[:original_x.size(0)] = filled_graph.y
        filled_graph.y = new_y

    return filled_graph