import argparse

from train import Trainer

if __name__ == "__main__":
    
    print('[Info] Experiment begins')
    
    parser = argparse.ArgumentParser(description='Fairness Through Matching')
    
    # SETTING
    parser.add_argument('--C_train', action='store_true')
    
    # SEEDS
    parser.add_argument('--seed', type=int, default=2022)
    parser.add_argument('--C_seed', type=int, default=2022)
    
    # DATASETS
    parser.add_argument('--source', type=str, default='Dutch_0')
    parser.add_argument('--target', type=str, default='Dutch_1')
    
    # NETWORKS
    parser.add_argument('--C_model', type=str, default='1MLP')
    parser.add_argument('--OT_model', type=str, default='2MLP')
    parser.add_argument('--act', type=str, default='LeakyReLU')
    parser.add_argument('--C_act', type=str, default='ReLU')
    parser.add_argument('--last_act', type=str, default='Sigmoid')
    parser.add_argument('--bn', action='store_true')
    parser.add_argument('--dropout', action='store_false')
    
    # OPTIMIZATION
    parser.add_argument('--batch_size', type=int, default=1024)
    parser.add_argument('--opt', type=str, default='Adam')
    parser.add_argument('--C_opt', type=str, default='Adam')
    parser.add_argument('--wd', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=5e-3)
    parser.add_argument('--C_lr', type=float, default=1e-3)
    parser.add_argument('--warmup_steps', type=int, default=400)
    parser.add_argument('--val_freq', type=int, default=50)
    parser.add_argument('--C_val_freq', type=int, default=50)
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--C_epochs', type=int, default=500)
    
    # HYPERPARAMETERS
    parser.add_argument('--lmda_ipm', type=float, default=100.0)
    parser.add_argument('--lmda_f', type=float, default=0.0)
    
    """ INITIALIZATIONS """
    args = parser.parse_args()
    for key, value in vars(args).items():
        print(f'\t [{key}]: {value}')
    
    """ FTM """
    trainer = Trainer()
    trainer._load_data(source=args.source, target=args.target, batch_size=args.batch_size, seed=args.seed)
    trainer.train_OT_bidirect(args)
    for args.C_seed in [2023, 2024, 2025, 2026, 2027]:
        trainer.train_FairC(args)
    print('[Info] Experiment done')