from typing import Tuple, Optional, Literal
import os
import pickle
import copy

from requests import get

import torch
from torch_geometric.utils import subgraph, to_edge_index
import torch_geometric.transforms as T
from torch_geometric.data import Data, Batch
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
                                      GNNBenchmarkDataset, Yelp, Flickr,
                                      Reddit2, PPI, SNAPDataset)
import scipy.sparse as sp
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from ogb.nodeproppred import PygNodePropPredDataset
from torch_sparse import SparseTensor

import networkx as nx
import numpy as np

from loguru import logger

from utils.others import index2mask, gen_masks

def save_edgelist_as_csr(edgelist_path, csr_save_path, num_nodes=None):
    """
    1) Read the edgelist using NetworkX.
    2) Convert to a SciPy CSR adjacency matrix.
    3) Save using scipy's .npz format.

    Args:
        edgelist_path (str): Path to your edgelist file (e.g. 'edgelist.txt').
        csr_save_path (str): Path to output .npz file (e.g. 'graph_csr.npz').
        num_nodes (int, optional): If you know the total number of nodes in the graph 
                                   (especially useful if some nodes have no edges, 
                                   or if node indexing has gaps).
                                   Otherwise, will be inferred from the data.
    """
    # Load graph with NetworkX
    G = nx.read_edgelist(edgelist_path, nodetype=int)
    
    # Optionally determine the number of nodes if not provided
    if num_nodes is None:
        num_nodes = len(G.nodes())
    
    # If your nodes are not 0-based or consecutive, remap them:
    mapping = {node: i for i, node in enumerate(sorted(G.nodes()))}
    
    # Now build row, col arrays for edges
    row = []
    col = []
    for u, v in G.edges():
        row.append(mapping[u])
        col.append(mapping[v])
        # If undirected, also store the reverse edge
        if not G.is_directed():
            row.append(mapping[v])
            col.append(mapping[u])

    # Build a (row, col, data) for adjacency
    data = np.ones(len(row), dtype=np.float32)
    adjacency_csr = sp.csr_matrix((data, (row, col)), shape=(num_nodes, num_nodes))
    
    # Save it
    sp.save_npz(csr_save_path, adjacency_csr)
    print(f"Saved CSR adjacency to '{csr_save_path}'.")

def load_edgelist_as_pyg_data(path_to_edgelist):
    """
    Loads an edge list from a text file and converts it into a PyTorch Geometric Data object.
    
    The edgelist file is assumed to be in a format compatible with `nx.read_edgelist`.
    For example, lines like:
        0 1
        1 2
        2 3
    etc.
    
    Returns:
        data (Data): A torch_geometric.data.Data object containing the edges (edge_index).
    """
    # Read edge list using NetworkX. You can specify nodetype=int if node IDs are integers.
    G = nx.read_edgelist(path_to_edgelist, nodetype=int)
    
    # Map graph nodes to a consecutive integer range [0, num_nodes-1],
    # because PyG typically expects 0-based consecutive node IDs.
    # (This is optional if your edgelist is already 0-based consecutive, but is good practice.)
    mapping = {node: i for i, node in enumerate(G.nodes())}
    # Remap edges:
    edges_mapped = [(mapping[u], mapping[v]) for u, v in G.edges()]

    # Convert to torch tensor of shape [2, num_edges]
    edge_index = torch.tensor(edges_mapped, dtype=torch.long).t().contiguous()
    
    # Create a Data object
    data = Data(edge_index=edge_index)
    
    # (Optional) If you want a default node feature matrix, e.g. one-hot identity
    # x = torch.eye(data.num_nodes, dtype=torch.float)
    # data.x = x
    
    return data

def preprocess_adj(split: Literal['train', 'valid', 'test'],
                   adj: SparseTensor, add_self_loops: Optional[bool] = True,
                   root: Optional[str]=None, dataset: Optional[str]=None) -> SparseTensor:
    r""" Pre-process the adjacency matrix. """
    logger.info('Preprocess adjacency matrix... (GCN Norm)')

    filepath = f'{root}/{dataset}_{split}_gcn_adj.pt'
    if os.path.exists(filepath):
        logger.info('Load GCN-normalized adjacency matrix...')
        return torch.load(filepath)
    else:
        logger.info('GCN-normalized adjacency matrix not found... generate it...')
        gcn_adj = gcn_norm(adj, add_self_loops=add_self_loops)
        logger.info('GCN-normalized adjacency matrix generated and saved...')
        torch.save(gcn_adj, filepath)
        return gcn_adj

def preprocess_inductive(data: Data) -> Tuple[Data, Data, Data]:
    r""" inductive preprocessing """
    logger.info('Preprocess inductive training...')
    transform = T.Compose([T.ToSparseTensor()])
    ori_edge_index, ori_edge_attr = to_edge_index(data.adj_t)
    train_data = copy.copy(data)
    train_data.edge_index, _ = subgraph(data.train_mask, ori_edge_index, ori_edge_attr, relabel_nodes=True)
    train_data.x = data.x[data.train_mask]
    train_data.y = data.y[data.train_mask]
    train_data = transform(train_data)

    logger.info(f'Train Data: #Nodes: {train_data.num_nodes} | #Edges: {train_data.num_edges}')

    val_data = copy.copy(data)
    val_data.edge_index, _ = subgraph(data.val_mask, ori_edge_index, ori_edge_attr, relabel_nodes=True)
    val_data.x = data.x[data.val_mask]
    val_data.y = data.y[data.val_mask]
    val_data = transform(val_data)

    logger.info(f'Val.  Data: #Nodes: {val_data.num_nodes} | #Edges: {val_data.num_edges}')

    test_data = copy.copy(data)
    test_data.edge_index, _ = subgraph(data.test_mask, ori_edge_index, ori_edge_attr, relabel_nodes=True)
    test_data.x = data.x[data.test_mask]
    test_data.y = data.y[data.test_mask]
    test_data = transform(test_data)

    logger.info(f'Test  Data: #Nodes: {test_data.num_nodes} | #Edges: {test_data.num_edges}')

    assert data.num_nodes == (train_data.num_nodes + val_data.num_nodes + test_data.num_nodes), 'Missing some dataset...'

    return train_data, val_data, test_data

def get_planetoid(root: str, name: str) -> Tuple[Data, int, int]:
    transform = T.Compose([T.NormalizeFeatures(), T.ToSparseTensor()])
    dataset = Planetoid(f'{root}/Planetoid', name, transform=transform)
    return dataset[0], dataset.num_features, dataset.num_classes


def get_wikics(root: str) -> Tuple[Data, int, int]:
    dataset = WikiCS(f'{root}/WIKICS', transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data.val_mask = data.stopping_mask
    data.stopping_mask = None
    return data, dataset.num_features, dataset.num_classes


def get_coauthor(root: str, name: str) -> Tuple[Data, int, int]:
    dataset = Coauthor(f'{root}/Coauthor', name, transform=T.ToSparseTensor())
    data = dataset[0]
    torch.manual_seed(12345)
    data.train_mask, data.val_mask, data.test_mask = gen_masks(
        data.y, 20, 30, 20)
    return data, dataset.num_features, dataset.num_classes


def get_amazon(root: str, name: str) -> Tuple[Data, int, int]:
    dataset = Amazon(f'{root}/Amazon', name, transform=T.ToSparseTensor())
    data = dataset[0]
    torch.manual_seed(12345)
    data.train_mask, data.val_mask, data.test_mask = gen_masks(
        data.y, 20, 30, 20)
    return data, dataset.num_features, dataset.num_classes


def get_arxiv(root: str) -> Tuple[Data, int, int]:
    dataset = PygNodePropPredDataset('ogbn-arxiv', f'{root}/OGB',
                                     pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data.node_year = None
    data.y = data.y.view(-1)
    split_idx = dataset.get_idx_split()
    data.train_mask = index2mask(split_idx['train'], data.num_nodes)
    data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
    data.test_mask = index2mask(split_idx['test'], data.num_nodes)
    return data, dataset.num_features, dataset.num_classes

def get_papers(root: str) -> Tuple[Data, int, int]:
    dataset = PygNodePropPredDataset('ogbn-papers100M', f'{root}/OGB',
                                     pre_transform=T.ToSparseTensor())
    data = dataset[0]
    num_features = dataset.num_features
    num_classes = dataset.num_classes

    split_idx = dataset.get_idx_split()
    del dataset
    data.train_mask = index2mask(split_idx['train'], data.num_nodes)
    data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
    data.test_mask = index2mask(split_idx['test'], data.num_nodes)
    del split_idx

    data.node_year = None
    data.y = data.y.view(-1)
    data.adj_t = data.adj_t.to_symmetric()
    return data, num_features, num_classes

def get_products(root: str) -> Tuple[Data, int, int]:
    """V = 2.45M, E = 123.7M, I = 100, O = 47"""
    dataset = PygNodePropPredDataset('ogbn-products', f'{root}/OGB',
                                     pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.y = data.y.view(-1)
    split_idx = dataset.get_idx_split()
    data.train_mask = index2mask(split_idx['train'], data.num_nodes)
    data.val_mask = index2mask(split_idx['valid'], data.num_nodes)
    data.test_mask = index2mask(split_idx['test'], data.num_nodes)
    return data, dataset.num_features, dataset.num_classes


def get_yelp(root: str) -> Tuple[Data, int, int]:
    """V = 0.716M, E = 13.9M, I = 300, O = 100"""
    dataset = Yelp(f'{root}/YELP', pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes


def get_flickr(root: str) -> Tuple[Data, int, int]:
    dataset = Flickr(f'{root}/Flickr', pre_transform=T.ToSparseTensor())
    return dataset[0], dataset.num_features, dataset.num_classes


def get_reddit(root: str) -> Tuple[Data, int, int]:
    """V = 0.233M, E = 23.2M, I = 602, O = 41"""
    dataset = Reddit2(f'{root}/Reddit2', pre_transform=T.ToSparseTensor())
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes


def get_ppi(root: str, split: str = 'train') -> Tuple[Data, int, int]:
    dataset = PPI(f'{root}/PPI', split=split, pre_transform=T.ToSparseTensor())
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
    return data, dataset.num_features, dataset.num_classes


def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
    dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
                                  pre_transform=T.ToSparseTensor())
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    return data, dataset.num_features, dataset.num_classes


def get_igb(root: str, name: str) -> Tuple[Data, int, int]:
    """V = 1M, E = 12M, I = 1024, O = 19"""
    data = torch.load(f'{root}/igb-pyg/{name}.pt')
    return data, data.num_features, 19  # or 2983


def get_livejournal(root: str):
    dataset = SNAPDataset(f'{root}/SNAP', 'soc-livejournal1',
                          pre_transform=T.ToSparseTensor(), transform=T.ToUndirected())
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    data.adj_t = data.adj_t.to_symmetric()
    return data, dataset.num_features, dataset.num_classes


def get_kronecker_synthetic(
        root: str, name: str,
        feature_dim: int = 128,
        num_classes: int = 10,
        train_ratio: float = 0.8,
        val_ratio: float = 0.1
        ) -> Tuple[Data, int, int]:
    """
    Creates a random (undirected) synthetic graph in CSR format.
    Adds random features, random labels, and train/val/test masks.
    Stores adjacency in PyG's `adj_t` format.
    
    Args:
        num_nodes (int):    Number of nodes in the graph.
        feature_dim (int):  Size of the node feature vector.
        num_classes (int):  Number of classes for node labels.
        train_ratio (float):Fraction of nodes for training.
        val_ratio (float):  Fraction of nodes for validation.
        
    Returns:
        data (Data): A PyG Data object with:
                     - x: [num_nodes, feature_dim]
                     - y: [num_nodes]
                     - train_mask / val_mask / test_mask: [num_nodes]
                     - adj_t: torch_sparse SparseTensor
    """

    assert name in ['kron22-10', 'kron23-10', 'kron24-10', 'kron25-10', 'kron26-10'], 'Invalid synthetic graph name...'

    if name == 'kron23-10':
        csr_path = f'{root}/kron23-10_csr.npz'
    elif name == 'kron24-10':
        csr_path = f'{root}/kron24-10_csr.npz'
    elif name == 'kron25-10':
        csr_path = f'{root}/kron25-10_csr.npz'
    elif name == 'kron26-10':
        csr_path = f'{root}/kron26-10_csr.npz'
    elif name == 'kron22-10':
        csr_path = f'{root}/kron22-10_csr.npz'
    else:
        raise NotImplementedError

    # if it is already saved, load it
    if os.path.exists(f'{root}/{name}.pt'):
        logger.info(f'Load synthetic data: {root}/{name}.pt')
        return torch.load(f'{root}/{name}.pt'), feature_dim, num_classes

    # -------------------------------------------------------------------------
    # 1) Load Kronecker synthetic graph in SciPy’s CSR format
    # -------------------------------------------------------------------------
    # Generate a random CSR adjacency (directed by default)
    adjacency_csr = sp.load_npz(csr_path)

    # Convert to COO for easier extraction of rows/cols
    adjacency_coo = adjacency_csr.tocoo()
    row = torch.from_numpy(adjacency_coo.row).long()
    col = torch.from_numpy(adjacency_coo.col).long()

    num_nodes = adjacency_csr.shape[0]

    # -------------------------------------------------------------------------
    # 2) Create node features (random normal)
    # -------------------------------------------------------------------------
    x = torch.randn(num_nodes, feature_dim)

    # -------------------------------------------------------------------------
    # 3) Create random 10-class labels for each node
    # -------------------------------------------------------------------------
    y = torch.randint(0, num_classes, (num_nodes,))

    # -------------------------------------------------------------------------
    # 4) Split nodes into 8/1/1 train/valid/test
    # -------------------------------------------------------------------------
    indices = torch.randperm(num_nodes)
    train_size = int(train_ratio * num_nodes)
    val_size = int(val_ratio * num_nodes)
    test_size = num_nodes - train_size - val_size
    
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask   = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask  = torch.zeros(num_nodes, dtype=torch.bool)
    
    train_mask[indices[:train_size]] = True
    val_mask[indices[train_size : train_size + val_size]] = True
    test_mask[indices[train_size + val_size : ]] = True

    # -------------------------------------------------------------------------
    # 5) Create `adj_t` (CSC-like) using PyG’s SparseTensor
    #    By convention, `adj_t` is the transposed adjacency (i.e. CSC).
    #    However, you can just store it as-is; PyG will handle message passing.
    # -------------------------------------------------------------------------
    # from_edge_index defaults to building an untransposed SparseTensor
    # If you explicitly want the transposed version, you can call .t() below.
    adj_t = SparseTensor.from_edge_index(
        torch.stack([row, col], dim=0),
        sparse_sizes=(num_nodes, num_nodes)
    ).coalesce()

    # adj_t = adj_t.t()  # If you actually want it stored in transposed CSC form.

    # -------------------------------------------------------------------------
    # 6) Build the PyG `Data` object
    # -------------------------------------------------------------------------
    data = Data(
        x=x,
        y=y,
        train_mask=train_mask,
        val_mask=val_mask,
        test_mask=test_mask
    )
    
    # PyG Data objects typically store edges via edge_index, but you can
    # attach the sparse adjacency as an attribute if you prefer:
    data.adj_t = adj_t

    # if it is not saved, save it
    if not os.path.exists(f'{root}/{name}.pt'):
        torch.save(data, f'{root}/{name}.pt')
        logger.info(f'Save synthetic data: {root}/{name}.pt')

    return data, feature_dim, num_classes


def get_data(root: str, name: str) -> Tuple[Data, int, int]:
    if name.lower() in ['cora', 'citeseer', 'pubmed']:
        return get_planetoid(root, name)
    elif name.lower() in ['coauthorcs', 'coauthorphysics']:
        return get_coauthor(root, name[8:])
    elif name.lower() in ['amazoncomputers', 'amazonphoto']:
        return get_amazon(root, name[6:])
    elif name.lower() == 'wikics':
        return get_wikics(root)
    elif name.lower() in ['cluster', 'pattern']:
        return get_sbm(root, name)
    elif name.lower() == 'reddit':
        return get_reddit(root)
    elif name.lower() == 'ppi':
        return get_ppi(root)
    elif name.lower() == 'flickr':
        return get_flickr(root)
    elif name.lower() == 'yelp':
        return get_yelp(root)
    elif name.lower() in ['ogbn-arxiv', 'arxiv']:
        return get_arxiv(root)
    elif name.lower() in ['ogbn-papers100M', 'papers']:
        return get_papers(root)
    elif name.lower() in ['ogbn-products', 'products']:
        return get_products(root)
    elif name.lower() in ['igb-tiny', 'igb-small', 'igb-medium', 'igb-large', 'igb-full']:
        return get_igb(root, name.lower())
    elif name.lower() in ['livejournal']:
        return get_livejournal(root)
    elif 'kron' in name.lower():
        return get_kronecker_synthetic(root, name)
    else:
        raise NotImplementedError


def get_preprocessed_data(root: str, name:str) -> Tuple[Data, int, int]:
    if os.path.exists(f'{root}/{name}.pt'):
        logger.info(f'Load preprocessed data: {root}/{name}.pt')
        data = torch.load(f'{root}/{name}.pt')
        with open(f'{root}/{name}.pickle', 'rb') as f:
            meta_data = pickle.load(f)
        return data, meta_data['num_features'], meta_data['num_classes']
    else:
        logger.info('Tried to load preprocessed data, but not exist...')
        data, num_features, num_classes = get_data(root, name)
        torch.save(data, f'{root}/{name}.pt')
        meta_data = {'num_features': num_features, 'num_classes': num_classes}
        with open(f'{root}/{name}.pickle', 'wb') as f:
            pickle.dump(meta_data, f)
        logger.info(f'Preprocessed data saved: {root}/{name}.pt')
        return data, num_features, num_classes

if __name__ == "__main__":
    # Example usage:
    # path = "/datasets/grinnder/kron23-10.txt"
    # data = load_edgelist_as_pyg_data(path)
    # print(data)
    # Data(edge_index=[2, num_edges])

    # for size in [22]:
    #     path = f"/small_data/grinnder/kron{size}-10.txt"
    #     csr_path = f"/small_data/grinnder/kron{size}-10_csr.npz"
    #     save_edgelist_as_csr(path, csr_path)

    for size in [22]:
        get_kronecker_synthetic('/small_data/grinnder', f'kron{size}-10')