This is the PyTorch implementation of our paper: Contrastive Conditional Transport for Representation Learning

1.Requirements:
pytorch==1.6.0;
tensorboard==2.3.0;
detectron2==0.3;

2. Notations
(1) The codes of CIFAR-10,CIFAR-100,STL-10 and ImageNet are in ``examples''.
(2) The CACR loss of small-scale datasets is in ``losses''

3. Example Usage
(1) Run the codes on CIFAR-10 (./examples/cifar10/)
Training: python main_ours.py --Ny 4 --Ns 128 --alpha 1.0 --beta 1.0 --tau_pos 1.0 --tau_neg 0.9 --gpus [gpu_num]
Testing: python linear_eval.py --gpus [gpu_num] --encoder_checkpoint [checkpoint_pth]

(2) Run the codes on CIFAR-100 (./examples/cifar100/)
Training: python main_ours.py --Ny 4 --Ns 128 --alpha 1.0 --beta 1.0 --tau_pos 1.0 --tau_neg 2.0 --gpus [gpu_num]
Testing: python linear_eval.py --gpus [gpu_num] --encoder_checkpoint [checkpoint_pth]

(3) Run the codes on STL-10 (./examples/stl10/)
Training: python main_ours.py --Ny 4 --Ns 128 --alpha 1.0 --beta 1.0 --tau_pos 1.0 --tau_neg 2.0 --gpus [gpu_num]
Testing: python linear_eval.py --gpus [gpu_num] --encoder_checkpoint [checkpoint_pth]

(4) Run the codes on ImageNet-100 and ImageNet (./examples/imagenet/)
(4.1) prepare the dataset for ImageNet-100: python scripts/create_imagenet_subset data/imagenet data/imagenet100
(4.2) on ImageNet-100: 
Training: python main_ours.py \
                 -a resnet50 \
                 --lr 0.03 \
                 --Ny 4 \
                 --Ns 128 \
                 --gpus 0 1 2 3 4 5 6 7 \
                 --moco-k 65536 \
                 --tau_pos 1.0 \
                 --tau_neg 2.0 \
                 --alpha 1.0 --beta 1.0 \
                 --mlp --aug-plus --cos \
                 --dist-url tcp://localhost:10001 \
                 --multiprocessing-distributed \
                 --world-size 1 \
                 --rank 0 
                 data/imagenet100
Testing: python main_lincls.py \
                 -a resnet50 \
                 --lr 30.0 \
                 --batch-size 256 \
                 --gpus 0 1 2 3 4 5 6 7 \
                 --pretrained [encoder_pth] \
                 --dist-url tcp://localhost:10001 \
                 --multiprocessing-distributed \
                 --world-size 1 \
                 --rank 0 
                 data/imagenet100
(4.3) on ImageNet-1K/Webvisionv1/Imagenet-22K:
Training: python main_ours.py \
                 -a resnet50 \
                 --lr 0.03 \
                 --Ny 4 \
                 --Ns 256 \
                 --gpus 0 1 2 3 4 5 6 7 \
                 --moco-k 65536 \
                 --tau_pos 1.0 \
                 --tau_neg 2.0 \
                 --alpha 1.0 --beta 1.0 \
                 --mlp --aug-plus --cos \
                 --dist-url tcp://localhost:10001 \
                 --multiprocessing-distributed \
                 --world-size 1 \
                 --rank 0 
                 data/imagenet1k (or webvision/imagenet22k)
Testing: python main_lincls.py \
                 -a resnet50 \
                 --lr 30.0 \
                 --batch-size 256 \
                 --gpus 0 1 2 3 4 5 6 7 \
                 --pretrained [encoder_pth] \
                 --dist-url tcp://localhost:10001 \
                 --multiprocessing-distributed \
                 --world-size 1 \
                 --rank 0 
                 data/imagenet1k
(webvisionv1 code is provided in the subfolder "webvision", for ImageNet-22K, you only need to change the data path to the corresponding data folder.)

(5) Detection and Segemntation: After you trained the encoder on ImageNet, you can follow the usage in MoCo to conduct the detection and segmentation task.


4.Acknowledgements
This code benefits a lot from previous works, we would like to thank them here. 
Furthermore, we will release the pretrained encoder on ImageNet later.



