## Representation Learning for Mendelian Randomization 

This folder contains code for the simulations in the paper. 

## Installation

First, install the required dependencies:

```bash
pip install -r requirements.txt
```

## Training Example

The `conf/config.yaml` file contains default settings to run the training script `train.py`. The args defined there can be overwritten. 
An example of running the training script with a grid of hyperparameters for the loss function with slurm is the following:

```bash
#!/bin/bash
#SBATCH --job-name=mdcrl_grid
#SBATCH --array=0-11
#SBATCH --time=2:00:00
#... other sbatch args ...

# --- Hyperparameter Arrays ---
lam1_vals=(1 5 10)
lam2_vals=(0 1 5 10)

# --- Array indices for the grid ---
# Task ID 0 -> lam1=1, lam2=0
# Task ID 1 -> lam1=1, lam2=1
# Task ID 2 -> lam1=1, lam2=5
# Task ID 3 -> lam1=1, lam2=10
# Task ID 4 -> lam1=5, lam2=0
# Task ID 5 -> lam1=5, lam2=1
# Task ID 6 -> lam1=5, lam2=5
# Task ID 7 -> lam1=5, lam2=10
# Task ID 8 -> lam1=10, lam2=0
# Task ID 9 -> lam1=10, lam2=1
# Task ID 10 -> lam1=10, lam2=5
# Task ID 11 -> lam1=10, lam2=10

idx=${SLURM_ARRAY_TASK_ID:-0}
len2=${#lam2_vals[@]} # Automatically gets '4'
idx1=$((idx / len2))
idx2=$((idx % len2))

L1=${lam1_vals[$idx1]}
L2=${lam2_vals[$idx2]}

# --- Execute with Overrides ---
python train.py \
    data_seed="${DATA_SEED}" \
    mix_seed="${MIX_SEED}" \
    data.mixing_type="${MIX_TYPE}" \
    data.polymix_degree="${POLYMIX_DEGREE}" \
    data.invmlp_actfun="${INVMLP_ACTFUN:-leaky_relu}"\
    data.dim_v_true="${DV_TRUE}" \
    data.dim_w_true="${DW_TRUE}" \
    data.dim_z="${DZ}" \
    data.n_pop="${NPOP}" \
    data.n_train="${NTRAIN}" \
    encoder="${ENC_TYPE}" \
    decoder="${DEC_TYPE}" \
    model.dim_v="${DV}" \
    model.dim_w="${DW}" \
    loss.inv_loss_type="${INV_LOSS_TYPE:-poly}" \
    loss.inv_ker_poly_degree="${INV_KER_POLY_DEGREE:-2}" \
    loss.ind_loss_type="${IND_LOSS_TYPE:-poly}" \
    loss.ind_ker_poly_degree="${IND_KER_POLY_DEGREE:-2}" \
    loss.lam1="${L1}" \
    loss.lam2="${L2}" \
    trainer.max_epochs="${MAX_EPOCHS}" \
    exp_id="${EXP_NAME}" \
    sim_id="${idx}" \
    ${RESUME_ARGS}
```

## Evaluation

Evaluating the result can be done by calling `summarize_seeds.py`, for example:

```bash
#!/bin/bash

CKPT_STRATEGY="last" # checkpoint strategy (last or best)
METRIC_KEY="val/inv_loss" # metric to choose the best model
EXP_GRP="new_normalclamppolymix3" # folder name to group results
EXCLUDE_SIM_IDS=(0 4 8) # e.g., exclude lam2 = 0 case
DATA_SEEDS=({42..61})

# List of experiment IDs
EXP_IDS=(       "new_normalclamppolymix3_mlpnormenc_meanvarinv_polyind_ms100"
"new_normalclamppolymix3_mlpnormenc_poly2inv_polyind_ms100"
"new_normalclamppolymix3_mlpnormenc_poly3inv_polyind_ms100"
"new_normalclamppolymix3_mlpnormenc_rbfinv_polyind_ms100"
)

# Loop through each experiment ID and run the summary script
for EXP_ID in "${EXP_IDS[@]}"; do
    echo "Summarizing seeds for: $EXP_ID"

    args=(
        --exp_id "$EXP_ID"
        --exp_grp "$EXP_GRP"
        --ckpt_strategy "$CKPT_STRATEGY"
        --metric_key "$METRIC_KEY"
        --data_seeds "${DATA_SEEDS[@]}"
        --exclude_sim_ids "${EXCLUDE_SIM_IDS[@]}"
    )

    # Run the command using the array
    python summarize_seeds.py "${args[@]}"
        
    echo "--------------------------------------"
done
```

## Summarize Results

The figures in the paper are produced by `make_plots.py`. 
