import os

import torch


def save_checkpoint(net, clf, critic, critic_z, lsh_attention, epoch, run_name, args, acc, scalar_logger, script_name, optim, attention_optimizer):
    # Save checkpoint.

    print('Saving..')
    state = {
        'net': net.state_dict(),
        'clf': clf.state_dict(),
        'critic': critic.state_dict(),
        'epoch': epoch,
        'args': vars(args),
        'script': script_name,
        'acc': acc,
        'optim': optim.state_dict(),
    }

    if hasattr(lsh_attention, 'WQ'): # check if lsh_attention contains learnable parameters
        state['lsh_attention'] = lsh_attention.state_dict()
    
    if critic_z is not None:
        state['critic_z'] = critic_z.state_dict()

    if attention_optimizer is not None:
        state['att_optim'] = attention_optimizer.state_dict()

    scalar_logger.log_value(epoch, ('Acc', acc))

    if not os.path.isdir(args.save_location):
        os.makedirs(args.save_location, exist_ok=True)
    destination = os.path.join(args.save_location, f"{run_name}_epoch_{epoch}.pth")
    torch.save(state, destination)
