# SpASTRA: Sparse training with ASTRA and IHT

This package is still in development phase and far from maturity. Currently used to implement ASTRA and SASTRA on ResNets (Conv2D) and Linear layers.

This repo provides training scripts and utilities for sparse training using ASTRA and iterative hard-thresholding (IHT). It uses Hydra configs for experiment composition, PyTorch for training, and Weights & Biases for logging.


Key entry points:
- ASTRA training: [train_resnet_astra.py](train_resnet_astra.py)
- IHT training: [train_resnet_IHT.py](train_resnet_IHT.py)

Core APIs used by the training scripts:
- Config plumbing: [`spastra.configs.get_model`](spastra/configs.py), [`spastra.configs.get_optimizer`](spastra/configs.py), [`spastra.configs.get_lr_scheduler`](spastra/configs.py), [`spastra.configs.get_sparsity_specs`](spastra/configs.py), [`spastra.configs.get_sparsity_groups`](spastra/configs.py), [`spastra.configs.get_ema`](spastra/configs.py), [`spastra.configs.get_lambdas`](spastra/configs.py), [`spastra.configs.get_alphas`](spastra/configs.py)
- Data: [`spastra.data.get_dataloaders`](spastra/data/datasets.py)
- Algorithms: [`spastra.astra.SASTRA`](spastra/astra.py), [`spastra.astra.IHTSparsifier`](spastra/astra.py)
- Eval & stats: [`spastra.evaluate.evaluate_accuracy`](spastra/evaluate.py), [`spastra.evaluate.get_model_sparsity`](spastra/evaluate.py), [`spastra.stats.StatsCollector`](spastra/stats.py)

## Installation

- With Poetry:
  - `poetry install`
  - `poetry run python train_resnet_astra.py …`

- With pip (editable):
  - `pip install -e .`
  - `python train_resnet_astra.py …`

## Datasets

Training scripts read the dataset root from an environment variable specified by the config key `data_dir_env`. Set the environment variable before running.

Example:
```sh
export DATASETS_DIR=/path/to/datasets
# If your config uses another env var name (see configs/config.yaml), set that instead.
```

Hydra base config: [configs/config.yaml](configs/config.yaml)  
Dataset configs: 
- CIFAR-10: [configs/dataset/cifar10.yaml](configs/dataset/cifar10.yaml)  
- CIFAR-100: [configs/dataset/cifar100.yaml](configs/dataset/cifar100.yaml)

## Run a quick experiment

Recommended experiments live in [configs/experiment/](configs/experiment/). Use Hydra overrides at the CLI.

- ASTRA on CIFAR-10 with ResNet-32:
```sh
python train_resnet_astra.py experiment=cifar10_resnet32
```

- IHT on CIFAR-10 with ResNet-32:
```sh
python train_resnet_IHT.py experiment=cifar10_resnet32
```

- CIFAR-100 variant:
```sh
python train_resnet_astra.py experiment=cifar100_resnet32
```

If you prefer flat experiment files (without the `experiment=` group), examples are also provided at the top level:
- [configs/cifar10_resnet32.yaml](configs/cifar10_resnet32.yaml)
- [configs/cifar100_resnet32.yaml](configs/cifar100_resnet32.yaml)
- [configs/cifar100_resnet32-filter.yaml](configs/cifar100_resnet32-filter.yaml)

You can point Hydra at any of these with:
```sh
python train_resnet_astra.py +experiment=@cifar10_resnet32
```

## Choosing sparsity structure and level

Pick a sparsifier config from [configs/sparsifier/](configs/sparsifier/):
- Unstructured: [configs/sparsifier/unstructured.yaml](configs/sparsifier/unstructured.yaml)
- Filter/channel structured: [configs/sparsifier/filter.yaml](configs/sparsifier/filter.yaml), [configs/sparsifier/conv2d_in_channel.yaml](configs/sparsifier/conv2d_in_channel.yaml)

Override at the CLI:
```sh
# Unstructured 90% sparsity
python train_resnet_astra.py experiment=cifar10_resnet32 sparsifier=unstructured sparsifier.sparsity=0.90

# Filter-wise 95% sparsity
python train_resnet_astra.py experiment=cifar10_resnet32 sparsifier=filter sparsifier.sparsity=0.95
```

ASTRA-specific knobs (all resolved via [`spastra.configs`](spastra/configs.py)):
- EMA of gradients: [`spastra.configs.get_ema`](spastra/configs.py)
- Lambda schedules: [`spastra.configs.get_lambdas`](spastra/configs.py)
- Per-group alphas: [`spastra.configs.get_alphas`](spastra/configs.py)

IHT freeze/warmup knobs (read by the scripts):
- `sparsifier.warmup`: fraction of epochs before sparsification
- `sparsifier.freeze`: fraction of epochs after which support is frozen

## Typical overrides

```sh
# Change batch size, epochs, device
python train_resnet_astra.py experiment=cifar10_resnet32 batch_size=256 num_epochs=200 device=cuda

# Change optimizer or LR schedule through config
python train_resnet_astra.py experiment=cifar10_resnet32 optimizer.name=sgd optimizer.lr=0.1 lr_scheduler.name=cosine

# Switch dataset workers (see dataset config)
python train_resnet_astra.py experiment=cifar10_resnet32 dataset.num_workers=8
```

The training loop prints LR, loss, and accuracy, and logs to W&B:
- Accuracy: [`spastra.evaluate.evaluate_accuracy`](spastra/evaluate.py)
- Sparsity: [`spastra.evaluate.get_model_sparsity`](spastra/evaluate.py)

## Run provided shell scripts

- CIFAR-10 95%: [scripts/resnt32_cifar10_95.sh](scripts/resnt32_cifar10_95.sh)  
- CIFAR-100 95%: [scripts/resnt32_cifar100_95.sh](scripts/resnt32_cifar100_95.sh)

```sh
bash scripts/resnt32_cifar10_95.sh
```

## W&B sweeps

Predefined sweep configs live in [sweeps/](sweeps/):
- [sweeps/cifar10_resnet_90.yaml](sweeps/cifar10_resnet_90.yaml)
- [sweeps/cifar100_resnet_90.yaml](sweeps/cifar100_resnet_90.yaml)

Example flow:
```sh
# Create the sweep
wandb sweep sweeps/cifar10_resnet_90.yaml

# Then launch an agent (replace SWEEP_ID with the printed ID)
wandb agent YOUR_ENTITY/YOUR_PROJECT/SWEEP_ID
```

Alternatively, see [launch_sweep.sh](launch_sweep.sh).

## Outputs and logs

- Hydra output directories: `outputs/…`
- W&B metrics include LR, train/test accuracy, real sparsity, and stats collected by [`spastra.stats.StatsCollector`](spastra/stats.py).
- Internals of sparsity groups/specs are built by:
  - [`spastra.configs.get_sparsity_specs`](spastra/configs.py)
  - [`spastra.configs.get_sparsity_groups`](spastra/configs.py)

## Extending

- Add a model in [spastra/models](spastra/models) and wire it in [`spastra.configs.get_model`](spastra/configs.py).
- Add new sparsity specs/groups through the YAMLs in [configs/sparsifier/](configs/sparsifier/).


