import torch.optim as optim





def define_optimizer(args, models_gen):
    optimizers_gen = []
    for i in range(args.num_classes):
        optimizers_gen.append(optim.Adam(models_gen[i].parameters(), lr=args.learning_rate))

    return optimizers_gen