from laplacian import *
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.datasets import TUDataset

import os.path as osp
import torch

import warnings

# Turn warnings into errors
warnings.simplefilter("error", RuntimeWarning)


def create_CC(edge_index, num_of_nodes):
    CC = set()
    
    for i in range(num_of_nodes):
        CC.add((i,))
    
    for i in range(edge_index.shape[1]):
        CC.add((edge_index[0,i].item(), edge_index[1,i].item()))

    return CC

def generate_HKS_for_dataset_single_thread(dataset, dir_path):
    timesteps = np.linspace(1, 100, num= 100)
    counter = 0
    for index, data in enumerate(dataset):
        # 1) gen the cc from the data 
        try:
            x = data['x']
            CC = create_CC(data['edge_index'], data['num_nodes'])
            laplacian1 = real_laplacian_matrix(CC, data['num_nodes'])
            hk1 = [heat_kernel(laplacian1, t) for t in timesteps]
            signatures_1 = eval_kernels(hk1)
            signatures_1 = torch.tensor(signatures_1)
            x = torch.cat([x,signatures_1], dim=1)
            torch.save({"x": x, "y": data['y'], 'num_nodes': data['num_nodes']}, f"{dir_path}/{counter}.pt")
            counter += 1
            print(f'Counter: {counter}')
        except Exception as e:
            print(f"Error processing index {index}: {e}")
            break
    
def debug_generate_HKS_for_dataset_single_thread(dataset, dir_path):
    timesteps = np.linspace(1, 100, num= 100)
    counter = 0

    data = dataset[4]
    index = 4
    x = data['x']
    CC = create_CC(data['edge_index'], data['num_nodes'])
    laplacian1 = real_laplacian_matrix(CC, data['num_nodes'])
    hk1 = [heat_kernel(laplacian1, t) for t in timesteps]
    signatures_1 = eval_kernels(hk1)
    signatures_1 = torch.tensor(signatures_1)
    x = torch.cat([x,signatures_1], dim=1)
    torch.save({"x": x, "y": data['y'], 'num_nodes': data['num_nodes']}, f"{dir_path}/{counter}.pt")
    counter += 1
    print(f'Counter: {counter}')
      


import numpy as np
import torch
from concurrent.futures import ThreadPoolExecutor, as_completed

def process_data(index, data, timesteps, dir_path, counter_lock):
    try:
        x = data['x']
        CC = create_CC(data['edge_index'], x.shape[0])
        laplacian1 = real_laplacian_matrix(CC, x.shape[0])
        hk1 = [heat_kernel(laplacian1, t) for t in timesteps]
        signatures_1 = eval_kernels(hk1)
        signatures_1 = torch.tensor(signatures_1)
        x = torch.cat([x, signatures_1], dim=1)
        assert not torch.isnan(x).any(), "Tensor contains NaN values!"
        # Lock to avoid race condition when updating the counter
        with counter_lock:
            counter = index
            torch.save({"x": x, "y": data['y'], 'num_nodes': x.shape[0]}, f"{dir_path}/{counter}.pt")
            print(f'Counter: {counter}')
    except Exception as e:
        print('he')
        print(f"Error processing index {index}: {e}")
        return None

def generate_HKS_for_dataset(dataset, dir_path):
    timesteps = np.linspace(1, 100, num=100)
    timesteps = np.log(timesteps)
    # Lock object to synchronize access to the counter across threads
    from threading import Lock
    counter_lock = Lock()
    
    # Use ThreadPoolExecutor for multithreading
    with ThreadPoolExecutor() as executor:
        futures = []
        
        # Submit tasks to the thread pool
        for index, data in enumerate(dataset):
            future = executor.submit(process_data, index, data, timesteps, dir_path, counter_lock)
            futures.append(future)
        
        # Wait for all threads to complete
        for future in as_completed(futures):
            # You can check for exceptions here if needed
            result = future.result()

# Example usage
# generate_HKS_for_dataset(dataset, 'path/to/save')


path = osp.dirname(osp.realpath(__file__))
# dataset = TUDataset(
#                 root=osp.realpath(__file__)[:-25],
#                 use_node_attr=True,
#                 cleaned=True,
#                 name='PROTEINS',
#                 transform=None,
#                 pre_transform = None
#             )
# print(len(dataset))
from tu_datasets import PTG_LegacyTUDataset


dataset = PTG_LegacyTUDataset(
    root=osp.realpath(__file__)[:-25],
    # use_node_attr=self.has_node_attributes,
    # cleaned=cleaned,
    name='PROTEINS',
    transform=None
)
# self.node_attributes = dataset[0].x.shape[1]
# debug_generate_HKS_for_dataset_single_thread(dataset, '/scratch/work/krahnm1/PhDProjects/DeepHKS/TopNets/RePHINE/datasets/laplacemolhiv_test')
print(len(dataset))

generate_HKS_for_dataset(dataset, '/scratch/work/krahnm1/PhDProjects/DeepHKS/TopNets/RePHINE/datasets/protein_debugged')



# try:
#     # Your code here that produces the warnings
#     debug_generate_HKS_for_dataset_single_thread(dataset, '/scratch/work/krahnm1/PhDProjects/DeepHKS/TopNets/RePHINE/datasets/laplacemolhiv_test')
# except RuntimeWarning as e:
#     print(f"RuntimeWarning caught: {e}")
#     import traceback
#     traceback.print_exc()  # Print the full traceback