#!/usr/bin/python3
import numpy as np
import random
import torch

import hydra
from omegaconf import OmegaConf
import wandb

# Set Backends
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Set Seed
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
random.seed(0)

import os, sys, time

from torch.nn import LeakyReLU, Linear, Module, ModuleList, Parameter, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import WebKB, Planetoid
import torch_geometric.transforms as T

from torch_geometric.nn import GCNConv, GCN2Conv
from sct_gnn import SCT, SCT_Resid, SCT_Scalar
from sct_gnn.smoothness import NodeFeatureSmoothness
from sct_gnn.kernel_vectors import kernel_vectors, ker_lapl
sys.path.append('/root/workspace/sct-gnn/baselines/')
from egnn import EGNNConv

import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams["figure.figsize"] = (16,9)
plt.rcParams["font.size"] = 45
# # plt.rcParams["font.weight"] = 'bold'
plt.rcParams["xtick.color"] = 'black'
plt.rcParams["ytick.color"] = 'black'
plt.rcParams["axes.edgecolor"] = 'black'
plt.rcParams["axes.linewidth"] = 1


"""
Pyg Planetoid Public Splits
    Datasets: 
        Planetoid: Cora, Citeseer, PubMed
"""

import torch.nn as nn
class SReLU(nn.Module):
    """Shifted ReLU"""

    def __init__(self, nc, bias):
        super(SReLU, self).__init__()
        self.srelu_bias = nn.Parameter(torch.Tensor(nc,))
        self.srelu_relu = nn.ReLU(inplace=True)
        nn.init.constant_(self.srelu_bias, bias)

    def forward(self, x):
        return self.srelu_relu(x - self.srelu_bias) + self.srelu_bias

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Models
#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_architecture(Module):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__()
        self.fc_enc = Linear(in_channels,hidden_channels)
        self.gcn_list = ModuleList([])
        self.fc_dec = Linear(hidden_channels,out_channels)
        if act_fn == 'relu':
            self.act = ReLU()
        elif act_fn == 'srelu':
            self.act = SReLU(hidden_channels, -10)
        elif act_fn == 'leaky':
            self.act = LeakyReLU()
        self.dropout = dropout
        self.layers = hidden_layers
        self.reg = 0
        self.grad_listener = []
        self.x_norm = []
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters())
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters())
        pass

    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        for i in range(self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        return x

    def x_grad(self,grads):
        self.grad_listener[-1].append(torch.norm(grads).item())

class gcn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers) ])
        for mod in self.gcn_list:
            mod.lin.weight.register_hook(self.x_grad)
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)

        if self.training:
            self.grad_listener.append([])
            self.x_norm.append([])

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)

            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            _layers.append(x)

            # if self.training: pointer.register_hook(self.x_grad)
            # if self.training: x.register_hook(self.x_grad)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers
        self.reg = 0
        return x

    def x_grad(self,grads):
        self.grad_listener[-1].append(torch.norm(grads).item())

class gcnii(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, largest=True) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        for module in self.gcn_list:
            module.weight1.register_hook(self.x_grad)
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        if self.training: self.grad_listener.append([])
        if self.training: self.x_norm.append([])

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            x = self.act(ax)
            if self.training: self.x_norm[-1].append(torch.norm(x.detach()).item())
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        self.reg = 0
        return x

    def x_grad(self,grads):
        self.grad_listener[-1].append(torch.norm(grads).item())

class egnn(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.loss_weight = 20
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)
        for module in self.gcn_list:
            module.weight.register_hook(self.x_grad)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        if self.training: self.grad_listener.append([])
        if self.training: self.x_norm.append([])

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            x = self.act(ax)
            if self.training: self.x_norm[-1].append(torch.norm(x.detach()).item())
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal

        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------

class fixed_bias_architecture(fixed_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        self.bias_list = ModuleList([])
        pass

    def reset_parameters(self):
        self.graph_params = list(self.gcn_list.parameters()) 
        self.fc_params = list(self.fc_enc.parameters()) + list(self.fc_dec.parameters()) + list(self.bias_list.parameters())
        pass

class gcns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True) # jb: get ker
        gcn_list = [ GCNConv(hidden_channels,hidden_channels,normalize=False,cached=True,bias=False) for _ in range(hidden_layers-1) ]
        gcn_list.append(GCNConv(hidden_channels,out_channels,normalize=False,cached=True,bias=False))
        self.gcn_list = ModuleList(gcn_list)
        bias_list = [ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers-1) ]
        bias_list.append(SCT(ker_vecs.shape[0], out_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) )
        self.bias_list = ModuleList(bias_list)
        for mod in self.gcn_list:
            mod.lin.weight.register_hook(self.x_grad)
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        if self.training: self.grad_listener.append([])
        if self.training: self.x_norm.append([])

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x, edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=ax, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            if self.training: self.x_norm[-1].append(torch.norm(x.detach()).item())
            _layers.append(x)

        self._layers = _layers
        self.reg = 0
        return x

class gcniis(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, largest=False) # jb: get ker
        self.gcn_list = ModuleList([ GCN2Conv(hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, normalize=False, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        self.bias_list = ModuleList([ SCT(ker_vecs.shape[0],hidden_channels, ker_vecs=ker_vecs, indicators=None, cached=True) for i in range(hidden_layers) ])
        for module in self.gcn_list:
            module.weight1.register_hook(self.x_grad)
        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        if self.training: self.grad_listener.append([])
        if self.training: self.x_norm.append([])

        for i in range(0, self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            if self.training: self.x_norm[-1].append(torch.norm(x.detach()).item())
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        self.reg = 0
        return x

class egnns(fixed_bias_architecture):
    def __init__(self, in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config, data):
        super().__init__(in_channels, hidden_channels, hidden_layers, out_channels, dropout, act_fn, config)
        indicators, deg, ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, return_all=True) # jb: get ker
        self.bias_list = ModuleList([ SCT_Resid(ker_vecs.shape[0],hidden_channels, alpha=config['alpha'], theta=config['theta'], layer=i+1, ker_vecs=ker_vecs, indicators=indicators, cached=True) for i in range(hidden_layers) ])
        # self.bias_list = ModuleList([ SCT(ker_vecs.shape[0], hidden_channels, ker_vecs=ker_vecs, indicators=indicators, cached=True) for _ in range(hidden_layers) ])
        self.gcn_list = ModuleList([ EGNNConv(hidden_channels,hidden_channels,c_max=config['c_max'],normalize=False,cached=True) for _ in range(hidden_layers) ])
        self.beta = config['beta']
        self.theta = config['theta']

        self.loss_weight = 20
        self.weight_standard = Parameter(torch.eye(hidden_channels), requires_grad=False)
        self.weight_first_layer = Parameter(torch.eye(hidden_channels) * np.sqrt(config['c_max']), requires_grad=False)
        for module in self.gcn_list:
            module.weight.register_hook(self.x_grad)

        self.reset_parameters()
        pass
    def forward(self, x, edge_index, edge_weight):
        _layers = []
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.act(self.fc_enc(x))
        _layers.append(x)
        if self.training: self.grad_listener.append([])
        if self.training: self.x_norm.append([])

        for i in range(0,self.layers):
            x = F.dropout(x, self.dropout, training=self.training)
            ax = self.gcn_list[i](x=x, x_0=_layers[0], edge_index=edge_index, edge_weight=edge_weight, beta=self.beta, residual_weight=self.theta)
            b = self.bias_list[i](x=x, x0=_layers[0], edge_index=edge_index, edge_weight=edge_weight) #add bias
            x = self.act(ax+b)
            if self.training: self.x_norm[-1].append(torch.norm(x.detach()).item())
            _layers.append(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = self.fc_dec(x)
        self._layers = _layers

        if self.training:
            loss_orthogonal = 0.
            loss_orthogonal += torch.norm(self.gcn_list[0].weight - self.weight_first_layer)
            for i in range(1, self.layers):
                loss_orthogonal += torch.norm(self.gcn_list[i].weight - self.weight_standard)

            self.reg =  self.loss_weight * loss_orthogonal

        return x

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Helper
#----------------------------------------------------------------------------------------------------------------------------------------------------

def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


def plot_grads(cfg, model):
    grads = np.array(model.grad_listener).T
    norms = model.x_norm

    # GCNII
    # plt.imshow(grads, vmin=0.0, vmax=0.00031, origin='lower')
    # cbar = plt.colorbar(fraction=0.015, pad=0.05)
    # cbar.set_ticks([0.0000, 0.0001,  0.0002, 0.0003])

    # GCN
    # plt.imshow(grads, vmin=0.0, vmax=1e-10, origin='lower')
    # plt.colorbar(fraction=0.015, pad=0.05)

    #EGNN
    plt.imshow(grads, vmin=19.9, vmax=20.1, origin='lower')
    plt.colorbar(fraction=0.015, pad=0.05)

    plt.xticks([0,24,49,74,99],[1,25,50,75,100])
    plt.yticks([0, 15, 31],[1, 16, 32])
    plt.xlabel('Epoch ($n$)',color='black')
    plt.ylabel('Layer ($k$)',color='black')
    plt.savefig(f'/root/workspace/out/sct_gnn/{cfg.model["name"]}_grads.pdf',format='pdf',bbox_inches='tight')
    plt.close()
    exit()


#----------------------------------------------------------------------------------------------------------------------------------------------------
# Config/Model/Dataset
#----------------------------------------------------------------------------------------------------------------------------------------------------

def setup(cfg):
    # Set device
    args = cfg.setup
    cfg['setup']['device'] = args['device'] if torch.cuda.is_available() else 'cpu'
    os.environ["WANDB_DIR"] = os.path.abspath(args['wandb_dir'])
    # Change file name for sweeping *Prior to setting seed*
    if args['sweep']:
        run_id = wandb.run.id
        cfg['load']['checkpoint_path']=cfg['load']['checkpoint_path'][:-3]+str(run_id)+'.pt'
    # Set Seed
    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

def load(data_name):
    transform = T.Compose([T.NormalizeFeatures(), T.GCNNorm()])
    # transform = T.Compose([T.NormalizeFeatures()])

    if data_name in ['cornell','texas','wisconsin']:
        print(f'WebKB: {data_name}')
        dataset = WebKB(
                root='/root/workspace/data/'+data_name,
                name=data_name,
                transform=transform
        )

    elif data_name in ['citeseer','cora','pubmed']:
        print(f'Planetoid: {data_name}')
        dataset = Planetoid(
                root='/root/workspace/data/'+data_name,
                name=data_name,
                split='public',
                transform=transform,
        )
    
    return dataset

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Train/Validate/Test
#----------------------------------------------------------------------------------------------------------------------------------------------------

def train(cfg, data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask]) + model.reg
    loss.backward()
    optimizer.step()
    acc = accuracy(output[data.train_mask],data.y[data.train_mask])
    return loss.item(), acc


@torch.no_grad()
def validate(cfg, data, model):
    model.eval()
    output = model(data.x,data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.val_mask], data.y[data.val_mask])
    
    acc = accuracy(output[data.val_mask],data.y[data.val_mask])
    return loss.item(), acc


@torch.no_grad()
def test(cfg, data, model):
    checkpoint = torch.load(cfg.load['checkpoint_path'])
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    output = model(data.x, data.edge_index, data.edge_weight)
    output = F.log_softmax(output, dim=1)
    loss = F.nll_loss(output[data.test_mask], data.y[data.test_mask])
    
    acc = accuracy(output[data.test_mask],data.y[data.test_mask])

    if cfg.setup['plot']:
        measure = NodeFeatureSmoothness(cached=True)
        lyr_msr = []
        for i,y in enumerate(model._layers):
            lyr_msr.append(measure(y, data.edge_index, edge_weight = data.edge_weight))
        lyr_msr = np.array(lyr_msr)
        plt.imshow(lyr_msr[:,:,0], vmin=0,vmax=1)#, cmap='Blues')
        plt.colorbar(fraction=.03)
        plt.xlabel('Feature ($n$)',color='black')
        plt.xticks([0, 7, 15],[1, 8, 16])
        plt.yticks([0, 15, 31],[1, 16, 32])
        plt.ylabel('Layer ($k$)',color='black')
        name = cfg.load['dataset']+'-'+cfg.model['name']+'-'+str(cfg.model['hidden_layers'])
        plt.savefig(f'/root/workspace/out/sct_gnn/{name}.pdf',format='pdf',bbox_inches='tight')

    return loss.item(), acc

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main/Hydra/Fold/Train
#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_training(cfg, data, model):
    args = cfg.train
    optimizer = optim.Adam([{'params':model.graph_params, 'weight_decay':cfg.train['wd1']},
                    {'params':model.fc_params, 'weight_decay':cfg.train['wd2']}],
                    lr=cfg.train['lr'])

    best = 1e8
    acc, loss = 0, 1e8
    bad_itr = 0

    for epoch in range(args['epochs']):
        start = time.time()
        train_loss, train_acc = train(cfg, data, model, optimizer)
        val_loss, val_acc = validate(cfg, data, model)

        perf_metric = val_loss

        if perf_metric < best:
            best = perf_metric
            acc, loss = val_acc, val_loss
            bad_itr = 0
            torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'acc': val_acc},
                cfg.load['checkpoint_path']
            )
        else:
            bad_itr += 1
        # Log results
        t_time = time.time() - start
        wandb.log({'epoch':epoch,
            'train_loss':train_loss,
            'train_acc':train_acc,
            'val_loss':val_loss,
            'val_acc':val_acc,
            'perf_metric':perf_metric,
            'best':best,
            'time':t_time})
        print(f'Epoch({epoch:03d}) | train({train_loss:.4f},{100*train_acc:.2f}) '
            f'| val({val_loss:.4f},{100*val_acc:.2f}) '+
            f'| perf({perf_metric:.3f}) | time({t_time:.2f})')

        if bad_itr>args['patience']:
            break

    plot_grads(cfg, model)
    exit()
    return loss, acc

#----------------------------------------------------------------------------------------------------------------------------------------------------

def run_folds(cfg):
    dataset = load(cfg.load['dataset'])
    data = dataset[0].to(cfg.setup['device'])
    print(f'Splits: '
        f'train({sum(data.train_mask)})'
        f'\tval({sum(data.val_mask)})'
        f'\ttest({sum(data.test_mask)})'
        f'\ttrain_pc({sum(data.train_mask)/dataset.num_classes:.2f})'
    )

    model_name = cfg.model['name']
    model = globals().get(model_name)

    dropout = cfg.model['dropout']
    hidden_channels = cfg.model['hidden_channels']
    hidden_layers = cfg.model['hidden_layers']
    model = model(dataset.num_features, hidden_channels, hidden_layers, dataset.num_classes, dropout, cfg.model['act_fn'], cfg.model, data)
    model.to(cfg.setup['device'])
    print(model)

    # Train
    if cfg.setup['train']:
        val_loss, val_acc = run_training(cfg, data, model)
        print(f'val({100*val_acc:2.2f})')

    # Test
    test_loss, test_acc = test(cfg, data, model)
    print(f'test({100*test_acc:2.2f})')
        
    if cfg.setup['train']: #log if training
        wandb.log({'val_loss':val_loss,
            'val_acc':val_acc,
            'test_loss':test_loss,
            'test_acc':test_acc})

    return 1

#----------------------------------------------------------------------------------------------------------------------------------------------------

@hydra.main(version_base=None, config_path="/root/workspace/sct-gnn/config/", config_name="pyg-planetoid")
def run_synthetic(cfg):

    #arg handling to nicely configure wandb

    wandb.init(
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
        entity='',
        mode='disabled',
        name=cfg.model['name']+'-'+cfg.model['act_fn']+'-'+cfg.load['dataset'],
        project='sct-gnn',
        tags=['test', cfg.model['name'], cfg.load['dataset'], 'public'],
    )

    setup(cfg)
    print(OmegaConf.to_yaml(cfg))
    run_folds(cfg)

if __name__ == '__main__':
    run_synthetic()
