import numpy as np
import torch
from torch_geometric.datasets import Actor, Planetoid, WebKB, WikipediaNetwork


def read_dataset(dataset_name,device):

    datasets_info = {
        'Wisconsin': (WebKB, {'root': '/tmp/Wisconsin', 'name': 'Wisconsin'}),
        'Texas': (WebKB, {'root': '/tmp/Texas', 'name': 'Texas'}),
        'Cornell': (WebKB, {'root': '/tmp/Cornell', 'name': 'Cornell'}),
        'Chameleon': (WikipediaNetwork, {'root': '/tmp/Chameleon', 'name': 'chameleon'}),
        'Squirrel': (WikipediaNetwork, {'root': '/tmp/Squirrel', 'name': 'squirrel'}),
        'Film': (Actor, {'root': '/tmp/Film'}),
        'Cora': (Planetoid, {'root': '/tmp/Cora', 'name': 'Cora'}),
        'CiteSeer': (Planetoid, {'root': '/tmp/CiteSeer', 'name': 'CiteSeer'}),
        'PubMed': (Planetoid, {'root': '/tmp/PubMed', 'name': 'PubMed'}),
    }

    dataset_class, dataset_args = datasets_info[dataset_name]
    dataset = dataset_class(**dataset_args)
    data = dataset[0].to(device)
    num_features = dataset.num_features
    num_classes = int(data.y.max().item()) + 1

    return data,num_features,num_classes


def split_setter(dataset_name,data,split_id,device):

    if split_id > 9 or split_id < 0 or not isinstance(split_id, int): 
        raise ValueError("split_id should be integer and in range [0,9]")
    
    split = np.load(f'splits/{dataset_name.lower()}_split_0.6_0.2_{split_id}.npz')
    data.train_mask = torch.tensor(split['train_mask']).to(device)
    data.val_mask = torch.tensor(split['val_mask']).to(device)
    data.test_mask = torch.tensor(split['test_mask']).to(device)

    return data


