import torch
from args import parse_arguments
from finetune_functions import finetune_ft, finetune_lpips, finetune_flyp, finetune_flyp_ce
from utils import epochs


def finetune(args, finetuned_encoder=None):
    if (args.finetune_loss in ['ce', 'ls', 'l2sp']):
        finetuned_image_classifier = finetune_ft(args, finetuned_image_encoder=finetuned_encoder)
    elif (args.finetune_loss == 'lpips'):
        finetuned_image_classifier = finetune_lpips(args, finetuned_image_encoder=finetuned_encoder)
    elif (args.finetune_loss == 'flyp'):
        finetuned_clip_encoder = finetune_flyp(args, finetuned_clip_encoder=finetuned_encoder)
    elif (args.finetune_loss == 'flyp_ce'):
        finetuned_clip_encoder = finetune_flyp_ce(args, finetuned_clip_encoder=finetuned_encoder)



if __name__ == '__main__':

    args = parse_arguments()
    print('='*100)
    print(f'Finetuning {args.model} on {args.first_train_dataset}')
    print('='*100)
    args.lr = 1e-5
    args.epochs = epochs[args.first_train_dataset]
    args.first_train_dataset = args.first_train_dataset + 'Val'
    args.train_dataset = args.first_train_dataset
    args.batch_size = 128

    if (args.model_checkpoint_path is not None):
        finetuned_checkpoint = torch.load(args.model_checkpoint_path)
        if (args.finetune_loss in ['ce', 'ls', 'l2sp, lpips']):
            image_encoder = finetuned_checkpoint.image_encoder
        elif (args.finetune_loss in ['flyp', 'flyp_ce']):
            image_encoder = finetuned_checkpoint
        print ('Checkpoint loaded!')
    else:
        image_encoder = None
        print ('No finetuned checkpoint')

    print (args)
    finetuned_model = finetune(args, finetuned_encoder=image_encoder)

    args.epochs = epochs[args.second_train_dataset]
    args.second_train_dataset = args.second_train_dataset + 'Val'
    args.train_dataset = args.second_train_dataset
    args.batch_size = 128

    if (args.finetune_loss in ['ce', 'ls', 'l2sp, lpips']):
        second_image_encoder = finetuned_model.image_encoder
    elif (args.finetune_loss in ['flyp', 'flyp_ce']):
        second_image_encoder = finetuned_model

    second_finetuned_model = finetune(args, finetuned_encoder=second_image_encoder)