## Requirements
Our codebase is implemented on H-CAST(https://github.com/pseulki/HCAST) and CHMatch (https://github.com/sailist/image-classification)
- Python: 3.10
- CUDA: 12.1
- PyTorch: 2.1.2
- DGL: 2.4.0 (for H-CAST)
- GCC: 11.2.0 (for H-CAST, Recommended to avoid errors when running DGL)

Create a conda environment with the following command:
```
# create conda env
> conda create -n hcast python=3.10
> conda activate hcast
> pip install -r requirements.txt
> pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121


# install dgl (https://www.dgl.ai/pages/start.html)
> pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html
```

## Data Preparation: 
- Download the [ImageNet (2012) dataset](https://www.image-net.org/download.php).
- Put 'imagenet-OC-train.txt', 'imagenet-OH-val.txt' files in the root directory of ImageNet dir (/data/ImageNet). 

## Training
- texts: pre-extracted text description file using Llama-3.2-11B (https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)

#### Text-Attr (H-CAST)
```
export PYTHONPATH=deit/:$PYTHONPATH

torchrun --nproc_per_node=4 deit/main_suppix_partial_cap.py \
   --model cast_small \
   --batch-size 256 \
   --epochs 200 \
   --num-superpixels 196 --num_workers 12 \
   --globalkl --gk_weight 0.5 \
   --data-set IMNET-H-SUPERPIXEL-CAP \
   --data-path /data/ImageNet \
   --output_dir ./output/text_hcast \
   --texts 'imagenethier2_modified.txt' --sim_loss_weight 1 \
   --distributed 
```

#### Text-Attr (H-ViT)
```
export PYTHONPATH=deit/:$PYTHONPATH

python deit/main_hier_partial.py \
  --model deit_small_patch16_224 \
  --batch-size 256 \
  --epochs 200 \
  --num_workers 8 \
  --data-set IMNET-H \
  --data-path /data/ImageNet \
  --output_dir ./output/text_hvit \
  --texts 'imagenethier2_modified.txt' --sim_loss_weight 1 
```


#### Taxon-SSL 
```
export PYTHONPATH=deit/:$PYTHONPATH

python deit/main_taxon_ssl.py \
  --model deit_small_patch16_224 \
  --batch-size 256 \
  --epochs 200 \
  --num_workers 8 \
  --lr 0.001 \
  --momentum 0.9 \
  --weight-decay 0.0005 \
  --data-set IMNET-H \
  --data-path /data/ImageNet \
  --output_dir ./output/taxon_ssl 
```

#### Taxon-SSL + Text-Attr
```
export PYTHONPATH=deit/:$PYTHONPATH

python deit/main_taxon_ssl.py \
  --model deit_small_patch16_224 \
  --batch-size 256 \
  --epochs 200 \
  --num_workers 8 \
  --lr 0.001 \
  --momentum 0.9 \
  --weight-decay 0.0005 \
  --data-set IMNET-H \
  --data-path /data/ImageNet \
  --output_dir ./output/taxon_ssl \
  --texts 'imagenethier2_modified.txt' --text_loss_weight 1.0
```
