import argparse
import yaml

class ConfigDict(dict):
    def __getattr__(self, name):
        return self[name]
    def __setattr__(self, name, value):
        self[name] = value

def get_config():

    parser = argparse.ArgumentParser()
    parser.add_argument('--wandb', type=bool, default=True, help='Whether to use wandb for logging')
    parser.add_argument('--n_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--hidden_channels', type=int, default=32, help='Number of hidden channels in the model')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the optimizer')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument("--exp_name", type=str, default='FakeNewsTwitter', help="name for the experiment")
    parser.add_argument("--checkpoint_dir", type=str, default='FakeNewsData/checkpoints', help="Path to the checkpoint directory")
    parser.add_argument("--dropout", type=float, default=0.2, help="Dropout rate")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")

    args = parser.parse_args()


    default_config = ConfigDict(
        exp_name=args.exp_name,
        model_checkpoints_dir=args.checkpoint_dir,        
        batch_size=args.batch_size,
        val_batch_size=10,
        test_batch_size=10,
        epochs=400,
        seed=args.seed,

        wd=1e-4,
        dropout=args.dropout,
        lr=args.lr,
        n_layers=args.n_layers,
        normalize_rho=False,
        hidden_channels=args.hidden_channels,
        out_channels=1,
        
        is_graph_task=True,
        max_num_GNANs=2,
        num_biom=2,
        num_biom_embed=5,
        feature_groups = [768, 300, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 300, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],

        weighted_loss = True,
        scheduler_T0=15,
        scheduler_decay=0.8,
        end_lr=5e-8,

        biomarker_groups=[['single'], ['not_single']],
        wandb=args.wandb
    )

    return default_config