# config.py

import argparse
import yaml

class ConfigDict(dict):
    def __getattr__(self, name):
        return self[name]
    def __setattr__(self, name, value):
        self[name] = value

def get_config():

    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, default='configs/P_12/biology_informed.yaml', help="Path to the config file")
    parser.add_argument('--wandb', type=bool, default=True, help='Whether to use wandb for logging')
    parser.add_argument('--n_layers', type=int, default=2, help='Number of layers in the model')
    parser.add_argument('--hidden_channels', type=int, default=32, help='Number of hidden channels in the model')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the optimizer')
    parser.add_argument('--warmup_epochs', type=int, default=5, help='Warmup epochs for the scheduler')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--val_batch_size', type=int, default=10, help='Batch size for validation')
    parser.add_argument('--test_batch_size', type=int, default=10, help='Batch size for testing')
    parser.add_argument('--gnan_mode', type=str, default='per_group', choices=['single', 'per_group', 'per_biomarker'], help='GNAN mode to use')
    # add seed argument
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    # add dropout
    parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate for the model')
    # add run_name for wandb naming
    parser.add_argument('--run_name', type=str, default='Develop', help='W&B run name (overrides auto-composed name)')
    
    # Ablation study arguments
    parser.add_argument('--disable_deepset', action='store_true', help='Disable DeepSet aggregation for ablation')
    parser.add_argument('--disable_distance_embedding', action='store_true', help='Disable distance embedding for ablation')
    parser.add_argument('--use_simple_aggregation', action='store_true', help='Use simple aggregation for ablation')
    parser.add_argument('--feature_processor_type', type=str, choices=['gnan', 'simple_linear', 'identity'], 
                       default='gnan', help='Type of feature processor to use (gnan, simple_linear, identity)')

    # P12-specific label control
    parser.add_argument('--predictive_label', type=str, default='mortality', choices=['mortality', 'LoS'],
                        help="Which label to predict for P12: 'mortality' or 'LoS' (length of stay > threshold)")
    parser.add_argument('--los_threshold_days', type=int, default=3,
                        help='Threshold in days for LoS classification (stay > threshold -> 1)')

    # Raindrop data loading flags
    parser.add_argument('--use_raindrop_data', action='store_true', help='Load P12 via Raindrop splits/preprocessing')
    parser.add_argument('--raindrop_base_path', type=str, default='/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/P12data', help='Base path to Raindrop P12data')
    parser.add_argument('--raindrop_split_idx', type=int, default=1, help='Split index (1..5) to mirror Raindrop')
    parser.add_argument('--raindrop_split_type', type=str, default='random', choices=['random', 'age', 'gender'], help='Split type as in Raindrop')
    parser.add_argument('--raindrop_reverse', action='store_true', help='Reverse flag for age/gender splits (Raindrop)')

    # Cached dataset flags (PSV files with newlabel, e.g., /tmp)
    parser.add_argument('--use-cached-dataset', dest='use_cached_dataset', action='store_true', help='Use cached PSV dataset directly (e.g., /tmp/*.psv)')
    parser.add_argument('--cached_dataset_dir', type=str, default='/tmp', help='Directory containing cached PSV files with newlabel')
    parser.add_argument('--split_pkl_path', type=str, default='P12_data_splits/split_1.pkl', help='Path to split_1.pkl to define train/val/test when using cached dataset')

    # Balanced sampling strategy (multi-GPU only)
    parser.add_argument('--balance_strategy', type=str, default='upsample', choices=['upsample', 'downsample'],
                        help="How to balance classes per batch in DDP: 'upsample' (repeat minority) or 'downsample' (subsample majority)")


    # upsample factor
    parser.add_argument('--upsample_factor', type=int, default=3, help='Upsample factor for the sampler')

    args = parser.parse_args()

    with open(args.config_path, "r") as f:
        raw_cfg = yaml.safe_load(f)

    default_config = ConfigDict(
        # exp_name="Physionet 19 -- DeepSet GNAN",
        exp_name="Physionet 12 -- DeepSet GNAN",
        run_name=args.run_name,
        logging_dir="develop_logs",
        config_path = args.config_path,
        
        sequential_data_dir="../../merged",

        graph_data_dir="",
    
        model_checkpoints_dir='P12_checkpoints',

        train_ratio=0.8, 
        val_ratio=0.10,
        
        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        test_batch_size=args.test_batch_size,
        epochs=1000,

        wd=1e-3,
        dropout=args.dropout,
        lr=args.lr,
        n_layers=args.n_layers,
        normalize_rho=False,
        hidden_channels=args.hidden_channels,
        out_channels=1,
        is_graph_task=True,
        # max_num_GNANs=34,
        max_num_GNANs=36,
        warmup_epochs=args.warmup_epochs,

        # num_biom=34,
        num_biom=36,
        num_biom_embed=3,
        feature_groups=[1, 1, 1, 1, 3],

        weighted_loss = True,
        scheduler_T0=15,
        scheduler_decay=0.8,
        end_lr=5e-8,

        biomarker_groups=[],
        wandb=args.wandb,
        gnan_mode=args.gnan_mode,
        seed=args.seed,
        
        # Ablation flags
        disable_deepset=args.disable_deepset,
        disable_distance_embedding=args.disable_distance_embedding,
        use_simple_aggregation=args.use_simple_aggregation,
        feature_processor_type=args.feature_processor_type,

        # Labeling options
        predictive_label=args.predictive_label,
        los_threshold_days=args.los_threshold_days,

        # Raindrop loading options
        use_raindrop_data=args.use_raindrop_data,
        raindrop_base_path=args.raindrop_base_path,
        raindrop_split_idx=args.raindrop_split_idx,
        raindrop_split_type=args.raindrop_split_type,
        raindrop_reverse=args.raindrop_reverse,

        # Cached dataset options
        use_cached_dataset=args.use_cached_dataset,
        cached_dataset_dir=args.cached_dataset_dir,
        split_pkl_path=args.split_pkl_path,

        # Sampler options (only used in multi-GPU runner)
        balance_strategy=args.balance_strategy,
        upsample_factor=args.upsample_factor,
    )

    # Override only biomarker_groups or any other parameter
    for key, value in raw_cfg.items():
        default_config[key] = value

    return default_config


