# Strucutured Approximations
Using structured approximations to dense layers.


### Installation
```
source setup.sh
```
This creates two conda envs named `gpt` and `imagenet`. 

### ImageNet setup
```
echo "export IMAGENET_DIR=<path to imagenet folder>" >> ~/.bashrc
echo "export IMAGENET_FFCV_DIR=<path to imagenet ffcv dataset folder>" >> ~/.bashrc
cd imagenet
./write_imagenet.sh 500 0.50 90
```

### WANDB setup
```
echo "export WANDB_API_KEY=<wandb api key>" >> ~/.bashrc
echo "export WANDB_DIR=<dir for wandb logs>" >> ~/.bashrc
echo "export WANDB__SERVICE_WAIT=300" >> ~/.bashrc
```

### Helper functions to add in .bashrc for automating sweeps on a multi-gpu machine
```
get_free_gpu() {
    # Get the list of GPU IDs to check, or default to all GPUs
    local IDs=${1:-$(nvidia-smi --query-gpu=index --format=csv,noheader,nounits | tr '\n' ' ')}

    for ID in $IDs; do
        MEM_USED=$(nvidia-smi --id=$ID --query-gpu=memory.used --format=csv,nounits,noheader)
        if [ "$MEM_USED" -lt 1000 ]; then
            echo "$ID"
            return 0
        fi
    done

    sleep 10
    get_free_gpu "$IDs"
}

kg() {
    # Kill all processes running on the specified GPU(s)
    if [[ "$#" -eq 0 ]]; then
        echo "Usage: kg <GPU_ID_1> [<GPU_ID_2> ...]"
        return 1
    fi

    local username="$(whoami)"  # Get the current username

    for gpu_id in "$@"; do  # Iterate over all arguments
        echo "Checking for processes on GPU $gpu_id by user $username..."

        # Get the list of process IDs using the GPU
        local pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader,nounits -i "$gpu_id" | grep -v Not)

        # Loop through each PID and kill the process if owned by the current user
        for pid in $pids; do
            # Check if the process belongs to the current user
            local owner=$(ps -o user= -p "$pid")
            if [[ "$owner" == "$username" ]]; then
                kill -9 "$pid"  # Forcefully kill the process
                echo "Killed process $pid on GPU $gpu_id"
            else
                echo "Process $pid on GPU $gpu_id is owned by $owner, not $username; skipping."
            fi
        done
    done
}
```


### Commands

To run a quick diagnostic run
```shell
py train_mlp.py --no-wandb --dataset=cifar10 --model=MLP --width=256 --depth=3 --lr=5e-5 --batch_size=1024 --resolution=32 --struct=fixed_kron_sum --layers=all_but_last --kron_mult=1 --scheduler=cosine
```

```shell
py train_cola_mlp.py --no-wandb --dataset=cifar10 --model=MLP --width=8192 --depth=9 --lr=5e-5 --batch_size=1024 --resolution=32 --struct=scaled_block_tt --tt_dim=2 --tt_rank=128 --layers=all_but_last --scheduler=cosine
```

### Run stuff

```shell
python3 imagenet_train.py /datasets/imagenet --ffn_struct block_tt --ffn_tt_dim 2 --ffn_tt_rank 1 --attention_struct block_tt --attention_tt_dim 2 --attention_tt_rank 1 --arch vit_huge_patch16 --input_resolution 224 --batch=4
```

```shell
py train_mlp.py --no-wandb --dataset=imagenet --resolution=64 --model=MLP --width=1024 --depth=9 --lr=5e-5 --batch_size=1024 --struct=low_rank --rank_frac=0.1 --layers=all_but_last --kron_mult=1 --scheduler=cosine --mixup=0
```
