import os
import random
import subprocess as sp
import numpy as np
from scipy.special import comb
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx, degree
from torch_geometric.nn import GINConv, global_add_pool
import warnings
from arguments import arg_parse

# Setup device
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



# --- User-Provided Helper Functions ---
# <<< --- IMPORTANT: SET THIS PATH TO YOUR ORCA EXECUTABLE DIRECTORY --- >>>
ORCA_DIR = 'orca/'
# <<< ----------------------------------------------------------------- >>>

# Check if ORCA directory and executable exist
ORCA_EXEC_PATH = os.path.join(ORCA_DIR, 'orca')
if not os.path.isdir(ORCA_DIR) or not os.path.isfile(ORCA_EXEC_PATH):
    print(f"Error: ORCA directory ('{ORCA_DIR}') or executable ('{ORCA_EXEC_PATH}') not found.")
    print("Please ensure ORCA is downloaded/compiled and ORCA_DIR is set correctly.")
    exit()
if not os.access(ORCA_EXEC_PATH, os.X_OK):
    print(f"Error: ORCA executable at '{ORCA_EXEC_PATH}' does not have execute permissions.")
    print("You might need to run: chmod +x " + ORCA_EXEC_PATH)
    exit()


def edge_list_reindexed(G):
    idx = 0
    id2idx = dict()
    for u in G.nodes():
        id2idx[str(u)] = idx
        idx += 1

    edges = []
    for (u, v) in G.edges():
        edges.append((id2idx[str(u)], id2idx[str(v)]))
    return edges


def orca(graph, orca_dir=ORCA_DIR):
    # Create a temporary file for ORCA input
    tmp_file_path = os.path.join(orca_dir, f'tmp-{random.random():.4f}.txt')
    with open(tmp_file_path, 'w+') as f:
        f.write(f'{graph.number_of_nodes()} {graph.number_of_edges()}\n')
        for u, v in edge_list_reindexed(graph):
            f.write(f'{u} {v}\n')

    # Define paths for ORCA execution
    orca_executable = os.path.join(orca_dir, 'orca')
    # Use a unique output file to avoid race conditions if running in parallel
    output_filename = os.path.join(orca_dir, f'output-{random.randint(0, int(1e6))}.txt')

    try:
        # Execute ORCA: count orbits for graphlets of size 4
        sp.run([orca_executable, '4', tmp_file_path, output_filename],
               check=True, capture_output=True, text=True)

        # Read the output
        with open(output_filename, 'r') as file:
            output = file.read()

        # Parse the output into a numpy array
        node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ')))
                                      for node_cnts in output.strip('\n').split('\n')])
    except sp.CalledProcessError as e:
        print(f"ORCA execution failed for graph. Error: {e.stderr}")
        return None
    finally:
        # Clean up temporary files
        if os.path.exists(tmp_file_path):
            os.remove(tmp_file_path)
        if os.path.exists(output_filename):
            os.remove(output_filename)

    return node_orbit_counts


def count2density(node_orbit_counts, graph_size):
    # Handle cases where graph_size is too small for certain motifs
    if graph_size < 4:
        # Return a zero vector of the correct shape if motifs can't be computed
        return np.zeros(9)

    all_possible_motifs = {
        2: comb(graph_size, 2, exact=True),
        3: comb(graph_size, 3, exact=True),
        4: comb(graph_size, 4, exact=True)
    }

    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1]
    node_size = np.array([2] + [3]*2 + [4]*6)
    rewiring_normalizer = [1., 3., 1., 12., 4., 3., 12., 6., 1.]
    non_unique_count = np.zeros(9)
    density = np.zeros(9)

    count_over_nodes = np.sum(node_orbit_counts, axis=0)

    # Ensure count_over_nodes is long enough (15 for 4-node orbits)
    if len(count_over_nodes) < 15:
        return np.zeros(9)

    non_unique_count[0] = count_over_nodes[0]
    for i in range(1, 9):
        start_idx = sum(map_loc2motif[:i])
        end_idx = start_idx + map_loc2motif[i]
        non_unique_count[i] = sum(count_over_nodes[start_idx:end_idx])

    unique_count = non_unique_count / node_size

    for i in range(9):
        # Avoid division by zero if no such motifs are possible
        denominator = (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])
        if denominator > 0:
            density[i] = unique_count[i] / denominator
        else:
            density[i] = 0.0
    return density


def count2density2(node_orbit_counts, graph_size):
    all_possible_motifs = {size: comb(graph_size, size, exact=True) for size in [2, 3, 4, 5]}
    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1, 3, 4, 2, 3, 4, 3, 1, 4, 4, 2, 4, 2, 3, 2, 3, 3, 3, 3, 2, 2, 1]
    node_size = np.array(1*[2] + 2*[3] + 6*[4] + 21*[5])
    rewiring_normalizer = [ 1.,  3.,  1., 12.,  4.,  3., 12.,  6.,  1., 60., 60.,  5., 60., 60.,
        30., 12., 60., 60., 15., 60., 10., 60., 10., 20., 60., 30., 30., 15.,
        10.,  1.]
    count_over_nodes = np.sum(node_orbit_counts, axis=0)
    non_unique_count = np.zeros(30)
    non_unique_count[0] = count_over_nodes[0]
    for i in range(1, 30):
        start_idx = sum(map_loc2motif[:i])
        non_unique_count[i] = sum(count_over_nodes[start_idx: start_idx+map_loc2motif[i]])
    unique_count = non_unique_count / node_size
    density = np.zeros(30)
    for i in range(30):
      density[i] = unique_count[i] / (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])
    return density



# --- Main Analysis Script ---

# Suppress specific warnings
warnings.filterwarnings("ignore", category=UserWarning, module='torch_geometric.utils.convert')


# Load the desired dataset
args = arg_parse()

dataset_name = args.DS
dataset_root = '/tmp/TUDatasets_motif_analysis'
print(f"\n--- Loading Dataset: {dataset_name} ---")

# **FIX 2: Apply the transform during dataset loading**
dataset = TUDataset(
    root=dataset_root,
    name=dataset_name,
)



all_densities = []

# Iterate through the original full dataset to preserve graph structure
for i in range(len(dataset)):
    data = dataset[i]
    # Convert to NetworkX graph for ORCA
    graph_nx = to_networkx(data, to_undirected=True, remove_self_loops=True)
    if graph_nx.number_of_nodes() == 0:
        continue

    # Run ORCA and calculate density
    orbit_counts = orca(graph_nx, orca_dir=ORCA_DIR)
    if orbit_counts is None:
        continue # Skip if ORCA failed

    density_vector = count2density(orbit_counts, graph_nx.number_of_nodes())
    all_densities.append(density_vector)


print(f"Finished density calculation for {len(all_densities)} graphs.")

X = np.array(all_densities)

# Clean up any potential NaN/inf values
if np.isnan(X).any() or np.isinf(X).any():
    print("Warning: NaN or Inf values found in density features. Replacing with 0.")
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

# save X as a numpy array
np.save('Models/motif_densities_' + dataset_name + '.npy', X)
    
    