import argparse
import yaml
import os


def load_yaml_config(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)


def get_args():
    parser = argparse.ArgumentParser(description='Unified Clinical Learner Args')

    # Basic parameters (common control items)
    parser.add_argument('--model', type=str, required=True, help='Model name, used to locate YAML config')
    parser.add_argument('--mode', type=str, default='train', help='train or test mode')
    parser.add_argument('--gpu', type=int, nargs='+', default=[0], help='List of GPU indices, supports single or multiple indices, e.g. `--gpu 0` or `--gpu 0 1`')
    parser.add_argument('--fold', type=int, default=1, choices=[1, 2, 3, 4, 5])
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--save_checkpoint', action='store_true', default=True)
    parser.add_argument('--dev_run', action='store_true')
    parser.add_argument('--use_triplet', action='store_true')
    parser.add_argument('--config_root', type=str, default=os.path.join(os.path.dirname(__file__), 'configs'))
    parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to checkpoint for testing mode')
    
    # Data subset selection
    parser.add_argument('--matched', action='store_true', help='Use matched subset for all splits')
    parser.add_argument('--cross_eval', type=str, choices=['matched_to_full', 'full_to_matched'], 
                       help='Cross evaluation: matched_to_full (train matched, test full) or full_to_matched (train full, test matched)')
    
    # Demographic features control
    parser.add_argument('--use_demographics', action='store_true', 
                       help='Include demographic features (age, gender, admission_type, race) in EHR data')
    parser.add_argument('--demographic_cols', type=str, nargs='+', 
                       default=['age', 'gender', 'admission_type', 'race'],
                       help='List of demographic columns to include')
    
    # Label weight control for class imbalance
    parser.add_argument('--use_label_weights', action='store_false', 
                       help='Enable label weights to handle class imbalance')
    parser.add_argument('--label_weight_method', type=str, 
                       choices=['balanced', 'inverse', 'sqrt_inverse', 'log_inverse', 'custom'],
                       default='balanced',
                       help='Method for calculating label weights')
    parser.add_argument('--custom_label_weights', type=str, nargs='+',
                       help='Custom label weights as key-value pairs')
    
    # CXR modality dropout for robustness evaluation
    parser.add_argument('--cxr_dropout_rate', type=float, default=0.0,
                       help='Dropout rate for CXR data during training/validation (0.0-1.0)')
    parser.add_argument('--cxr_dropout_seed', type=int, default=None,
                       help='Random seed for CXR dropout (uses main seed if not specified)')
    
    # Fairness evaluation options
    parser.add_argument('--compute_fairness', action='store_true',
                       help='Enable fairness metrics computation during evaluation')
    parser.add_argument('--fairness_attributes', type=str, nargs='+',
                       default=['age', 'race', 'gender'],
                       help='List of sensitive attributes for fairness evaluation')
    parser.add_argument('--fairness_age_bins', type=float, nargs='+',
                       default=[0, 40, 60, 80],
                       help='Age bins for fairness evaluation (if age is included)')
    parser.add_argument('--fairness_intersectional', action='store_true',
                       help='Compute intersectional fairness metrics')
    parser.add_argument('--fairness_include_cxr', action='store_true', 
                    help='Whether to include CXR availability in fairness analysis')
    
    # Predictions saving options
    parser.add_argument('--save_predictions', action='store_true',
                       help='Save test predictions and labels to experiment directory')
    parser.add_argument('--predictions_save_dir', type=str, default=None,
                       help='Custom directory to save predictions (if not specified, uses experiment directory)')
    
    parser.add_argument('--task', type=str, default='mortality',help='phenotype or mortality')
    parser.add_argument('--patience', type=int, default=10, help='number of epoch to wait for best')
    parser.add_argument('--log_dir', type=str, default=None, help='Log directory path. If not specified, will use default path.')

    # Dataset paths


    parser.add_argument('--resized_cxr_root', type=str, help='Path to the cxr data',
                        default='/root/autodl-tmp/benchmark/benchmark_dataset/mimic_cxr_resized')
    
    parser.add_argument('--image_meta_path', type=str, help='Path to the image meta data',
                        default='/root/autodl-tmp/benchmark/benchmark_dataset/mimic-cxr-2.0.0-metadata.csv')
    
    parser.add_argument('--ehr_root', type=str, help='Path to the data dir',
                    default='/root/autodl-tmp/benchmark/benchmark_dataset/DataProcessing/benchmark_data/250827')

    parser.add_argument('--pkl_dir', type=str, help='Path to the pkl data',
                        default='/root/autodl-tmp/benchmark/benchmark_dataset/DataProcessing/benchmark_data/250827/data_pkls')

    # parser.add_argument('--ehr_root', type=str, help='Path to the data dir',
    #                 default='/hdd/DataProcessing/benchmark_data/250827')

    # parser.add_argument('--pkl_dir', type=str, help='Path to the pkl data',
    #                     default='/hdd/DataProcessing/benchmark_data/250827/data_pkls')
    
    parser.add_argument('--demographics_in_model_input', action='store_true',
                   help='Include demographic features in model input (for models trained with demographics)')


    # First parse to get model name and unknown arguments
    partial_args, unknown_args = parser.parse_known_args()
    model_name = partial_args.model
    config_path = os.path.join(partial_args.config_root, f'{model_name}.yaml')

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"YAML config not found for model `{model_name}` at: {config_path}")

    # Load YAML configuration
    yaml_config = load_yaml_config(config_path)
    print(f"\nLoading YAML config: {config_path}")
    
    # Process unknown arguments (only those defined in YAML)
    modified_params = {}
    if unknown_args:
        print(f"\nDetected command line argument overrides:")
        i = 0
        while i < len(unknown_args):
            # Ensure argument name starts with --
            if unknown_args[i].startswith('--'):
                param_name = unknown_args[i][2:]  # Remove -- prefix
                
                # Check if there's a corresponding value
                if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith('--'):
                    param_value = unknown_args[i + 1]
                    i += 2  # Skip both name and value
                else:
                    # No value, treat as boolean flag
                    param_value = 'True'
                    i += 1  # Skip only the name
                
                # Only process parameters defined in YAML
                if param_name in yaml_config:
                    original_value = yaml_config[param_name]
                    
                    # Try to convert type to match original
                    if isinstance(original_value, bool):
                        if param_value.lower() in ('true', 'yes', 'y', '1'):
                            param_value = True
                        else:
                            param_value = False
                    elif isinstance(original_value, int):
                        param_value = int(param_value)
                    elif isinstance(original_value, float):
                        param_value = float(param_value)
                    # String type doesn't need conversion
                    
                    # Update YAML config
                    yaml_config[param_name] = param_value
                    modified_params[param_name] = (original_value, param_value)
                    
                    # Dynamically add argument to parser
                    if isinstance(param_value, bool):
                        if param_value:
                            parser.add_argument(f'--{param_name}', action='store_true', default=param_value)
                        else:
                            parser.add_argument(f'--{param_name}', action='store_false', default=param_value)
                    else:
                        parser.add_argument(f'--{param_name}', type=type(param_value), default=param_value)
                else:
                    print(f"  Warning: Parameter '{param_name}' not defined in YAML config, will be ignored")
            else:
                i += 1
                print(f"  Warning: Ignoring invalid parameter format: {unknown_args[i-1]}")

    # Print modified parameters
    if modified_params:
        for param_name, (old_value, new_value) in modified_params.items():
            print(f"Modified parameter: {param_name} = {new_value} (original: {old_value})")
    else:
        print("No YAML parameters were modified")
    print("")  # Empty line

    # Update parser defaults
    parser.set_defaults(**yaml_config)

    # Final parse of all arguments
    args = parser.parse_args()
    
    # Handle data configuration logic (existing code)
    if args.cross_eval == 'matched_to_full':
        args.train_matched = True
        args.val_matched = True
        args.test_matched = False
        print("Cross evaluation mode: Training on matched data, testing on full data")
    elif args.cross_eval == 'full_to_matched':
        args.train_matched = False
        args.val_matched = False
        args.test_matched = True
        print("Cross evaluation mode: Training on full data, testing on matched data")
    elif args.matched:
        args.train_matched = True
        args.val_matched = True
        args.test_matched = True
        print("Using matched data for all splits")
    else:
        args.train_matched = False
        args.val_matched = False
        args.test_matched = False
        print("Using full data for all splits")
    
    # Print demographic features configuration
    if args.use_demographics:
        print(f"Using demographic features: {args.demographic_cols}")
    else:
        print("Demographic features disabled")
    
    return args