from main import main
import argparse
import os

parser = argparse.ArgumentParser()

parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
parser.add_argument('--embed_size', type=int, default=1024, help='Embedding size')
parser.add_argument('--num_layers', type=int, default=4, help='Number of layers')
parser.add_argument('--num_head', type=int, default=8, help='Number of attention heads')
parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--step_size', type=int, default=10, help='Step size for learning rate scheduler')
parser.add_argument('--gamma', type=float, default=0.999, help='Gamma value for learning rate scheduler')
parser.add_argument('--scheduler_flag', type=bool, default=True, help='Flag to enable/disable scheduler')
parser.add_argument('--patience', type=int, default=8, help='Patience for early stopping')

parser.add_argument('--loss_type', type=str, default='bce', 
                   choices=['bce', 'focal', 'weighted_bce', 'combined'],
                   help='Type of loss function to use')
parser.add_argument('--focal_weight', type=float, default=0.7, help='Weight for focal loss component')
parser.add_argument('--bce_weight', type=float, default=0.2, help='Weight for BCE loss component')
parser.add_argument('--consistency_weight', type=float, default=0.1, help='Weight for consistency loss component')
parser.add_argument('--focal_alpha', type=float, default=0.75, help='Alpha parameter for focal loss')
parser.add_argument('--focal_gamma', type=float, default=2.0, help='Gamma parameter for focal loss')
parser.add_argument('--label_smoothing', type=float, default=0.1, help='Label smoothing parameter')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer')
parser.add_argument('--l2_lambda', type=float, default=0.01, help='L2 regularization lambda')
parser.add_argument('--data_dir', type=str, 
                   default='yourpath/scUniGP/Data_processing/Dataspilt/Specific/mESC/TFs_500',
                   help='Path to the dataset directory')

args = parser.parse_args()
print('data_dir:', args.data_dir)
main(args.data_dir, args)
print(f'Training or Evaluation on {args.data_dir} finished.')