from genericpath import exists
import os
import pandas as pd
import numpy as np
import random
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import networkit as nk
import dgl
from torch_geometric.datasets import Planetoid

# Create PyTorch Geometric dataset for link prediction
import torch
from torch_geometric.data import Data
from torch_geometric.utils import train_test_split_edges
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import to_networkx
from torch_geometric.utils import to_dgl
import torch.nn.functional as F
import concurrent.futures
import threading

# Set random seeds for reproducibility
def set_random_seed(seed=42):
    """Set random seeds for reproducible results"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



def build_edge_label(pos_edges, neg_edges):
    edge_index = torch.cat([pos_edges, neg_edges], dim=1)
    labels = torch.cat([torch.ones(pos_edges.shape[1]), torch.zeros(neg_edges.shape[1])])
    return edge_index, labels


def sampling_neg_edges(data, g, method='random', wl=50):
    # sample 2 hop neighbors as negative  edges
    pos_idx = data.shape[1] // 2
    pos_edges = data[:, :pos_idx]
    
    orig_edges = set(map(tuple, data.t().tolist()))
    neg_edges = []

    walks = dgl.sampling.node2vec_random_walk(g, pos_edges[0], p=10, q=0.01, walk_length=wl)
    neg_idx = walks[:, -1].clone().detach()

    neg_edges = torch.stack([pos_edges[0], neg_idx], dim=0)
    idx = torch.randperm(neg_edges.shape[1])
    neg_edges = neg_edges[:, idx]


    return torch.cat([pos_edges, neg_edges], dim=1)



def get_binarized_network_v1(ratio=0.1):
    # Preprocess binarized/stack graph
    path = 'cropped_aligned_stack'

    edge_list = np.loadtxt(f'{path}/Binarized/network_edgelist.csv', delimiter=',', skiprows=1)
    node_coor = np.loadtxt(f'{path}/Binarized/network_node_positions.csv', delimiter=',', skiprows=1)

    # Remove self-loops from edge_list
    edge_list = edge_list[edge_list[:,0] != edge_list[:,1]]

    # Convert edge list to long tensor (PyG expects shape [2, num_edges])
    edge_index = torch.tensor(edge_list[:, :2].T, dtype=torch.long)
    print(f"Total edges in original graph: {edge_index.shape}")

    # Node features: node coordinates
    x = torch.tensor(node_coor, dtype=torch.float)
    # x = F.normalize(x, p=2, dim=1)
    data = Data(x=x, edge_index=edge_index)

    # Efficient split for link prediction (train/val/test)
    transform = RandomLinkSplit(
        num_val=ratio, num_test=ratio,
        is_undirected=True,
        add_negative_train_samples=True,
        neg_sampling_ratio=1.0
    )
    train_data, val_data, test_data = transform(data)

    print('Node feature shape:', data.x.shape)
    print('Train edges:', train_data.edge_index.shape)
    
    edges = edge_index.t().numpy()
    edges = np.unique(np.sort(edges, axis=1), axis=0)
    edges = [tuple(e) for e in edges]
  
    num_nodes = data.num_nodes
    # ---------- Build adjacency for 2-hop neighbor computations ----------
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(edges)

    print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
   

    dgl_g = to_dgl(data)

    train_data.edge_label_index = sampling_neg_edges(train_data.edge_label_index, dgl_g, method='hops')
    val_data.edge_label_index = sampling_neg_edges(val_data.edge_label_index, dgl_g, method='hops')    
    test_data.edge_label_index = sampling_neg_edges(test_data.edge_label_index, dgl_g, method='hops')

    # print('Node feature shape:', data.x.shape)
    print('Train edges:', train_data.edge_label_index.shape)
    print('Val edges:', val_data.edge_label_index.shape)
    print('Test edges:', test_data.edge_label_index.shape)

    return data, train_data, val_data, test_data, dgl_g


def get_nanowire_network():

    # Preprocess Isotropic data
    path = 'FinalIsotropicData'
    sub_folder = ['Rx=72Ry=79', 'Rx=95Ry=91', 'Rx=159Ry=169', 'Rx=244Ry=229']
    lbl_x = [72, 95, 159, 244]
    lbl_y = [79, 91, 169, 229]

    data_list = []
    nnode_list = []
    nedge_list = []

    for f in range(len(sub_folder)):
        folder = sub_folder[f]
        # print(folder)
        subs = [f for f in os.listdir(os.path.join(path, folder))]
        if '.DS_Store' in subs:
            subs.remove('.DS_Store')

        for sub in subs:
            # print(sub)
            edge_list = np.loadtxt(f'{path}/{folder}/{sub}/Binarized/network_edgelist.csv', delimiter=',', skiprows=1)
            node_coor = np.loadtxt(f'{path}/{folder}/{sub}/Binarized/network_node_positions.csv', delimiter=',', skiprows=1)

             # Remove self-loops from edge_list
            edge_list = edge_list[edge_list[:,0] != edge_list[:,1]]

             # Convert to PyTorch tensors
            edge_index = torch.tensor(edge_list[:, :2].T, dtype=torch.long)
            x = torch.tensor(node_coor[:, 1:], dtype=torch.float)

            # Create PyG Data object
            data = Data(x=x, edge_index=edge_index)
            data.y = torch.tensor([lbl_x[f]], dtype=torch.float)

            data_list.append(data)
            nnode_list.append(node_coor.shape[0])
            nedge_list.append(edge_list.shape[0])

        
   
    return data_list


def get_polymer_network():
    
    path = 'Polymer_network_fracture_data/ensemble_dataset'

    subs = [f for f in os.listdir(path)]
    if '.DS_Store' in subs:
        subs.remove('.DS_Store')

    nnode_list = []
    nedge_list = []
    data_list = []

    for sub in subs:
        # print(sub)
        edge_list = np.loadtxt(f'{path}/{sub}/edge_list.txt', delimiter='\t', skiprows=1)
        node_coor = np.loadtxt(f'{path}/{sub}/nodes.txt', delimiter='\t', skiprows=4)
        lbl = np.loadtxt(f'{path}/{sub}/P_break_x.txt', delimiter='\t', skiprows=1)[:,:2]
    
        edge_list[:, :2] -= 1

        # Convert to PyTorch tensors
        edge_index = torch.tensor(edge_list[:, :2].T, dtype=torch.long)
        x = torch.tensor(node_coor[:, 1:], dtype=torch.float)
        # Create PyG Data object
        data = Data(x=x, edge_index=edge_index)
        
        # data.edge_ids = edge_list[:, 2].astype(int)-1

        p_break = np.zeros(lbl.shape[0])
        
        for i in range(lbl.shape[0]):
            e_id = int(edge_list[i, 2])
            idx = np.where(lbl[:,0] == e_id)[0]
            p_break[i] = lbl[idx, 1]

        data.y = torch.tensor(p_break, dtype=torch.float)

        data_list.append(data)
        nnode_list.append(node_coor.shape[0])
        nedge_list.append(edge_list.shape[0])

    return data_list 

