local task = std.extVar("task");
local dataset = std.extVar("dataset");
local downsample = 1;
local n_layer = 2;
local seq_lens = {
    Rest: 375,
    MID: 403,
    SST: 437,
    nBack: 362,
    "[TSK]": 1
};
local seq_len = std.get(seq_lens, task);
local data_path = "anonymous/%s/%s/aal3v1.h5" % [dataset, task];
local model = std.extVar("model");

local dataset_obj = {
                    type: "brain",
                    downsample: downsample,
                    truncate: null,
                    feature: "correlation",
                };

{
    steps: {
        logger: {
            type: "logger",
            group_name: model,
        },
        compute_entropy: {
            type: "compute_entropy",
            data_module: {
                type: "brain",
                file_path: data_path,
                dataset: {
                    type: "brain",
                    downsample: 5,
                    truncate: 1
                },
                batch_size: 512,
                num_workers: 10
            },
        },
        train: {
            type: "train",
            load_name: null,
            seed: std.parseInt(std.extVar("seed")),
            trainer: {
                max_steps: 10000,
                check_val_every_n_epoch: 10,
                accelerator: "auto",
                // fast_dev_run: 50
            },
            model: {
                type: "regressor",
                model: {
                    type: model,
                    n_mask: std.parseInt(std.extVar("n_mask")),
                    tasks: task,
                    hidden_size: 1024,
                    transformer_type: "bernoulli",
                    agg_strategy: "sum"
                },
                learning_rate: 1e-3,
                node_entropy: {
                    type: "ref",
                    ref: "compute_entropy"
                    },
                beta: std.parseJson(std.extVar("beta"))
            },
            data_module: {
                type: "brain",
                file_path: data_path,
                dataset: dataset_obj,
                batch_size: 64,
                num_workers: 10
            },
            logger: {
                type: "ref",
                ref: "logger"
            },
        },
        cpm: {
            type: "cpm",
            run_name: {
                type: "ref",
                ref: "train"
            },
            datasets: dataset, 
            tasks: task,
            path: "anonymous/src/results/hyperedges/" + '[RUN]_' + dataset +  "_" + task + ".mat",
            pairwise_path: "anonymous/results/pairwise3/%s_%s_aal3v1.mat" % [dataset, task],
            train_test_split_path: "anonymous/results/pairwise/%s_%s_traintest.mat" % [dataset, task] ,
        }
    },
}
