import torch
import numpy as np
import scipy.sparse as sparse
from torch_geometric.utils import to_dense_adj, to_dense_batch
import networkx as nx

from typing import List


from easydict import EasyDict
import yaml

##Geometric Scattering

# for generating all possible wavelet combinations 
def all_index_combinations(k: int) -> List[List[int]]:
    """
    Return all subsets of {0,1,...,k-1}, sorted by
    the integer value of their binary inclusion mask.
    """
    result = []
    for mask in range(1, 1 << k):            # 0 .. 2^k - 1
        subset = [i for i in range(k) 
                  if (mask >> i) & 1]     # include i if bit i is 1
        result.append(subset)
    return result



def edge_index_to_sparse_adj(edge_index: torch.LongTensor, num_nodes: int) -> torch.Tensor:
    # edge_index: [2, E], num_nodes: N
    row, col = edge_index
    # if our graph is undirected, you may want to add the reverse edges here
    # e.g. row = torch.cat([row, col]); col = torch.cat([col, row])
    values = torch.ones(row.size(0), dtype=torch.float32)
    adj = torch.sparse_coo_tensor(
        torch.stack([row, col], dim=0),
        values,
        (num_nodes, num_nodes),
    ).coalesce()

    return adj

def torch_adj_to_nx(adj):
    A = adj
    i, j = A.indices()
    G = nx.Graph()
    G.add_nodes_from(range(adj.size(0)))
    G.add_edges_from({(int(u), int(v)) for u, v in torch.stack([i, j]).T.tolist() if u != v})
    return G
    


import gc, psutil, os
from custom import *

def log_cpu(name=""):
    # gc.collect()
    rss = psutil.Process(os.getpid()).memory_info().rss / 1e6
    print(f"[{name}] RSS: {rss:.1f} MB")

def get_lap(adj):
    degree = torch.diag(torch.sum(adj.to_dense(), dim = 0))
    lap = degree - adj
    return lap

def get_diffusion(adj):
    degree = torch.diag(torch.sum(adj.to_dense(), dim = 0))
    return adj @ torch.inverse(degree)


def enumerate_labels(labels):
    """ Converts the labels from the original
        string form to the integer [0:MaxLabels-1]
    """
    unique = list(set(labels))
    labels = np.array([unique.index(label) for label in labels])
    return labels


def normalize_adjacency(adj):
    """ Normalizes the adjacency matrix according to the
        paper by Kipf et al.
        https://arxiv.org/pdf/1609.02907.pdf
    """
    adj = adj + sparse.eye(adj.shape[0])

    node_degrees = np.array(adj.sum(1))
    node_degrees = np.power(node_degrees, -0.5).flatten()
    node_degrees[np.isinf(node_degrees)] = 0.0
    node_degrees[np.isnan(node_degrees)] = 0.0
    degree_matrix = sparse.diags(node_degrees, dtype=np.float32)

    adj = degree_matrix @ adj @ degree_matrix
    return adj

def normalize_by_batch(x, batch):
    num_groups = int(batch.max()) + 1
    norm_sq = x.new_zeros(num_groups, x.shape[-1]).index_add(0, batch, x.pow(2))
    eps   = 1e-6
    norms  = (norm_sq + eps).sqrt()
    x_norm = x / norms[batch]

    return x_norm

def orthogonalize_by_batch(x, batch, max_nodes=50):


    X_dense, mask = to_dense_batch(x, batch) # [B, m, k]

    Q_t, R = torch.linalg.qr(X_dense)  
    Q = Q_t       

    # 3) Un‐pack back to (m, k)
    x_orth = Q[mask]  

    return x_orth

def convert_scipy_to_torch_sparse(matrix):
    matrix_helper_coo = matrix.tocoo().astype('float32')
    data = torch.FloatTensor(matrix_helper_coo.data)
    rows = torch.LongTensor(matrix_helper_coo.row)
    cols = torch.LongTensor(matrix_helper_coo.col)
    indices = torch.vstack([rows, cols])

    shape = torch.Size(matrix_helper_coo.shape)
    matrix = torch.sparse.FloatTensor(indices, data, shape)
    return matrix


def merge_configs(default: EasyDict, override: EasyDict) -> EasyDict:
    """
    For each key in override, if both default[k] and override[k]
    are dict‑like, recurse; otherwise take override[k].
    """
    for k, v in override.items():
        if (
            k in default
            and isinstance(default[k], EasyDict)
            and isinstance(v, dict)
        ):
            # recurse into nested EasyDict
            default[k] = merge_configs(default[k], EasyDict(v))
        else:
            # override or add new key
            default[k] = v
    return default