import torch
from args import parse_arguments
from finetune_functions import finetune_ft, finetune_lpips, finetune_flyp, finetune_flyp_ce
from utils import epochs


def finetune(args, finetuned_encoder=None):
    if (args.finetune_loss in ['ce', 'ls', 'l2sp']):
        finetune_ft(args, finetuned_image_encoder=finetuned_encoder)
    elif (args.finetune_loss == 'lpips'):
        finetune_lpips(args)
    elif (args.finetune_loss == 'flyp'):
        finetune_flyp(args, finetuned_clip_encoder=finetuned_encoder)
    elif (args.finetune_loss == 'flyp_ce'):
        finetune_flyp_ce(args, finetuned_clip_encoder=finetuned_encoder)



if __name__ == '__main__':

    args = parse_arguments()
    print('='*100)
    print(f'Finetuning {args.model} on {args.train_dataset}')
    print('='*100)
    args.lr = 1e-5
    args.epochs = epochs[args.train_dataset]
    args.train_dataset = args.train_dataset + 'Val'
    args.batch_size = 128

    if (args.model_checkpoint_path is not None):
        finetuned_checkpoint = torch.load(args.model_checkpoint_path)
        if (args.finetune_loss in ['ce', 'ls', 'lpips', 'l2sp']):
            image_encoder = finetuned_checkpoint.image_encoder
        elif (args.finetune_loss in ['flyp', 'flyp_ce']):
            image_encoder = finetuned_checkpoint
        print ('Checkpoint loaded!')
    else:
        image_encoder = None
        print ('No finetuned checkpoint')

    print (args)
    finetune(args, finetuned_encoder=image_encoder)