import os
import json
from typing import Any
import networkx as nx
#import gdown

import pandas as pd
import numpy as np
from numpy.linalg import eigh
from scipy.sparse.linalg import eigsh

import scipy.sparse as sp

import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid, Actor, WikipediaNetwork
import torch_geometric.transforms as T
from torch_geometric.utils import convert, to_undirected, add_self_loops
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
from utils import normalize_adj
from torch_geometric.utils import add_self_loops
from torch_geometric.utils import to_undirected

def get_dataset(name, normalize_features=False, transform=None):
    path = f"./data/{name}"

    if name in ["cora", "citeseer", "pubmed"]:
        dataset = Planetoid(path, name)
    else:
        raise NotImplementedError
    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform
    data = dataset[0]
    data.num_classes = dataset.num_classes
    data.num_nodes = data.x.shape[0]

    # edge_index_with_self_loops, _ = add_self_loops(dataset.edge_index)
    # undirected_edge_index = to_undirected(edge_index_with_self_loops)
    # dataset.edge_index = undirected_edge_index

    return data


def get_dataset2(name):
    path = f"./data/{name}/{name}/{name}.pt"
    dataset = torch.load(path)
    if name in ["arxiv_topic", 'arxiv_year', "arxiv_topic_s", 'arxiv_year_s', "arxiv_s"]:
        feat = dataset.x
        scaler = StandardScaler()
        scaler.fit(feat)
        dataset.x = torch.tensor(scaler.transform(feat), dtype=torch.float32)

        edge_index_with_self_loops, _ = add_self_loops(dataset.edge_index)
        undirected_edge_index = to_undirected(edge_index_with_self_loops)
        dataset.edge_index = undirected_edge_index



    elif name in ["hm_class", 'hm_regre',"hm_class_s", 'hm_regre_s']:
        # feat = dataset.x
        # scaler = StandardScaler()
        # scaler.fit(feat)
        # dataset.x = torch.tensor(scaler.transform(feat), dtype=torch.float32)
        if name in ["hm_regre", "hm_regre_s"]:
            dataset.y_mean = torch.mean(dataset.y)
            dataset.y_std = torch.std(dataset.y)
            dataset.y_regre_std = (dataset.y - dataset.y_mean) / dataset.y_std

    return dataset




def get_dataset3(name):
    path = f"./data/{name}.pt"
    dataset = torch.load(path)
    return dataset

def get_dataset_condensed(method, name, rate):
    path = f"./condensed_data/{name}/{method}_{name}_{rate}.pt"
    dataset = torch.load(path)
    #dataset.edge_weight = dataset.edge_attr


    # transform = NormalizeFeatures()
    # dataset = transform(dataset)
    #print(dataset.edge_weight)

    # feat = dataset.x
    # scaler = StandardScaler()
    # scaler.fit(feat)
    # dataset.x = torch.tensor(scaler.transform(feat), dtype=torch.float32)

    return dataset