# SA-FixMatch

Our code is based on the repository [FixMatch-pytorch](https://github.com/kekmodel/FixMatch-pytorch), we properly cite this implementation in our paper.

## Overview

This repository contains code for training and evaluating supervised learning (SL), FixMatch, and SA-FixMatch. The implementation for integrating SA-CutOut with other deep semi-supervised learning methods is based on the repository [Semi-supervised-learning](https://github.com/microsoft/Semi-supervised-learning) and will be made available once finalized.


## Installation

To install required packages, use the following command:
```bash
pip install -r requirements.txt
```

## Data Preparation

To prepare the data for training and evaluation, follow these steps:

1. Download the ImageNet dataset from the official website.
2. Create a directory named `data` in the root of the repository:
    ```bash
    mkdir data
    ```
3. Place the downloaded ImageNet dataset into the `data` directory. The structure should look like this:
    ```
    your-repository/
    ├── data/
    │   └── imagenet/
    │       ├── train/
    │       └── val/
    ├── train.py
    ├── evaluate.py
    └── ...
    ```

## Training

### Training Supervised Learning

To train the model using SL, use the following command:

```bash
CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch --nproc_per_node 2 ./train.py \
 --num_labeled 100000 --batch_size 64 --lr 0.03 --wdecay 0.0003 --expand_labels --seed 0 \
 --out results/imagenet@100k_sl --num_workers 8 --eval_step 2048 --total_steps 307200 \
 --supervised
```

### Training FixMatch

To train the model using FixMatch, use the following command:

```bash
CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch --nproc_per_node 2 ./train.py \
 --num_labeled 100000 --batch_size 64 --lr 0.03 --wdecay 0.0003 --expand_labels --seed 0 \
 --out results/imagenet@100k_fixmatch --num_workers 8 --eval_step 2048 --total_steps 307200
```

### Training SA-FixMatch

To train the model using SA-FixMatch, use the following command (we only use SA-CutOut for the last 32 epochs, so we resume checkpoint from FixMatch):

```bash
CUDA_VISIBLE_DEVICES='0,1' python -m torch.distributed.launch --nproc_per_node 2 ./train.py \
 --num_labeled 100000 --batch_size 64 --lr 0.03 --wdecay 0.0003 --expand_labels --seed 0 \
 --out results/imagenet@100k_safixmatch --num_workers 8 --eval_step 2048 --total_steps 307200 \
 --cam --resume results/imagenet@100k_fixmatch/checkpoint_118.pth.tar
```

## Evaluation

To evaluate the trained model, use the following command:

```bash
CUDA_VISIBLE_DEVICES='0' python evaluate.py --checkpoint results/imagenet@100k_safixmatch/model_best.pth.tar
```
