import os
import h5py
import torch
import numpy as np
import pdb
from tqdm import tqdm
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from scipy.linalg import eigh
import torch_geometric
# from add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE
import gc
# from multiprocessing import Pool
import time
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.collections import LineCollection
import math
import argparse

import warnings
warnings.filterwarnings("ignore", message="Sparse CSR tensor support is in beta state.*")

from torch_geometric.utils import (
    get_laplacian,
    get_self_loop_attr,
    is_torch_sparse_tensor,
    scatter,
    to_edge_index,
    to_scipy_sparse_matrix,
    to_torch_coo_tensor,
    to_torch_csr_tensor,
)
import numpy as np
from torch import Tensor

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'
# DEVICE = 'cpu'

def convert_to_networkx_graph(edge_index, edge_attr=None):
    G = nx.Graph()
    for idx in range(edge_index.shape[1]):
        i = edge_index[0][idx].item()
        j = edge_index[1][idx].item()
        if edge_attr is not None:
            weight = edge_attr[idx].item()
            G.add_edge(i, j, weight=weight)
        else:
            G.add_edge(i, j)
    return G

def normalize_coordinates(coords):
    min_coords = coords.min(axis=0)
    max_coords = coords.max(axis=0)
    normalized_coords = (coords - min_coords) / (max_coords - min_coords)
    return normalized_coords

def get_laplacian_eigenvector_pe(edge_index, k=64):
    G = convert_to_networkx_graph(edge_index)
    connected_components = [G.subgraph(c).copy() for c in nx.connected_components(G)]
    if len(connected_components) == 1:
        L = nx.laplacian_matrix(G).todense()
        eigenvalues, eigenvectors = eigh(L)
        num_dc = np.sum(np.abs(eigenvalues) < 1e-10)
        # print(f'n_dc: {num_dc}')
        return eigenvalues[num_dc:k+num_dc], eigenvectors[:,num_dc:k+num_dc]
    else:
        # print(f'Graph has {len(connected_components)} connected components.')

        node_index_mapping = {node: i for i, node in enumerate(G.nodes())}
        original_node_count = len(G.nodes())
        
        all_eigenvectors = np.zeros((original_node_count, k))
        all_eigenvalues = np.zeros((len(connected_components), k))
        
        for component_idx, component in enumerate(connected_components):
            L = nx.laplacian_matrix(component).todense()
            eigenvalues, eigenvectors = eigh(L)
            num_dc = np.sum(np.abs(eigenvalues) < 1e-10)
            # print(f'[{component_idx}] n_dc: {num_dc}')
            
            subgraph_k = min(k-num_dc, len(eigenvalues)-num_dc)
            top_k_eigenvalues = eigenvalues[num_dc:subgraph_k+num_dc]
            top_k_eigenvectors = eigenvectors[:, num_dc:subgraph_k+num_dc]
            
            for i, node in enumerate(component.nodes()):
                original_index = node_index_mapping[node]
                all_eigenvectors[original_index, :subgraph_k] = top_k_eigenvectors[i, :]
            
            all_eigenvalues[component_idx, :subgraph_k] = top_k_eigenvalues
        return np.array(all_eigenvalues), all_eigenvectors


def get_random_walk_eigenvector_pe(edge_index, num_nodes, walk_length=16, add_self_loops=False):
    edge_index = torch.from_numpy(edge_index).to(DEVICE)
    N = num_nodes
    if add_self_loops:
        # Add self-loops
        self_loops = torch.arange(N, device=edge_index.device)
        self_loop_edges = torch.stack([self_loops, self_loops], dim=0)
        edge_index = torch.cat([edge_index, self_loop_edges], dim=1)

    row, col = edge_index
    num_edges = edge_index.size(1)
    value = torch.ones(num_edges, device=row.device)
    
    value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]
    value = 1.0 / value

    if N <= 2_000:  # Dense code path for faster computation:
        adj = torch.zeros((N, N), device=row.device)
        adj[row, col] = value
        loop_index = torch.arange(N, device=row.device)
    else:
        adj = to_torch_coo_tensor(edge_index, value, size=(num_nodes,num_nodes))

    def get_pe(out: Tensor) -> Tensor:
        if is_torch_sparse_tensor(out):
            return get_self_loop_attr(*to_edge_index(out), num_nodes=N)
        return out[loop_index, loop_index]

    out = adj
    pe_list = [get_pe(out)]
    for _ in range(walk_length - 1):
        out = out @ adj
        pe_list.append(get_pe(out))

    pe = torch.stack(pe_list, dim=-1).cpu().detach().numpy()

    del edge_index, row, col, value, adj, out, pe_list
    gc.collect()
    if DEVICE =='cuda':
        torch.cuda.empty_cache()
    
    return pe

def visualize_graphs(base_adj_edge_index, base_adj_edge_attr, feat_adj_edge_index, feat_adj_edge_attr, feat_filtered_adj_edge_index, feat_filtered_adj_edge_attr, 
                     node_coordinates, node_size=6, file_name='', visualization_save_path='./', k=8, scale_factor=1.0):
    # Scale node coordinates
    scaled_node_coordinates = {node: (scale_factor * x, scale_factor * y) for node, (x, y) in node_coordinates.items()}
    
    fig, axes = plt.subplots(2, 2, figsize=(16*2, 16*2))
    
    cmap = cm.viridis

    base_G = convert_to_networkx_graph(base_adj_edge_index, base_adj_edge_attr)
    feat_G = convert_to_networkx_graph(feat_adj_edge_index, feat_adj_edge_attr)
    feat_filtered_G = {}
    for key in feat_filtered_adj_edge_index.keys():
        feat_filtered_G[key] = convert_to_networkx_graph(feat_filtered_adj_edge_index[key][k], feat_filtered_adj_edge_attr[key][k])


    base_norm = Normalize(vmin=min(nx.get_edge_attributes(base_G, 'weight').values()),
                            vmax=max(nx.get_edge_attributes(base_G, 'weight').values()))
    base_edge_colors = [cmap(base_norm(base_G[u][v]['weight'])) for u, v in base_G.edges()]
    nx.draw(base_G, pos=scaled_node_coordinates, node_size=node_size, node_color="gray",
            edge_color=base_edge_colors, with_labels=False, ax=axes[0,0], width=0.7, alpha=0.7)
    axes[0,0].set_title('Base Adjacency Matrix')
    axes[0,0].title.set_fontsize(50)

    feat_norm = Normalize(vmin=min(nx.get_edge_attributes(feat_G, 'weight').values()),
                            vmax=max(nx.get_edge_attributes(feat_G, 'weight').values()))
    feat_edge_colors = [cmap(feat_norm(feat_G[u][v]['weight'])) for u, v in feat_G.edges()]
    nx.draw(feat_G, pos=scaled_node_coordinates, node_size=node_size, node_color="gray",
            edge_color=feat_edge_colors, with_labels=False, ax=axes[0,1], width=0.7, alpha=0.7)
    axes[0,1].set_title('Feature Adjacency Matrix')
    axes[0,1].title.set_fontsize(50)

    sigma = 0
    freq = 0
    feat_filtered_norm = Normalize(vmin=min(nx.get_edge_attributes(feat_filtered_G[(sigma, freq)], 'weight').values()),
                                    vmax=max(nx.get_edge_attributes(feat_filtered_G[(sigma, freq)], 'weight').values()))
    feat_filtered_edge_colors = [cmap(feat_filtered_norm(feat_filtered_G[(sigma, freq)][u][v]['weight'])) for u, v in feat_filtered_G[(sigma, freq)].edges()]
    nx.draw(feat_filtered_G[(sigma, freq)], pos=scaled_node_coordinates, node_size=node_size, node_color="gray",
            edge_color=feat_filtered_edge_colors, with_labels=False, ax=axes[1, 0], width=0.7, alpha=0.7)
            # edge_color=feat_filtered_edge_colors, with_labels=False, ax=axes[freq+1, sigma], width=0.7, alpha=0.7)
    freq_title = 'LPF' if freq == 0 else 'HPF'
    axes[1, 0].set_title(f'sigma={(sigma+1)*10}, filter={freq_title}')
    axes[1, 0].title.set_fontsize(50)

    sigma = 2
    freq = 1
    feat_filtered_norm = Normalize(vmin=min(nx.get_edge_attributes(feat_filtered_G[(sigma, freq)], 'weight').values()),
                                    vmax=max(nx.get_edge_attributes(feat_filtered_G[(sigma, freq)], 'weight').values()))
    feat_filtered_edge_colors = [cmap(feat_filtered_norm(feat_filtered_G[(sigma, freq)][u][v]['weight'])) for u, v in feat_filtered_G[(sigma, freq)].edges()]
    nx.draw(feat_filtered_G[(sigma, freq)], pos=scaled_node_coordinates, node_size=node_size, node_color="gray",
            edge_color=feat_filtered_edge_colors, with_labels=False, ax=axes[1, 1], width=0.7, alpha=0.7)
    freq_title = 'LPF' if freq == 0 else 'HPF'
    axes[1, 1].set_title(f'sigma={(sigma+1)*10}, filter={freq_title}')
    axes[1, 1].title.set_fontsize(50)
    

    # Save the plot to a file
    plt.savefig(os.path.join(visualization_save_path,f"{file_name}_graphs.png"), dpi=100)
    plt.close()
    del fig, axes
    gc.collect()


def load_node_images_from_h5(file_path):
    with h5py.File(file_path, 'r') as hdf5_file:
        node_image_coords = hdf5_file['10.0x_coords'][:]
        node_image_patches = hdf5_file['10.0x_patches'][:]
    return node_image_coords, node_image_patches


def visualize_graph_with_image(edge_index, edge_attr, node_coordinates, node_size=6, file_name='', image_path='./', visualization_save_path='./', note='', scale_factor=1.0):
    scaled_node_coordinates = {node: (scale_factor * x, scale_factor * y) for node, (x, y) in node_coordinates.items()}

    image_file_path = os.path.join(image_path, file_name + '.h5')
    node_image_coords, node_image_patches = load_node_images_from_h5(image_file_path)
    node_image_coords = node_image_coords[:, :2] * 2 + node_image_coords[:, 2:4]
    scaled_node_image_coords = {i: (scale_factor * x, scale_factor * y) for i, (x, y) in enumerate(node_image_coords)}
    node_image_patches = {i: patch for i, patch in enumerate(node_image_patches)}

    for key in scaled_node_coordinates.keys():
        if key not in scaled_node_image_coords.keys():
            print(f'Node {key} does not have image patch.')
            return
        elif scaled_node_coordinates[key] != scaled_node_image_coords[key]:
            print(f'Node {key} has different coordinates between node and image patch.')
            return

    max_coords_x, max_coords_y = np.max(list(scaled_node_coordinates.values()), axis=0)
    ratio = max_coords_x/max_coords_y
    num_cols, num_rows = 3, 3
    if ratio > 2:
        num_cols = 4.5
    elif ratio > 1.8:
        num_cols = 4
    elif ratio < 0.7:
        num_rows = 4
    if max(max_coords_x, max_coords_y) < 80:
        zoom_factor = 0.4
    else:
        zoom_factor = 0.2
    print(f'Maximum coordinates: ({max_coords_x}, {max_coords_y}), Ratio: {ratio:.4f}, Zoom Factor: {zoom_factor:.2f}')


    fig, ax = plt.subplots(figsize=(16 * num_cols *1.2, 16 * num_rows *1.4))

    cmap = cm.viridis

    G = convert_to_networkx_graph(edge_index, edge_attr)

    def draw_graph_with_images(G, ax, scaled_node_image_coords, node_image_patches, zoom_factor):
        norm = Normalize(vmin=min(nx.get_edge_attributes(G, 'weight').values()), vmax=max(nx.get_edge_attributes(G, 'weight').values()))
        edge_colors = [cmap(norm(G[u][v]['weight'])) for u, v in G.edges()]
        
        edge_lines = []
        for u, v in G.edges():
            x1, y1 = scaled_node_coordinates[u]
            x2, y2 = scaled_node_coordinates[v]
            edge_lines.append(((x1, y1), (x2, y2)))
        

        for node, (x, y) in tqdm(scaled_node_coordinates.items()):
            if node in scaled_node_image_coords.keys() and scaled_node_coordinates[node] == scaled_node_image_coords[node]:
                img_array = np.transpose(node_image_patches[node], (2, 1, 0))
                img_array = np.flipud(img_array)
                img = OffsetImage(img_array, zoom=zoom_factor *2, alpha=0.8 +0.0)
                ab = AnnotationBbox(img, (x, y), frameon=False, zorder=1)
                ax.add_artist(ab)
            else:
                print(f'Node {node} ({x},{y}) does not have image patch.')

        lc = LineCollection(edge_lines, colors=edge_colors, linewidths=3.0, alpha=0.8-0.2, zorder=2)
        ax.add_collection(lc)
        node_x = [pos[0] for node, pos in scaled_node_coordinates.items()]
        node_y = [pos[1] for node, pos in scaled_node_coordinates.items()]
        ax.scatter(node_x, node_y, s=node_size*2, c="red", alpha=0.8-0.2, zorder=3)

    print(f'Drawing {file_name}_{note}_graph_image.png')
    draw_graph_with_images(G, ax, scaled_node_image_coords, node_image_patches, zoom_factor)
    ax.axis('off')
    ax.set_title('Graph with Image Patches')
    ax.title.set_fontsize(50)
    if note == '':
        file_name = os.path.join(visualization_save_path, f"{file_name}_graph_image.png")
    else:
        file_name = os.path.join(visualization_save_path, f"{file_name}_{note}_graph_image.png")
    plt.savefig(file_name, dpi=50)
    plt.close()
    del fig, ax
    gc.collect()


def visualize_positional_encoding(edge_index, node_coordinates, pe, dim=8, node_size=6, file_name='', visualization_save_path='./', pe_type='LapPE', scale_factor=1.0):
    G = convert_to_networkx_graph(edge_index)

    scaled_node_coordinates = {node: (scale_factor * x, scale_factor * y) for node, (x, y) in node_coordinates.items()}
    num_nodes = len(node_coordinates)
    if num_nodes <1000:
        node_size = node_size*3

    cmap_node = cm.inferno

    num_rows = 2
    num_cols = dim//2
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(16*num_cols, 16*num_rows))

    for i in range(dim):
        node_colors = pe[:,i]
        # set edge alpha
        for _, _, edge in G.edges(data=True):
            edge['alpha'] = 0.0
        nx.draw(G, pos=scaled_node_coordinates, node_size=node_size, node_color=node_colors, cmap=cmap_node, 
                with_labels=False, ax=axes[i//num_cols, i%num_cols], width=0.0, alpha=1.0)

        axes[i//num_cols, i%num_cols].set_title(f'[{i}-th PE]')
        axes[i//num_cols, i%num_cols].title.set_fontsize(50)
    plt.savefig(os.path.join(visualization_save_path,f"{file_name}_{pe_type}.png"), dpi=100)
    plt.close()
    del fig, axes
    gc.collect()


def visualize_positional_encoding_LSPE(edge_index, edge_index_LSPE, node_coordinates, pe, dim=8, node_size=6, file_name='', visualization_save_path='./', pe_type='LapPE', scale_factor=1.0):
    G = convert_to_networkx_graph(edge_index)
    scaled_node_coordinates = {node: (scale_factor * x, scale_factor * y) for node, (x, y) in node_coordinates.items()}
    num_nodes = len(node_coordinates)
    if num_nodes <1000:
        node_size = node_size*3

    cmap_node = cm.inferno

    num_rows = 2
    num_cols = dim//2
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(16*num_cols, 16*num_rows))

    for i in range(dim):
        node_colors = pe[:,i]
        edge_index_LSPE_dense = np.zeros((num_nodes, num_nodes))
        for j in range(edge_index_LSPE.shape[1]):
            edge_index_LSPE_dense[edge_index_LSPE[0][j], edge_index_LSPE[1][j]] = 1
        edge_index_LSPE_dense = edge_index_LSPE_dense + edge_index_LSPE_dense.T
        edge_index_LSPE_dense = edge_index_LSPE_dense + np.eye(num_nodes)
        edge_index_LSPE_dense = (edge_index_LSPE_dense>0).astype(int)
        node_colors = np.matmul(edge_index_LSPE_dense, node_colors)

        nx.draw(G, pos=scaled_node_coordinates, node_size=node_size, node_color=node_colors, cmap=cmap_node, 
                with_labels=False, ax=axes[i//num_cols, i%num_cols], width=0.3, alpha=0.8)
        axes[i//num_cols, i%num_cols].set_title(f'[{i}-th PE]')
        axes[i//num_cols, i%num_cols].title.set_fontsize(50)
    plt.savefig(os.path.join(visualization_save_path,f"{file_name}_{pe_type}_LSPE.png"), dpi=100)
    plt.close()
    del fig, axes
    gc.collect()

def get_base_distance_matrix(coords, threshold):
    # if DEVICE == 'cuda':
    # process on torch with GPU and return numpy data
    coords = torch.from_numpy(coords).to(DEVICE).float()
    base_distance = torch.cdist(coords, coords)
    base_distance_mask = base_distance <= threshold
    base_distance_mask = base_distance_mask * (1-torch.eye(len(coords))).to(DEVICE)
    base_distance = base_distance.cpu().numpy()
    base_distance_mask = base_distance_mask.cpu().numpy()

    del coords
    gc.collect()
    if DEVICE == 'cuda': torch.cuda.empty_cache()
    return base_distance, base_distance_mask
    # else:
    #     base_distance = np.zeros((len(coords), len(coords)))
    #     for i in range(len(coords)):
    #         for j in range(i+1, len(coords)):
    #             distance = np.linalg.norm(coords[i]-coords[j])
    #             base_distance[i][j] = distance
    #             base_distance[j][i] = distance
    #     base_distance_mask = base_distance <= threshold
    #     base_distance_mask = base_distance_mask * (1-np.eye(len(coords)))
        
    #     return base_distance, base_distance_mask

def get_base_adj(coords, base_k, base_distance, base_distance_mask):
    # if DEVICE == 'cuda':
    base_distance_torch = torch.from_numpy(base_distance).to(DEVICE)
    base_distance_mask_torch = torch.from_numpy(base_distance_mask).to(DEVICE)
    base_adj = torch.zeros((len(coords), len(coords)), device=DEVICE)
    
    for node in range(len(coords)):
        top_k_idx = torch.argsort(base_distance_torch[node])[:base_k+1]
        base_adj[node][top_k_idx] = 1
    
    base_adj = ((base_adj + base_adj.permute(1, 0)) > 0).int()  # get symmetric matrix
    base_adj = base_adj * base_distance_mask_torch              # apply threshold and remove self-loop

    base_adj_edge_index = torch.stack(torch.where(base_adj)).cpu().numpy()  # [2, num_edges]
    base_adj_edge_attr = base_distance_torch[base_adj.bool()].cpu().numpy()
    del base_distance_torch, base_distance_mask_torch
    gc.collect()
    if DEVICE == 'cuda': torch.cuda.empty_cache()

    # else:
    #     base_adj = np.zeros((len(coords), len(coords)))
    #     for node in range(len(coords)):
    #         top_k_idx = np.argsort(base_distance[node])[:base_k+1] # +1 will be excluded (the node itself)
    #         base_adj[node][top_k_idx] = 1
    #     base_adj = ((base_adj + base_adj.T)>0).astype(int)  # get symmetric matrix
    #     base_adj = base_adj*base_distance_mask              # apply threshold and remove self-loop

    #     base_adj_edge_index = np.stack(np.where(base_adj))  # [2, num_edges]
    #     base_adj_edge_attr = base_distance[np.where(base_adj)]  # [num_edges] Distance attribute for each edge
        
    return base_adj, base_adj_edge_index, base_adj_edge_attr


def get_feature_similarity_adj(coords, feats, feat_k, base_distance_mask):

    feature_adj, feature_adj_edge_index, feature_adj_edge_attr = {}, {}, {}

    # if DEVICE == 'cuda':
    # process on torch with GPU and return numpy data
    feats = torch.from_numpy(feats).to(DEVICE).float()
    normalized_feats = feats / torch.norm(feats, dim=1)[:, None]
    similarity_matrix = torch.matmul(normalized_feats, normalized_feats.T)
    feature_distance = 1-similarity_matrix
    for k in feat_k:
        feature_adj[k] = torch.zeros((len(coords), len(coords))).to(DEVICE)
    for node in range(len(coords)):
        top_k_idx = torch.argsort(feature_distance[node])[:feat_k[0]+1]
        for k in feat_k:
            top_k_idx = top_k_idx[:k+1]
            feature_adj[k][node][top_k_idx] = 1
    for k in feat_k:
        feature_adj[k] = ((feature_adj[k] + feature_adj[k].permute(1, 0)) > 0).int()
        feature_adj[k] = feature_adj[k] * torch.from_numpy(base_distance_mask).to(DEVICE)

        feature_adj_edge_index[k] = torch.stack(torch.where(feature_adj[k])).cpu().numpy()  # [2, num_edges]
        feature_adj_edge_attr[k] = feature_distance[torch.where(feature_adj[k])].cpu().numpy()  # [num_edges] Distance attribute for each edge
    del feats, normalized_feats, similarity_matrix, feature_distance
    gc.collect()
    if DEVICE == 'cuda': torch.cuda.empty_cache()

    # else:
    #     normalized_feats = feats / np.linalg.norm(feats, axis=1)[:, None]
    #     similarity_matrix = np.matmul(normalized_feats, normalized_feats.T)
    #     feature_distance = 1-similarity_matrix
    #     feature_adj = np.zeros((len(coords), len(coords)))
    #     for node in range(len(coords)):
    #         top_k_idx = np.argsort(feature_distance[node])[:feat_k+1] # +1 will be excluded (the node itself)
    #         feature_adj[node][top_k_idx] = 1
    #     feature_adj = ((feature_adj + feature_adj.T)>0).astype(int)  # get symmetric matrix
    #     feature_adj = feature_adj*base_distance_mask              # apply threshold and remove self-loop

    #     feature_adj_edge_index = np.stack(np.where(feature_adj))  # [2, num_edges]
    #     feature_adj_edge_attr = feature_distance[np.where(feature_adj)]  # [num_edges] Distance attribute for each edge


    return feature_adj, feature_adj_edge_index, feature_adj_edge_attr


def save_graphs(f,threshold,base_k, feat_k, LapPE_k_10x, RWPE_walk_length, image_path='./', visualization_save_path='./', mode='visualize'):
    file_name = f.split('/')[-1].split('.')[0]
    # if not os.path.exists(os.path.join(image_path, file_name + '.h5')):
    #     print(f'No image file for {file_name}! Skip visualization for adjacency matrix with image patches.')
    #     return
    # if not os.path.exists(f):
    #     return
    # print(f'Threshold: {threshold}, Feature K: {feat_k}, LapPE K: {LapPE_k_10x}, RWPE Walk Length: {RWPE_walk_length}')
    start_time = time.time()
    with h5py.File(f,'r') as hdf5_file:
        if f'10.0x_base_LapPE_{threshold}' in hdf5_file and mode=='save':
            print(f'{file_name} already has graph information')
            return

        # coords_5x = hdf5_file['5.0x_coords'][:]                     # [num_patches, dim (2)]
        coords_10x = hdf5_file['10.0x_coords'][:]                   # [num_patches, dim (4)]
        # coords_20x = hdf5_file['20.0x_coords'][:]                   # [num_patches, dim (6)]

        if (len(coords_10x) < 2000 or len(coords_10x) > 4000) and mode=='visualize':
            print(f'{file_name} has too many nodes ({len(coords_10x)}). Skip constructing graphs.')
            return
        else:
            print(f'{file_name} has {len(coords_10x)} nodes. Constructing graphs...')

        # original_feats_5x = hdf5_file['5.0x_patches'][:]            # [num_patches, dim (1024)]
        original_feats_10x = hdf5_file['10.0x_patches'][:]          # [num_patches, dim (1024)]
        # original_feats_20x = hdf5_file['20.0x_patches'][:]          # [num_patches, dim (1024)]

        # filtered_feats_5x = hdf5_file['5.0x_filtered_patches'][:]   # [num_patches, sigma 10/20/30 (3), LF/HF (2), dim (1024)]
        filtered_feats_10x = hdf5_file['10.0x_filtered_patches'][:] # [num_patches, sigma 10/20/30 (3), LF/HF (2), dim (1024)]
        # filtered_feats_20x = hdf5_file['20.0x_filtered_patches'][:] # [num_patches, sigma 10/20/30 (3), LF/HF (2), dim (1024)]
    # coords_5x = coords_5x*4
    # coords_10x = coords_10x[:, :2]*4 + coords_10x[:, 2:4]*2
    # coords_20x = coords_20x[:, :2]*4 + coords_20x[:, 2:4]*2 + coords_20x[:, 4:6]
    coords_10x = coords_10x[:, :2]*2 + coords_10x[:, 2:4]

    # base_distance_5x, base_distance_5x_mask = get_base_distance_matrix(coords_5x, threshold*4)
    # base_distance_10x, base_distance_10x_mask = get_base_distance_matrix(coords_10x, threshold*2)
    # base_distance_20x, base_distance_20x_mask = get_base_distance_matrix(coords_20x, threshold)
    base_distance_10x, base_distance_10x_mask = get_base_distance_matrix(coords_10x, threshold)
    base_adj_10x, base_adj_edge_index_10x, base_adj_edge_attr_10x = get_base_adj(coords_10x, base_k, base_distance_10x, base_distance_10x_mask)

    feat_adj_10x, feat_adj_edge_index_10x, feat_adj_edge_attr_10x = get_feature_similarity_adj(coords_10x, original_feats_10x, feat_k, base_distance_10x_mask)
    feat_filtered_adj_10x, feat_filtered_adj_edge_index_10x, feat_filtered_adj_edge_attr_10x = {}, {}, {}
    for sigma in range(filtered_feats_10x.shape[1]):
        for freq in range(filtered_feats_10x.shape[2]):
            feat_filtered_adj_10x[sigma, freq], feat_filtered_adj_edge_index_10x[sigma, freq], feat_filtered_adj_edge_attr_10x[sigma, freq] = get_feature_similarity_adj(coords_10x, filtered_feats_10x[:, sigma, freq], feat_k, base_distance_10x_mask)
    
    base_RWPE_10x = get_random_walk_eigenvector_pe(base_adj_edge_index_10x, len(coords_10x), walk_length=RWPE_walk_length, add_self_loops=True)
    eig_vals_10x, base_LapPE_10x = get_laplacian_eigenvector_pe(base_adj_edge_index_10x, k=LapPE_k_10x)

    node_coordinates = {i: (x, y) for i, (x, y) in enumerate(coords_10x)}

    if mode == 'visualize':
        for k in feat_k:
            visualization_save_path_k = os.path.join(visualization_save_path, f'k_{k}')
            if not os.path.exists(visualization_save_path_k):
                os.makedirs(visualization_save_path_k)
            visualize_graphs(base_adj_edge_index_10x, base_adj_edge_attr_10x, feat_adj_edge_index_10x[k], feat_adj_edge_attr_10x[k], feat_filtered_adj_edge_index_10x, feat_filtered_adj_edge_attr_10x, node_coordinates, node_size=6, file_name=file_name, visualization_save_path=visualization_save_path_k, k=k, scale_factor=1.0)
        
            visualize_graph_with_image(feat_adj_edge_index_10x[k], feat_adj_edge_attr_10x[k], node_coordinates, node_size=6, file_name=file_name, image_path=image_path, visualization_save_path=visualization_save_path_k, note='feat')
            
            visualize_graph_with_image(feat_filtered_adj_edge_index_10x[0,0][k], feat_filtered_adj_edge_attr_10x[0, 0][k], node_coordinates, node_size=6, file_name=file_name, image_path=image_path, visualization_save_path=visualization_save_path_k, note=f'feat_filtered_10_LF')
            visualize_graph_with_image(feat_filtered_adj_edge_index_10x[2,1][k], feat_filtered_adj_edge_attr_10x[2, 1][k], node_coordinates, node_size=6, file_name=file_name, image_path=image_path, visualization_save_path=visualization_save_path_k, note=f'feat_filtered_30_HF')

        visualize_positional_encoding(base_adj_edge_index_10x, node_coordinates, base_LapPE_10x, dim=12, node_size=64, file_name=file_name, visualization_save_path=visualization_save_path, pe_type='LapPE', scale_factor=1.0)
        visualize_positional_encoding(base_adj_edge_index_10x, node_coordinates, base_RWPE_10x, dim=12, node_size=64, file_name=file_name, visualization_save_path=visualization_save_path, pe_type='RWPE', scale_factor=1.0)
        visualize_graph_with_image(base_adj_edge_index_10x, base_adj_edge_attr_10x, node_coordinates, node_size=6, file_name=file_name, image_path=image_path, visualization_save_path=visualization_save_path, note='base')
        
 
    elif mode =='save':
        # pass

        with h5py.File(f,'a') as hdf5_file:
            for k in feat_k:
                if f'10.0x_feat_adj_edge_index_{threshold}_{k}' in hdf5_file:
                    del hdf5_file[f'10.0x_feat_adj_edge_index_{threshold}_{k}']
                hdf5_file.create_dataset(f'10.0x_feat_adj_edge_index_{threshold}_{k}', data=feat_adj_edge_index_10x[k])
                if f'10.0x_feat_adj_edge_attr_{threshold}_{k}' in hdf5_file:
                    del hdf5_file[f'10.0x_feat_adj_edge_attr_{threshold}_{k}']
                hdf5_file.create_dataset(f'10.0x_feat_adj_edge_attr_{threshold}_{k}', data=feat_adj_edge_attr_10x[k])
                for sigma in range(filtered_feats_10x.shape[1]):
                    for freq in range(filtered_feats_10x.shape[2]):
                        sigma_title = str((sigma+1)*10)
                        freq_title = 'LPF' if freq == 0 else 'HPF'
                        if f'10.0x_feat_filtered_adj_edge_index_{sigma_title}_{freq_title}_{threshold}_{k}' in hdf5_file:
                            del hdf5_file[f'10.0x_feat_filtered_adj_edge_index_{sigma_title}_{freq_title}_{threshold}_{k}']
                        hdf5_file.create_dataset(f'10.0x_feat_filtered_adj_edge_index_{sigma_title}_{freq_title}_{threshold}_{k}', data=feat_filtered_adj_edge_index_10x[sigma, freq][k])
                        if f'10.0x_feat_filtered_adj_edge_attr_{sigma_title}_{freq_title}_{threshold}_{k}' in hdf5_file:
                            del hdf5_file[f'10.0x_feat_filtered_adj_edge_attr_{sigma_title}_{freq_title}_{threshold}_{k}']
                        hdf5_file.create_dataset(f'10.0x_feat_filtered_adj_edge_attr_{sigma_title}_{freq_title}_{threshold}_{k}', data=feat_filtered_adj_edge_attr_10x[sigma, freq][k])
        
            if f'10.0x_base_adj_edge_index_{threshold}' in hdf5_file:
                del hdf5_file[f'10.0x_base_adj_edge_index_{threshold}']
            hdf5_file.create_dataset(f'10.0x_base_adj_edge_index_{threshold}', data=base_adj_edge_index_10x)
            if f'10.0x_base_adj_edge_attr_{threshold}' in hdf5_file:
                del hdf5_file[f'10.0x_base_adj_edge_attr_{threshold}']
            hdf5_file.create_dataset(f'10.0x_base_adj_edge_attr_{threshold}', data=base_adj_edge_attr_10x)
            if f'10.0x_base_RWPE_{threshold}' in hdf5_file:
                del hdf5_file[f'10.0x_base_RWPE_{threshold}']
            hdf5_file.create_dataset(f'10.0x_base_RWPE_{threshold}', data=base_RWPE_10x)
            if f'10.0x_eig_vals_{threshold}' in hdf5_file:
                del hdf5_file[f'10.0x_eig_vals_{threshold}']
            hdf5_file.create_dataset(f'10.0x_eig_vals_{threshold}', data=eig_vals_10x)
            if f'10.0x_base_LapPE_{threshold}' in hdf5_file:
                del hdf5_file[f'10.0x_base_LapPE_{threshold}']
            hdf5_file.create_dataset(f'10.0x_base_LapPE_{threshold}', data=base_LapPE_10x)


    del base_adj_edge_index_10x, base_adj_edge_attr_10x, feat_adj_edge_index_10x, feat_adj_edge_attr_10x, feat_filtered_adj_edge_index_10x, feat_filtered_adj_edge_attr_10x, base_RWPE_10x, eig_vals_10x, base_LapPE_10x
    gc.collect()
    print(f'{file_name}_{threshold}_{feat_k} done! Elapsed time: {(time.time()-start_time)/60:.2f} mins')

def check_num_patches(all_files, mag='10.0x'):
    num_patches = []
    total_num_patches_ours = 0
    for f in all_files:
        if not os.path.exists(f):
            continue
        with h5py.File(f,'r') as hdf5_file:
            coords = hdf5_file[f'{mag}_coords'][:]
        num_patches.append(len(coords))
        if len(coords) < 600:
            total_num_patches_ours += len(coords)*4
        else:
            total_num_patches_ours += 600*4
    print(f'[{mag}] Total: {len(num_patches)} / Min: {min(num_patches)} / Max: {max(num_patches)} / Mean: {np.mean(num_patches)}')
    print(f'[{mag}] Total number of patches at high magnification for our model: {total_num_patches_ours} / Mean: {total_num_patches_ours/len(num_patches)}')
    return num_patches

def wrapper_save_graphs(args):
    return save_graphs(*args)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Configurations for WSI Training')
    parser.add_argument('--threshold', type=int, default=3, help='Euclidean distance threshold')
    parser.add_argument('--base_k', type=int, default=8, help='Top K for base adjacency matrix')
    parser.add_argument('--mode', type=str, default='visualize', help='visualize or save')
    parser.add_argument('--machine', type=str, default='4GPU', help='CPU or 1GPU or 4GPU')
    parser.add_argument('--num_cpus', type=int, default=16, help='Number of CPUs for multiprocessing')
    parser.add_argument('--mag', type=str, default='10.0x', help='Magnification for constructing graphs')
    args = parser.parse_args()

    feature_path = '/path/to/feature' 
    image_path = '/path/to/image'

    all_files = [f for f in os.listdir(feature_path) if f.endswith('.h5')]
    all_files = [os.path.join(feature_path, f) for f in all_files]
    
    if args.mode == 'visualize':
        import random
        random.shuffle(all_files)
        all_files = all_files[:100]
    elif args.mode == 'save':
        all_files.sort()

    threshold = args.threshold      # Euclidean distance threshold
    base_k= args.base_k             # Top K for base adjacency matrix
    feat_k = [8, 7, 6, 5, 4]             # Top K for feature similarity adjacency matrix
    feat_k.sort()
    feat_k = feat_k[::-1]           # Descending order [8, 7, 6, 5, 4]
    # num_patches_5x = check_num_patches(all_files, mag='5.0x')
    num_patches_10x = check_num_patches(all_files, mag='10.0x')
    # num_patches_20x = check_num_patches(all_files, mag='20.0x')
    # assert False
    # # assert False
    # LapPE_k_5x = min(num_patches_5x)-2
    # LapPE_k_10x = min(num_patches_10x)-2
    LapPE_k_10x = 226
    # LapPE_k_20x = min(num_patches_20x)-2
    RWPE_walk_length = 24

    visualization_save_path = f'./visualization_threshold_{threshold}/'
    
    if not os.path.exists(visualization_save_path):
        os.makedirs(visualization_save_path)

    if args.mode =='visualize' or args.num_cpus==1:
        if args.machine == '4GPU':
            import random
            random.shuffle(all_files)

        for idx, f in enumerate(all_files):
            file_name = f.split('/')[-1].split('.')[0]
            print(f'\n[ {idx+1} / {len(all_files)} ] Processing {file_name}...')
            save_graphs(f, threshold, base_k,feat_k, LapPE_k_10x, RWPE_walk_length, image_path, visualization_save_path, mode=args.mode)
    elif args.mode =='save':
        num_cpus = args.num_cpus
    
        # multiprocessing with tqdm
        from multiprocessing import Pool
        # from tqdm import tqdm
        with Pool(num_cpus) as p:
            args_list = zip(all_files, 
                            [threshold] * len(all_files), 
                            [base_k] * len(all_files), 
                            [feat_k] * len(all_files), 
                            [LapPE_k_10x] * len(all_files), 
                            [RWPE_walk_length] * len(all_files), 
                            [image_path] * len(all_files), 
                            [visualization_save_path] * len(all_files), 
                            [args.mode] * len(all_files))
            
            for _ in tqdm(p.imap_unordered(wrapper_save_graphs, args_list), total=len(all_files)):
                pass
            p.close()
            p.join()

    print('Done!')
