import copy
import train as train
from argparser import DefaultArguments


def single_train(dataset, method, batch_size, lr, phi, training_step, lr_adversary, epsilon):
    # Load the default arguments
    default_args = DefaultArguments()

    # Change if the loss should be printed
    default_args.print_loss = False

    # Change the amount of times the results are averaged here.
    default_args.average_over = 10

    args = copy.copy(default_args)

    args.dataset = dataset  # ['uci_adult', 'compas', 'law_school', 'celebA']
    args.model_name = method  # ['baseline', 'DRO', 'ARL', 'VFair']
    args.batch_size = batch_size
    args.lr_learner = lr
    # for ARL
    args.lr_adversary = lr_adversary
    args.pretrain_steps = 250

    args.phi = phi
    args.epsilon = epsilon
    args.test_every = 5
    args.log_dir = f'checkpoints/{dataset}/{args.model_name}/'
    args.train_steps = training_step

    for k, v in sorted(vars(args).items()):
        print(k, '=', v)

    # Train the model.
    train.main(args)


if __name__ == '__main__':
    single_train(
        dataset='compas',
        method='VFair',
        batch_size=32,
        lr=0.01,
        phi=0.9,
        training_step=495,
        lr_adversary=1,
        epsilon=3
    )
