from utils.config import cfg
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import LRGBDataset
from torch.utils.data import random_split
from rewiring.expander import ExpanderTransform
from rewiring.fa import FullyAdjacentTransform

tu_datasets = ['MUTAG', 'ENZYMES', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']
lrgb_datasets = ['Peptides-func']

def load_datasets():
    # this does a bit more and handles some of the TUDatasets Transform
    transform = FullyAdjacentTransform() if cfg.use_fa else ExpanderTransform()

    if cfg.dataset.format == 'OGB':
        return load_ogb_dataset(transform)
    elif cfg.dataset.format == 'PyG':
        if cfg.dataset.name in tu_datasets:
            return load_tu_dataset(transform)
        elif cfg.dataset.name in lrgb_datasets:
            return load_lrgb_dataset(transform) 
    else:
        raise ValueError('Dataset does not exist')

def load_ogb_dataset(pre_transform = None):
    dataset = PygGraphPropPredDataset(name=cfg.dataset.name, root=make_dir_root(), pre_transform=pre_transform)

    split_idx = dataset.get_idx_split()

    set_output_dim_if_required(dataset.num_tasks)

    train_dataset = dataset[split_idx["train"]]
    validation_dataset = dataset[split_idx["valid"]]
    test_dataset = dataset[split_idx["test"]]

    return train_dataset, validation_dataset, test_dataset, dataset

def load_tu_dataset(pre_transform = None):
    train_fraction = 0.8
    val_fraction = 0.1

    dataset = TUDataset(name=cfg.dataset.name, root=make_dir_root(), pre_transform=pre_transform)
    
    set_input_dim_if_required(dataset.num_features)
    set_output_dim_if_required(dataset.num_classes)

    dataset_size = len(dataset)
    train_size = int(train_fraction * dataset_size)
    validation_size = int(val_fraction * dataset_size)
    test_size = dataset_size - train_size - validation_size

    train_dataset, validation_dataset, test_dataset = random_split(dataset,[train_size, validation_size, test_size])

    return train_dataset, validation_dataset, test_dataset, dataset

def load_lrgb_dataset(pre_transform = None):
    train_dataset = LRGBDataset(root=make_dir_root(), name=cfg.dataset.name, split="train", pre_transform=pre_transform)
    validation_dataset = LRGBDataset(root=make_dir_root(), name=cfg.dataset.name, split="val", pre_transform=pre_transform)
    test_dataset = LRGBDataset(root=make_dir_root(), name=cfg.dataset.name, split="test", pre_transform=pre_transform)

    set_output_dim_if_required(train_dataset.num_classes)

    return train_dataset, validation_dataset, test_dataset, None

def set_input_dim_if_required(input_dim):
    if cfg.gnn.input_dim is None:
        cfg.gnn.input_dim = input_dim

def set_output_dim_if_required(output_dim):
    if cfg.gnn.output_dim is None:
        cfg.gnn.output_dim = output_dim

def make_dir_root() -> str:
    if cfg.dataset.format == 'OGB':
        folder_name = 'ogb'
    elif cfg.dataset.name in tu_datasets:
        folder_name = 'tu'
    elif cfg.dataset.name in lrgb_datasets:
        folder_name = 'lrgb'
    else:
        raise ValueError('todo:')
    
    expander_sl = '-has-sl' if cfg.expander.input_edge_index_has_self_loops else ''
    expander_fixed = 'fixed-node_embeddings' if cfg.expander.zero_edge_embeddings else 'non-fixed-edge_embeddings'
    expander_shuffle_nodes = '-shuffle' if cfg.expander.shuffle_nodes else ''
    expander_folder_name = 'base' if cfg.expander.type is None else f'{cfg.expander.type.lower()}-{cfg.expander.variant}-{expander_fixed}{expander_sl}{expander_shuffle_nodes}'
    fa_name = '-fa' if cfg.use_fa else '' 

    return f'{cfg.dataset.dir}/{folder_name}/{expander_folder_name}{fa_name}{expander_sl}'
