# Example usage python src/fine_tune_chosen_layer.py --model-arch b32 --checkpoint-path weights/pytorch/ --pretrain_model_name imagenet21k+imagenet2012_ViT-B_32.pth --image-size 384 --batch-size 32 --dataset CIFAR10 --num-classes 10 --int_layers 0 1
import sys
import pickle
import torchvision
import torchvision.transforms as transforms
import torch
import random
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10, MNIST, CIFAR100, FGVCAircraft
from model import VisionTransformer
from config import get_ft_config
from checkpoint import load_checkpoint
import torch.nn as nn
import torch.optim as optim
    
def train_epoch(config, FModel, dataloader, optimizer, scheduler, criterion, cur_epoch):
    
    set_to_train(FModel)
    # One epoch
    top1 = AverageMeter()
    top5 = AverageMeter()
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(config.device), targets.to(config.device)
        
        outputs = FModel(inputs)
        
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # print statistics
        running_loss += loss.item()
        if batch_idx % 100 == 0:    # print every 100 mini-batches
            print(f'[Epoch {cur_epoch}, {batch_idx + 1:5d}] loss: {running_loss / 100 / config.batch_size:.3f}, acc: {top1.avg}')
            running_loss = 0.0
        
    return top1.avg, top5.avg

def evaluate_epoch(config, FModel, dataloader, criterion):
    
    FModel.eval()
    # One epoch
    test_top1 = AverageMeter()
    test_top5 = AverageMeter()
    running_loss = 0.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):

            inputs, targets = inputs.to(config.device), targets.to(config.device)
            
            with torch.no_grad():
                outputs = FModel(inputs)

            loss = criterion(outputs, targets)

            prec1, prec5 = compute_accuracy(outputs.data, targets.data, topk=(1, 5))
            test_top1.update(prec1.item(), inputs.size(0))
            test_top5.update(prec5.item(), inputs.size(0))

    return test_top1.avg, test_top5.avg

def set_to_train(model):
    model.train()
    
def compute_accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class AverageMeter(object):
    """Computes and stores the average and current value
       Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def set_seed(manualSeed=666):
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(manualSeed)
    

def main():
    config = get_ft_config()

    if config.checkpoint_path is None:
        sys.exit('Need to input the path to a pre-trained model!')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    config.device = device

    set_seed(config.seed)
    
    # Dataset part
    if config.dataset == "aircraft":
        transform_train = transforms.Compose([
            transforms.Resize([config.image_size, config.image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize([config.image_size, config.image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.FGVCAircraft(
            root=config.data_dir, split="trainval", download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=config.batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.FGVCAircraft(
            root=config.data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=config.batch_size, shuffle=False, num_workers=1)
        
    elif config.dataset == "CIFAR10":
        transform_train = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.CIFAR10(
            root=config.data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=config.batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR10(
            root=config.data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=config.batch_size, shuffle=False, num_workers=1)
    elif config.dataset == "CIFAR100":
        transform_train = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(config.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.CIFAR100(
            root=config.data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=config.batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR100(
            root=config.data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=config.batch_size, shuffle=False, num_workers=1)
    else:
        raise ValueError("That dataset is not yet implemented!")
    
    # create model
    print("create model")
    model = VisionTransformer(
             image_size=(config.image_size, config.image_size),
             patch_size=(config.patch_size, config.patch_size),
             emb_dim=config.emb_dim,
             mlp_dim=config.mlp_dim,
             num_heads=config.num_heads,
             num_layers=config.num_layers,
             num_classes=config.num_classes,
             attn_dropout_rate=config.attn_dropout_rate,
             dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.checkpoint_path:
        state_dict = load_checkpoint(config.checkpoint_path + config.pretrain_model_name)
        if config.num_classes != state_dict['classifier.weight'].size(0):
            del state_dict['classifier.weight']
            del state_dict['classifier.bias']
            print("re-initialize fc layer")
            model.load_state_dict(state_dict, strict=False)
        else:
            model.load_state_dict(state_dict)
        print("Load pretrained weights from {}".format(config.checkpoint_path))

    # send model to device
    model = model.to(device)
    model.eval()
    model.requires_grad = False
    
    num_layer_relation = {0: "transformer.encoder_layers.0",
                          1: "transformer.encoder_layers.1",
                          2: "transformer.encoder_layers.2",
                          3: "transformer.encoder_layers.3",
                          4: "transformer.encoder_layers.4",
                          5: "transformer.encoder_layers.5",
                          6: "transformer.encoder_layers.6",
                          7: "transformer.encoder_layers.7",
                          8: "transformer.encoder_layers.8",
                          9: "transformer.encoder_layers.9",
                          10: "transformer.encoder_layers.10",
                          11: "transformer.encoder_layers.11",
                          "all": "all"}
    
    # Now we need to set model layers of interest to require grad
    int_layers = config.int_layers
    print(len(int_layers))
    print(f"Fine tune for layers {int_layers}")
    
    layer_to_change = []
    for layer_name in int_layers:
        layer_to_change.append(num_layer_relation[int(layer_name)])
            
    need_param_list = [] # params that will be trained
    if len(int_layers) == 0:
        print("Transfer Learning")
    elif len(int_layers) == 1 and int_layers[0] == "all": # Fine tune the whole model
        print("Will fine tune the whole model")
        need_param_list = list(model.parameters())
    else:
        for int_l in layer_to_change:
            for key, mod in model.named_modules():
                if key != "" and key == int_l:
                    print(key)
                    mod.requires_grad = True
                    need_param_list = need_param_list + list(mod.parameters())
                    
    
    criterion = nn.CrossEntropyLoss()
    trainable_params = need_param_list + list(model.classifier.parameters())
    print(f"Trainable layers: model_train: {len(need_param_list)}, model_ft: {len(list(model.classifier.parameters()))}")
    # create optimizers and learning rate scheduler
    optimizer = torch.optim.SGD(
        params=trainable_params,
        lr=config.lr,
        weight_decay=config.wd,
        momentum=0.9)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=config.lr,
        pct_start=config.warmup_steps / config.train_steps,
        total_steps=config.train_steps)

    epochs = config.train_steps // len(trainloader)
    best_test_acc = 0.0
    train_accs = []
    test_accs = []
    is_best = False
    save_name = "-".join([str(i) for i in config.int_layers]) # the folder to save fine tune results
    checkpoint_dir = config.checkpoint_path + "/" + config.dataset + "_ft/" + f"fine_tune_{save_name}/"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)
        
    for epoch in range(epochs):
        train_top1, train_top5 = train_epoch(config, model, trainloader, optimizer, scheduler, criterion, epoch)
        test_top1, test_top5 = evaluate_epoch(config, model, testloader, criterion)
        train_accs.append(train_top1)
        test_accs.append(test_top1)

        if test_top1 > best_test_acc:
            is_best = True
            best_test_acc = test_top1
        else:
            is_best = False
        print(f"Finish Epoch {epoch}, Training Acc: {train_top1}, Test Acc: {test_top1}, Current best Test Acc: {best_test_acc}, current LR: {scheduler.get_lr()}")
        
        state = {
                    'arch': config.model_arch,
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'test_acc': test_top1
                }
        if is_best:
            print("Save current model (best)")
            path = checkpoint_dir + 'model_best.pth'
            torch.save(state, path)
    
    print(f"Training Accs are: {train_accs}")
    print(f"Test Accs are: {test_accs}")

if __name__ == "__main__":
    main()