# config.py

import argparse
import yaml


class ConfigDict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value


def get_config():
    parser = argparse.ArgumentParser()
    # parser.add_argument("--config_path", type=str, default='configs/run1.yaml', help="Path to the config file")
    parser.add_argument("--config_path", type=str, default='configs_P12/run1.yaml', help="Path to the config file")
    parser.add_argument('--wandb', type=bool, default=True, help='Whether to use wandb for logging')
    parser.add_argument('--n_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--hidden_channels', type=int, default=32, help='Number of hidden channels in the model')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the optimizer')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--dropout', type=float, default=0.0)
    parser.add_argument('--wandb_exp_name', type=str, default='LoSP12GMAN')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--val_batch_size', type=int, default=10, help='Batch size for validation')
    parser.add_argument('--test_batch_size', type=int, default=10, help='Batch size for testing')
    
    parser.add_argument('--run_name', type=str, default='Develop', help='Run name for the experiment')
    
    parser.add_argument('--deepset_n_layers', type=int, default=2, help='Number of layers in the deepset')

    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        raw_cfg = yaml.safe_load(f)

    default_config = ConfigDict(

        wandb_exp_name=args.wandb_exp_name,
        # wandb_exp_name = 'SetGNANPhysionet'

        exp_name="Physionet 12 -- DeepSet GNAN",
        run_name=args.run_name,
        # exp_name="Physionet 19 -- DeepSet GNAN",

        logging_dir="develop_logs",
        config_path=args.config_path,

        sequential_data_dir="../../merged",

        graph_data_dir="",

        model_checkpoints_dir='P12_checkpoints',

        train_ratio=0.8,
        val_ratio=0.10,

        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        test_batch_size=args.test_batch_size,
        epochs=1000,

        deepset_n_layers=args.deepset_n_layers,

        wd=1e-4,
        dropout=args.dropout,
        seed=args.seed,
        lr=args.lr,
        n_layers=args.n_layers,
        normalize_rho=False,
        hidden_channels=args.hidden_channels,
        out_channels=1,
        is_graph_task=True,
        # max_num_GNANs=34,
        max_num_GNANs=36,

        # num_biom=34,
        num_biom=36,
        num_biom_embed=5,
        feature_groups=[1, 2, 1, 5],

        weighted_loss=True,
        scheduler_T0=15,
        scheduler_decay=0.8,
        end_lr=5e-8,

        biomarker_groups=[],
        wandb=args.wandb
    )

    # Override only biomarker_groups or any other parameter
    for key, value in raw_cfg.items():
        default_config[key] = value

    return default_config