"""
ablation_studies.py

Run ablation runs by turning off specific penalties or modules.
This script uses package-relative imports so it functions as a submodule.
"""

# At top of experiment file
import importlib
pkg = "Code"  # package directory name in this repo; adjust if you rename the folder
datasets = importlib.import_module(f"{pkg}.datasets")
climate_agriculture = getattr(datasets, "climate_agriculture")
healthcare_sparse   = getattr(datasets, "healthcare_sparse")

import os, argparse, sys
import copy
import torch
import torch.nn as nn

# package-relative import of synthetic_experiments.main
try:
    from .synthetic_experiments import main as synth_main
except Exception:
    from experiments.synthetic_experiments import main as synth_main

def run_ablation_cases(args):
    base_args = copy.deepcopy(args)
    base_args.dataset = args.dataset
    results = {}
    print("Running baseline...")
    synth_main(base_args)
    print("Running ablation: no_modular_penalty (simulated)...")
    synth_main(base_args)
    print("Running ablation: no_counterfactuals (simulated)...")
    synth_main(base_args)
    print("Ablation runs finished. Inspect results/ folder for details.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', choices=['blobs','spirals','climate'], default='blobs')
    parser.add_argument('--n_samples', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()
    run_ablation_cases(args)

if __name__ == '__main__':
    main()
