import torch
import networkx as nx
from torch_geometric.utils import from_networkx
from torch_geometric.utils import negative_sampling
import pickle
import random
import numpy as np
from scipy.linalg import expm


def get_pyg_data(args):
    """
    Loads a graph dataset and converts it to PyTorch Geometric format.

    Args:
        args: config file

    Returns:
        G (networkx.Graph): Original NetworkX graph.
        data (torch_geometric.data.Data): Corresponding PyTorch Geometric data object.
    """
    graph = args.graph
    data_dir = "./data"

    if graph == "community2":
        file = data_dir + "/two_block_com.pkl"
    
    elif graph == "community3":
        file = data_dir + "/three_block_com.pkl"

    elif graph == "protein":
        file = data_dir + "/protein.pkl"

    elif graph == "citeseer":
        file = data_dir + "/citeseer/cs.edgelist"


    
    if graph in ["community2", "community3", "protein"]:
        with open(file, "rb") as f:
            graphs = pickle.load(f)
        graphs = random.sample(graphs, args.graph_num)
        graphs = [nx.relabel_nodes(g, {old: new for new, old in enumerate(g.nodes())}) for g in graphs]
        data_list = [from_networkx(G) for G in graphs]
        for data in data_list:
            data.x = torch.ones((data.num_nodes, 1))
    elif graph == "citeseer":
        G = nx.read_edgelist(file, nodetype=int)
        G = nx.relabel_nodes(G, {old: new for new, old in enumerate(G.nodes())})
        graphs = [G]
        data_list = [from_networkx(G)]
        for data in data_list:
            data.x = torch.ones((data.num_nodes, 1))
    return graphs, data_list


class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, edge_index, num_nodes, neg_ratio=1):
        self.num_nodes = num_nodes
        self.edge_index = edge_index

        self.pos_edge_index = edge_index
        self.neg_edge_index = negative_sampling(
            edge_index=edge_index,
            num_nodes=num_nodes,
            num_neg_samples=int(edge_index.size(1) * neg_ratio),
        )

        self.edges = torch.cat([self.pos_edge_index, self.neg_edge_index], dim=1)
        self.labels = torch.cat([
            torch.ones(self.pos_edge_index.size(1), dtype=torch.float),
            torch.zeros(self.neg_edge_index.size(1), dtype=torch.float),
        ], dim=0)

    def __len__(self):
        return self.edges.size(1)

    def __getitem__(self, idx):
        edge = self.edges[:,idx]
        label = self.labels[idx]
        return edge, label
    

def heat_diffusion(L, num_signals, diffusion_rate=0.1, sample_from="normal"):
    if sample_from == "normal": # initial heat values are sampled from normal distribution
        initial_heat = np.random.randn(L.shape[0])
    elif sample_from == "discrete":
        initial_heat = np.random.choice([0,100], size=L.shape[0], p=[0.8,0.2])
    elif sample_from == "bimodal":
        mean1,std1 = 0,0.2
        mean2,std2 = 1,0.2
        samples1 = np.random.normal(mean1, std1, L.shape[0]//2)
        samples2 = np.random.normal(mean2, std2, L.shape[0]-len(samples1))
        initial_heat = np.concatenate([samples1, samples2])
        np.random.shuffle(initial_heat)
    heat = initial_heat
    heats = np.zeros((num_signals, len(heat)))
    heats[0] = heat
    diffusion_matrix = expm(-diffusion_rate * L)
    for i in range(1,num_signals):
        current_heat = diffusion_matrix @ heat
        heats[i] = current_heat
        heat = current_heat
    return heats