# Transformed CNNS: Converting pre-trained ResNets to functionally equivalent hybrid models

This repository contains PyTorch code for Transformed CNNs. It is adapted from the repository of the Convolutional Vision Transformer (ConViT), available https://github.com/facebookresearch/convit.

# Usage
Install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models):
```
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2
```

## Training
To check the reparametrization is exact, use
``
python mapping.py
``

To reparametrize then fine-tune a pre-trained ResNet50 on ImageNet, use
```
python main.py --model resnet50 --transform_model 1 --pretrained 1 --epochs 50 --data-path /path/to/imagenet
```

To run a ResNet50 from scratch for 100 epochs then reparametrize and resume training for 100 epochs, use
```
python main.py --model resnet50 --transform_at 100 --pretrained 0 --epochs 200 --data-path /path/to/imagenet
```

## Data preparation
Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively:
```
/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg
```