from datasets.saturation import Saturation
from torch_geometric.utils import to_dense_adj
from torch import Tensor
from datasets.infection import Infection
from torch_geometric.datasets import TUDataset
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import ShuffleSplit, StratifiedKFold
from sklearn.utils import compute_class_weight
from torch_geometric.loader import DataLoader
from tqdm import trange
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU, ModuleList
from torch_geometric.datasets import Planetoid
from dig.xgraph.dataset import SynGraphDataset, MoleculeDataset
from ogb.graphproppred import Evaluator
from ogb.graphproppred import PygGraphPropPredDataset



from torch_geometric.nn import GINConv, global_add_pool
import argparse

import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


###MLP with lienar output
class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
        '''
            num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
            input_dim: dimensionality of input features
            hidden_dim: dimensionality of hidden units at ALL layers
            output_dim: number of classes for prediction
            device: which device to use
        '''

        super(MLP, self).__init__()

        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()

            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))

            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))

    def forward(self, x):
        if self.linear_or_not:
            # If linear model
            return self.linear(x)
        else:
            # If MLP
            h = x
            for layer in range(self.num_layers - 1):
                h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
            return self.linears[self.num_layers - 1](h)


class GraphCNN(nn.Module):
    def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps,
                 graph_pooling_type, neighbor_pooling_type, device):
        '''
            num_layers: number of layers in the neural networks (INCLUDING the input layer)
            num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
            input_dim: dimensionality of input features
            hidden_dim: dimensionality of hidden units at ALL layers
            output_dim: number of classes for prediction
            final_dropout: dropout ratio on the final linear layer
            learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
            neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
            graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
            device: which device to use
        '''

        super(GraphCNN, self).__init__()

        self.final_dropout = final_dropout
        self.device = device
        self.num_layers = num_layers
        self.graph_pooling_type = graph_pooling_type
        self.neighbor_pooling_type = neighbor_pooling_type
        self.learn_eps = learn_eps
        self.eps = nn.Parameter(torch.zeros(self.num_layers - 1))

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.fcs = nn.ModuleList()

        self.convs.append(
            GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))))
        self.bns.append(nn.BatchNorm1d(hidden_dim))
        self.fcs.append(nn.Linear(input_dim, output_dim))
        self.fcs.append(nn.Linear(hidden_dim, output_dim))

        for i in range(self.num_layers - 1):
            self.convs.append(
                GINConv(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))))
            self.bns.append(nn.BatchNorm1d(hidden_dim))
            self.fcs.append(nn.Linear(hidden_dim, output_dim))

    def __preprocess_neighbors_maxpool(self, batch_graph):
        ###create padded_neighbor_list in concatenated graph

        # compute the maximum number of neighbors within the graphs in the current minibatch
        max_deg = max([graph.max_neighbor for graph in batch_graph])

        padded_neighbor_list = []
        start_idx = [0]

        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.x))
            padded_neighbors = []
            for j in range(len(graph.neighbors)):
                # add off-set values to the neighbor indices
                pad = [n + start_idx[i] for n in graph.neighbors[j]]
                # padding, dummy data is assumed to be stored in -1
                pad.extend([-1] * (max_deg - len(pad)))

                # Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
                if not self.learn_eps:
                    pad.append(j + start_idx[i])

                padded_neighbors.append(pad)
            padded_neighbor_list.extend(padded_neighbors)

        return torch.LongTensor(padded_neighbor_list)

    def __preprocess_neighbors_sumavepool(self, batch_graph):
        ###create block diagonal sparse matrix

        edge_mat_list = []
        start_idx = [0]
        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.x))
            edge_mat_list.append(graph.edge_mat + start_idx[i])
        Adj_block_idx = torch.cat(edge_mat_list, 1)
        Adj_block_elem = torch.ones(Adj_block_idx.shape[1])

        # Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.

        if not self.learn_eps:
            num_node = start_idx[-1]
            self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
            elem = torch.ones(num_node)
            Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
            Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)

        Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1], start_idx[-1]]))

        return Adj_block.to(self.device)

    def __preprocess_graphpool(self, batch_graph):
        ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)

        start_idx = [0]

        # compute the padded neighbor list
        for i, graph in enumerate(batch_graph):
            start_idx.append(start_idx[i] + len(graph.x))

        idx = []
        elem = []
        for i, graph in enumerate(batch_graph):
            ###average pooling
            if self.graph_pooling_type == "average":
                elem.extend([1. / len(graph.x)] * len(graph.x))

            else:
                ###sum pooling
                elem.extend([1] * len(graph.x))

            idx.extend([[i, j] for j in range(start_idx[i], start_idx[i + 1], 1)])
        elem = torch.FloatTensor(elem)
        idx = torch.LongTensor(idx).transpose(0, 1)
        graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))

        return graph_pool.to(self.device)

    def maxpool(self, h, padded_neighbor_list):
        ###Element-wise minimum will never affect max-pooling

        dummy = torch.min(h, dim=0)[0]
        h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
        pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim=1)[0]
        return pooled_rep

    def next_layer_eps(self, h, layer, padded_neighbor_list=None, Adj_block=None):
        ###pooling neighboring nodes and center nodes separately by epsilon reweighting.

        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            # If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                # If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
                pooled = pooled / degree

        # Reweights the center node representation when aggregating it with its neighbors
        pooled = pooled + (1 + self.eps[layer]) * h
        pooled_rep = self.mlps[layer](pooled)
        h = self.batch_norms[layer](pooled_rep)

        # non-linearity
        h = F.relu(h)
        return h

    def next_layer(self, h, layer, padded_neighbor_list=None, Adj_block=None):
        ###pooling neighboring nodes and center nodes altogether

        if self.neighbor_pooling_type == "max":
            ##If max pooling
            pooled = self.maxpool(h, padded_neighbor_list)
        else:
            # If sum or average pooling
            pooled = torch.spmm(Adj_block, h)
            if self.neighbor_pooling_type == "average":
                # If average pooling
                degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
                pooled = pooled / degree

        # representation of neighboring and center nodes
        pooled_rep = self.mlps[layer](pooled)

        h = self.batch_norms[layer](pooled_rep)

        # non-linearity
        h = F.relu(h)
        return h

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        outs = [x]
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            outs.append(x)

        out = None
        for i, x in enumerate(outs):
            #x = global_add_pool(x, batch)
            x = F.dropout(self.fcs[i](x), p=self.final_dropout, training=self.training)
            if out is None:
                out = x
            else:
                out += x
        return F.log_softmax(out, dim=-1)

def score_model_mask(m, d, mask):
    m.eval()
    pred = m(d.x, d.edge_index).argmax(dim=-1)
    acc = int((pred[mask] == d.y[mask]).sum()) / len(mask)
    return acc


class CustomDataset:
    def __init__(self, dataset):
        self.dataset = dataset
        data = dataset[0]
        self.num_node_features = data.num_node_features
        self.num_features = data.num_node_features
        self.num_classes = data.num_classes

    def __len__(self):
        r"""The number of examples in the dataset."""
        return len(self.dataset)

    def __getitem__(
            self,
            idx,
    ):
        if (isinstance(idx, (int, np.integer))
                or (isinstance(idx, Tensor) and idx.dim() == 0)
                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
            return self.dataset[idx]

        else:
            return CustomDataset([self.dataset[i] for i in idx])


class FlattenY:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        data.x = data.x.to(torch.float32)
        data.y = torch.flatten(data.y).to(torch.long)
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)

class TransformGin:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        data.node_features = data.x
        data.label = data.y
        data.edge_mat = to_dense_adj(data.edge_index, batch=data.batch)
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)

class AddFeatures:
    def __init__(self, k=1):
        self.k = k

    def __call__(self, data):
        x = torch.tensor([[1.0] for _ in range(data.num_nodes)], dtype=torch.float)
        data.x = x
        return data

    def __repr__(self):
        return '{}(k={})'.format(self.__class__.__name__, self.k)


def train_eval_model_cv(dataset, dataset_name, params, val_size=0.1, use_pooling=True, debug=False, folds=10,
                         dataset_mask=False):
    print(f'----- {dataset_name} -----')
    skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=42)
    idx_list = []
    if use_pooling:
        labels = [graph.y[0] for graph in dataset]
    elif dataset_mask:
        labels = dataset[0].y.detach().numpy()
    else:
        labels = [0 for _ in dataset]
    for idx in skf.split(np.zeros(len(labels)), labels):
        idx_list.append(idx)
    gnn_accs = []
    for fold in range(folds):
        print(f'----- Fold {fold + 1}/{folds} -----')
        train_idx, test_idx = idx_list[fold]

        if dataset_mask:
            train_dataset = train_idx
            test_dataset = test_idx
        else:
            train_dataset = dataset[train_idx]
            test_dataset = dataset[test_idx]
        gnn_accs_fold = train_model(
            dataset,
            dataset_name,
            params,
            val_size=val_size,
            train_val_dataset=train_dataset,
            test_dataset=test_dataset,
            use_pooling=use_pooling,
            debug=debug,
            dataset_mask=dataset_mask)
        gnn_accs.append(gnn_accs_fold)

    print('---------- CV Results ----------')
    print('---------- GNN ----------')
    print("Acc.:", np.mean(gnn_accs, axis=0), "Std.:", np.std(gnn_accs, axis=0))


def train_model(dataset, dataset_name, params, test_size=0.1, val_size=0.1, train_val_dataset=None, test_dataset=None,
                use_pooling=True,
                debug=False, gumbel_noise=True, dataset_mask=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    data = dataset[0]
    data = data.to(device)
    if test_dataset is None:
        sss = ShuffleSplit(n_splits=1, test_size=test_size, random_state=41)
        X = [data.x for data in dataset]
        y = [data.y for data in dataset]
        sss.get_n_splits(X, y)
        train_index, test_index = next(sss.split(X, y))
        train_val_dataset = dataset[train_index]
        test_dataset = dataset[test_index]

    if dataset_mask:
        sss = ShuffleSplit(n_splits=1, test_size=val_size, random_state=41)
        X = [data.x[index] for index in train_val_dataset]
        y = [data.y[index] for index in train_val_dataset]
        sss.get_n_splits(X, y)
        train_index, val_index = next(sss.split(X, y))
        train_mask = train_val_dataset[train_index]
        val_mask = train_val_dataset[val_index]
        test_mask = test_dataset
        print(len(train_mask), len(val_mask), len(test_mask))

    else:
        sss = ShuffleSplit(n_splits=1, test_size=val_size, random_state=41)
        X = [data.x for data in train_val_dataset]
        y = [data.y for data in train_val_dataset]
        sss.get_n_splits(X, y)
        train_index, val_index = next(sss.split(X, y))
        train_dataset = train_val_dataset[train_index]
        val_dataset = train_val_dataset[val_index]

        print("Train/Val/Test Size:", len(train_dataset), len(val_dataset), len(test_dataset))
        train_loader = DataLoader(train_dataset, batch_size=params["batch_size"])
        val_loader = DataLoader(val_dataset, batch_size=params["batch_size"])
        test_loader = DataLoader(test_dataset, batch_size=params["batch_size"])

    if dataset_name == "OGBA":
        split_idx = dataset.get_idx_split()
        train_mask, val_mask, test_mask = split_idx["train"], split_idx["valid"], split_idx["test"]
        print(len(train_mask), len(val_mask), len(test_mask))


    if use_pooling:
        y = [graph.y.cpu().detach().numpy()[0] for graph in dataset]
    else:
        y = []
        for graph in dataset:
            y += list(graph.y.cpu().detach().numpy())
    classes = np.unique(y)
    class_weights = np.ones(dataset.num_classes)
    computed_class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y)
    if len(computed_class_weights) < dataset.num_classes:
        for i in range(len(classes)):
            class_weights[classes[i]] = computed_class_weights[i]
    else:
        class_weights = computed_class_weights
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

    model = GraphCNN(params["num_layers"], 2, dataset.num_features, params["hidden_units"], dataset.num_classes,
                     params["dropout"], learn_eps=False, graph_pooling_type='sum', neighbor_pooling_type='sum',
                     device=device).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=params["learning_rate"])

    def train_with_mask(mask):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out[mask], data.y[mask])
        loss.backward()
        optimizer.step()
        return float(loss)

    @torch.no_grad()
    def test_with_mask(mask):
        model.eval()
        out = model(data)
        pred = out.argmax(dim=-1)
        loss = F.nll_loss(out[mask], data.y[mask])
        acc = int((pred[mask] == data.y[mask]).sum()) / len(mask)
        return acc, loss

    def train(loader):
        model.train()
        total_loss = 0
        for data in loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, data.y, weight=class_weights)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * data.num_graphs
        return total_loss / len(loader.dataset)

    @torch.no_grad()
    def test(loader):
        model.eval()
        total_correct = 0
        total_loss = 0
        for data in loader:
            data = data.to(device)
            output = model(data)
            loss = F.nll_loss(output, data.y, weight=class_weights)
            correct = int((output.argmax(-1) == data.y).sum())
            if not use_pooling:
                correct /= len(data.y)
            total_correct += correct
            total_loss += float(loss) * data.num_graphs
        return total_correct / len(loader.dataset), total_loss / len(loader.dataset)

    def score_model(m, d):
        m.eval()
        total_correct = 0
        all_pred = []
        all_true = []
        for data in d:
            data.to(device)
            out = m(data.x, data.edge_index, data.batch)
            correct = int((out.argmax(-1) == data.y).sum())
            if not use_pooling:
                correct /= len(data.y)
            total_correct += correct
            all_pred.append(torch.unsqueeze(out.argmax(-1), dim=1))
            all_true.append(data.y)
        acc = total_correct / len(d.dataset)
        if dataset_name == "OGB-molhiv":
            evaluator = Evaluator(name="ogbg-molhiv")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["rocauc"]
        if dataset_name == "OGB-code2":
            evaluator = Evaluator(name="ogbg-code2")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["rocauc"]
        if dataset_name == "OGB-ppa":
            evaluator = Evaluator(name="ogbg-ppa")
            input_dict = {"y_true": torch.cat(all_true, dim=0),
                          "y_pred": torch.cat(all_pred, dim=0)}
            acc = evaluator.eval(input_dict)["acc"]
        return acc

    early_stopping_enabled = True
    es_patience = params["es_patience"]
    es_counter = 0
    best_test_acc = 0.0
    best_val_loss = np.inf
    best_val_acc = 0
    pbar = trange(1, params["epochs"] + 1)
    model_save_name = dataset_name

    for epoch in pbar:
        if dataset_mask:
            loss = train_with_mask(train_mask)
            train_acc, _ = test_with_mask(train_mask)
            val_acc, val_loss = test_with_mask(val_mask)
            test_acc, _ = test_with_mask(test_mask)
        else:
            loss = train(train_loader)
            train_acc, _ = test(train_loader)
            val_acc, val_loss = test(val_loader)
            test_acc, _ = test(test_loader)
        if val_loss < best_val_loss:
            # if val_acc > best_val_acc:
            es_counter = 0
            best_val_loss = val_loss
            best_val_acc = val_acc
            best_test_acc = test_acc
            torch.save(model, f'datasets/data/model_checkpoints/{model_save_name}.pt')

        pbar.set_description(f'Epoch: {epoch:04d}, Loss: {loss:.3f} Train: {train_acc:.3f},'
                             f' Val: {val_acc:.3f}, Test: {test_acc:.3f},'
                             f' Best Val|Test Acc.:  {best_val_acc:.3f} | {best_test_acc:.3f}')


        if early_stopping_enabled and es_counter > es_patience:
            print("-- Early Stopping --")
            break
        es_counter += 1

    print('---------- GNN (Acc, Loss) ----------')
    model = torch.load(f'datasets/data/model_checkpoints/{model_save_name}.pt').to(device)

    if dataset_mask:
        gnn_accs = [test_with_mask(train_mask)[0], test_with_mask(val_mask)[0],
                    test_with_mask(test_mask)[0]]
        print(
            f'Train: {test_with_mask(train_mask)}, Val: {test_with_mask(val_mask)}, Test: {test_with_mask(test_mask)}')
    else:
        gnn_accs = [test(train_loader)[0], test(val_loader)[0],
                    test(test_loader)[0]]
        print(f'Train: {test(train_loader)}, Val: {test(val_loader)}, Test: {test(test_loader)}')

        gnn_accs2 = [score_model(model, train_loader), score_model(model, val_loader),
                    score_model(model, test_loader)]
        print(f'Train: {gnn_accs2[0]}, Val: {gnn_accs2[1]}, Test: {gnn_accs2[2]}')

    return gnn_accs

params = {
    "learning_rate": 0.01,
    "epochs": int(os.getenv("EPOCHS", 1500)),
    "folds": 10,
    "es_patience": int(os.getenv("ES_PATIENCE", 100)),
    "batch_size": int(os.getenv("BATCH_SIZE", 1)),
    "num_layers": int(os.getenv("NUM_LAYERS", 5)),
    "hidden_units": int(os.getenv("HIDDEN_UNITS", 32)),
    "skip_connection": os.getenv("SKIP_CONNECTION", "false") == "true",
    "data_dir": "datasets/data/",
    "val_size": float(os.getenv("VAL_SIZE", 0.1)),
    "dropout": float(os.getenv("DROPOUT", 0.5)),
}

parser = argparse.ArgumentParser(
        description='Train and evaluate GIN',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
parser.add_argument("--dataset", default="MUTAG", help="Name of the dataset to run the experiment on")

args = parser.parse_args()
dataset_name = args.dataset
use_pooling = True
dataset_mask = False
if dataset_name == 'BA_2Motifs':
    dataset = SynGraphDataset(params["data_dir"] + '/datasets', dataset_name)
    dataset.data.x = dataset.data.x.to(torch.float32)
    dataset.data.x = dataset.data.x[:, :1]
    use_pooling = True

if dataset_name == 'MUTAG':
    dataset = MoleculeDataset(params["data_dir"] + '/datasets', dataset_name)
    use_pooling = True

if dataset_name == 'PROTEINS':
    dataset = TUDataset(root=params["data_dir"] + '/datasets', name=dataset_name)
    use_pooling = True

if dataset_name == 'IMDB-BINARY':
    dataset = TUDataset(root=params["data_dir"] + '/datasets', name=dataset_name, pre_transform=AddFeatures())
    use_pooling = True

if dataset_name == 'REDDIT-BINARY':
    dataset = TUDataset(root=params["data_dir"] + '/datasets', name=dataset_name, pre_transform=AddFeatures())
    use_pooling = True

if dataset_name == 'Mutagenicity':
    dataset = TUDataset(root=params["data_dir"] + '/datasets', name=dataset_name)
    use_pooling = True

if dataset_name == 'BBBP':
    dataset = MoleculeDataset(params["data_dir"] + '/datasets', dataset_name, pre_transform=FlattenY())
    use_pooling = True

if dataset_name == 'COLLAB':
    dataset = TUDataset(params["data_dir"] + '/datasets', name=dataset_name, pre_transform=AddFeatures())
    use_pooling = True

if dataset_name == 'Infection':
    benchmark = Infection(num_layers=params["num_layers"])
    dataset = CustomDataset([benchmark.create_dataset(num_nodes=1000, edge_probability=0.004) for _ in range(10)])
    use_pooling = False

if dataset_name == 'Saturation':
    benchmark = Saturation(sample_count=1, num_layers=params["num_layers"], concat_features=False, conv_type=None)
    dataset = CustomDataset([benchmark.create_dataset() for _ in range(10)])
    use_pooling = False

if dataset_name == 'BA_shapes':
    dataset = SynGraphDataset(params["data_dir"] + '/datasets', name=dataset_name)
    dataset.data.x = dataset.data.x.to(torch.float32)
    dataset.data.x = dataset.data.x[:, :1]
    use_pooling = False
    dataset_mask = True

if dataset_name == 'Tree_Cycle':
    dataset = SynGraphDataset(params["data_dir"] + '/datasets', name=dataset_name)
    dataset.data.x = dataset.data.x.to(torch.float32)
    dataset.data.x = dataset.data.x[:, :1]
    use_pooling = False
    dataset_mask = True

if dataset_name == 'Tree_Grid':
    dataset = SynGraphDataset(params["data_dir"] + '/datasets', name=dataset_name)
    dataset.data.x = dataset.data.x.to(torch.float32)
    dataset.data.x = dataset.data.x[:, :1]
    use_pooling = False
    dataset_mask = True

if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
    dataset = Planetoid(params["data_dir"] + '/datasets', name=dataset_name)
    dataset.data.x = dataset.data.x.to(torch.float32)
    use_pooling = False
    dataset_mask = True

from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T

if dataset_name == "OGBA":
    dataset = PygNodePropPredDataset("ogbn-arxiv", root="datasets/", transform=T.ToUndirected())
    dataset.data.x = dataset.data.x.to(torch.float32)
    use_pooling = False
    dataset_mask = True
    dataset.data.y = dataset.data.y.squeeze(1)
    #dataset[0].y = dataset[0].y.squeeze(1)

if dataset_name == "OGB-molhiv":
    dataset = PygGraphPropPredDataset("ogbg-molhiv", root="datasets/")
    dataset.data.x = dataset.data.x.to(torch.float32)
    use_pooling = True
    dataset_mask = False
    dataset.data.y = dataset.data.y.squeeze(1)
    #dataset[0].y = dataset[0].y.squeeze(1)

train_eval_model_cv(dataset, dataset_name, params, val_size=params["val_size"], use_pooling=use_pooling,
                    folds=params["folds"], dataset_mask=dataset_mask)

