import torch
import numpy as np
from torch_geometric.utils import to_scipy_sparse_matrix

def _to_dense_torch(x, device=None):
    if isinstance(x, np.ndarray):
        t = torch.from_numpy(x).float()
    else:
        t = x if torch.is_tensor(x) else torch.tensor(x, dtype=torch.float32)
    if device is not None:
        t = t.to(device)
    return t

def split_batch_to_graphs(batch):
    """Split a PyG Batch into per-graph tensors: (x, edge_index, y)."""
    X_list, E_list, y_list = [], [], []
    ptr = batch.ptr.tolist() if hasattr(batch, "ptr") else None
    if ptr is None:
        X_list = [batch.x]
        E_list = [batch.edge_index]
        y_list = [batch.y]
        return X_list, E_list, y_list
    for g in range(len(ptr)-1):
        s, t = ptr[g], ptr[g+1]
        mask = (batch.edge_index[0] >= s) & (batch.edge_index[0] < t)
        e = batch.edge_index[:, mask] - s
        X_list.append(batch.x[s:t])
        E_list.append(e)
        y_list.append(batch.y[g].view(()))  # scalar label per-graph
    return X_list, E_list, y_list

def pyg_to_scipy_adj(edge_index, num_nodes):
    return to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes)
def Uext_batch_from_tree_lists(
    X_list, edge_index_list, gnn_model,
    levels=5, ratio=0.3, temp=0.1, tau=0.5
):
    U_batch = []                     # (optional placeholder if you later add Haar)
    edge_index_list_batch = []
    num_nodes_tree_batch  = []
    num_edges_tree_batch  = []
    features_list_batch   = []
    treeG_batch=[]
    S_assign_List = []

    for X_i, ei_i in zip(X_list, edge_index_list):
        Adjacency = to_scipy_sparse_matrix(ei_i, num_nodes=X_i.shape[0])
        treeG_i, S_assign_list = Make_tree_real1(
            X_i, Adjacency, gnn_model,
            levels=levels, ratio=ratio, temp=temp, tau=tau
        )
        treeG_i = HaarGOB_with_Sassign(treeG_i, S_assign_list)
        U_i, n_nodes_i, n_edges_i, eidx_i,feats_i = extract_haar_basis_and_graph_info(treeG_i)

        U_batch.append(U_i)                     # (or your Haar basis if you compute it)
        edge_index_list_batch.append(eidx_i)
        num_nodes_tree_batch.append(n_nodes_i)
        num_edges_tree_batch.append(n_edges_i)
        features_list_batch.append(feats_i)
        treeG_batch.append(treeG_i)
        S_assign_List.append(S_assign_list)

    return U_batch, edge_index_list_batch, num_nodes_tree_batch, num_edges_tree_batch, features_list_batch,treeG_batch, S_assign_List
