import os

import networkx as nx
import numpy as np
from pygsp import graphs
from spektral.datasets import Citation
from spektral.utils import load_off
from networkx import barbell_graph
from magni.src.modules.transforms import normalize_point_cloud
import numpy as np
#from torch_geometric.datasets import TUDataset
from spektral.data import Dataset, Graph
from torch_geometric.utils import to_dense_adj

MODELNET_CONFIG = {
    "Airplane": {
        "classname": "airplane",
        "split": "train",
        "sample": 151,
    },
    "Car": {
        "classname": "car",
        "split": "train",
        "sample": 79,
    },
    "Guitar": {
        "classname": "guitar",
        "split": "train",
        "sample": 38,
    },
    "Person": {
        "classname": "person",
        "split": "train",
        "sample": 83,
    },
}


def make_dataset(name, **kwargs):
    if "seed" in kwargs:
        np.random.seed(kwargs.pop("seed"))
    if ((name in graphs.__all__) or (name.lower() in ["grid2d", "ring", "bunny", "airfoil", "minnesota", "sensor", "community", "barabasialbert", "davidsensornet", "erdosrenyi", "torus", "barbell"])):
        return make_cloud(name)
    if name in MODELNET_CONFIG:
        return make_modelnet(**MODELNET_CONFIG[name])
    if name in Citation.available_datasets():
        return make_citation(name)


def make_cloud(name):
    if name.lower() == "grid2d":
        G = graphs.Grid2d(N1=8, N2=8)
    elif name.lower() == "ring":
        G = graphs.Ring(N=64)
    elif name.lower() == "bunny":
        G = graphs.Bunny()
    elif name.lower() == "airfoil":
        G = graphs.Airfoil()
    elif name.lower() == "minnesota":
        G = graphs.Minnesota()
    elif name.lower() == "sensor":
        G = graphs.Sensor(N=64)
    elif name.lower() == "community":
        G = graphs.Community(N=64)
    elif name.lower() == "barabasialbert":
        G = graphs.BarabasiAlbert(N=64)
    elif name.lower() == "davidsensornet":
        G = graphs.DavidSensorNet(N=64)
    elif name.lower() == "erdosrenyi":
        G = graphs.ErdosRenyi(N=64)
    elif name.lower() == "torus":
        G = graphs.Torus(8, 8)
    elif name.lower() == "barbell":
        G = barbell_graph(20, 24, create_using=None)
        G = networkx_to_pygsp(G)
    else:
        raise ValueError("Unknown dataset: {}".format(name))

    if not hasattr(G, "coords"):
        G.set_coordinates(kind="spring")


    
    x = G.coords.astype(np.float32)
    y = np.zeros(x.shape[0])  # X[:,0] + X[:,1]
    #A = G.W

    ### Shuffle the data and the adjacency matrix
    n = len(y)
    np.random.seed(42) 
    perm = np.random.permutation(n) 

    A = G.W[perm, :][:, perm]
    x = x[perm,:]

    if A.dtype.kind == "b":
        A = A.astype("i")

    return A, x, y


def networkx_to_pygsp(nx_graph):
    """Convert a NetworkX graph to a PyGSP graph."""
    A = nx.to_scipy_sparse_array(nx_graph, dtype=np.float32).toarray()  # Get sparse adjacency matrix
    G = graphs.Graph(A)  # Create PyGSP Graph
    return G

def make_modelnet(classname="airplane", split="train", sample=151):
    path = os.path.expanduser(
        f"~/.spektral/datasets/ModelNet/40/{classname}/{split}/{classname}_{sample:04d}.off"
    )
    graph = load_off(path)
    x, a = graph.x, graph.a
    x = normalize_point_cloud(x)

    return a, x, classname


def make_citation(name):
    graph = Citation(name)[0]
    x, a = graph.x, graph.a

    gg = nx.Graph(a)
    lay = nx.spring_layout(gg)
    x = np.array([lay[i] for i in range(a.shape[0])])

    return a, x, name


def edge_index_to_adj_matrix(edge_index, num_nodes):
    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
    for src, dst in zip(edge_index[0], edge_index[1]):
        adj_matrix[src][dst] = 1
    return adj_matrix

class TorchDataset(Dataset):#
    def __init__(self, dataset, **kwargs):
        self.dataset = dataset
        super().__init__(**kwargs)
    
    def read(self):
        out_list= []
        for i in range(len(self.dataset)):
                        # Get a single graph (e.g., the first one)
            data = self.dataset[0]

            # Extract edge_index, x (node features), y (graph label)
            edge_index = data.edge_index
            x = data.x.numpy()  # Convert to numpy array
            y = data.y.numpy()

            # Convert edge_index to adjacency matrix (dense)
            adj = to_dense_adj(edge_index)[0].numpy()  # [num_nodes, num_nodes]

            out_list.append((x, adj, None, y))
        return [Graph(x=x, a=a, e=e, y=y) for x, a, e, y in out_list]