
import argparse
import yaml

save_paths = [
    'log_file', 'mainc_log_file', 'save_path', 'mainc_save_path',
    'dump_path', 'attacked_dataset_path', 'reconstructed_dataset_path'
]
freqs = ['save_freq', 'dump_freq', 'attacked_save_freq']

DEFAULT_FREQ = 1
DEFAULT_DEVICE = 'cuda:0'
DEFAULT_PATH = None 

def validate_config(cfg):
    """Validates the run config
    uration dictionary.

    Checks for the presence and correct types of required keys.
    Raises an AssertionError if validation fails.

    Args:
        cfg (dict): The configuration dictionary to validate.
    """
    # check datasets
    assert 'test_datasets' in cfg.keys(), "test_datasets not specified"
    assert 'dataset_paths' in cfg.keys(), "dataset_paths not specified"
    assert len(cfg['test_datasets']) == len(cfg['dataset_paths']), \
        "test_datasets and dataset_paths list len mismatch"
    assert len(cfg['test_datasets']) > 0, "No datasets specified"
    
    # check names
    assert 'loss_name' in cfg.keys() and isinstance(cfg['loss_name'], str)
    assert 'attack' in cfg.keys() and isinstance(cfg['attack'], str)
    assert 'codec' in cfg.keys() and isinstance(cfg['codec'], str)
    # check presets
    assert 'attack_preset' in cfg.keys() and isinstance(cfg['attack_preset'], int)
    assert 'defence_preset' in cfg.keys() and isinstance(cfg['defence_preset'], int)
    # check other params
    assert 'device' in cfg.keys() and (cfg['device'] == 'cpu' or 'cuda' in cfg['device'])
    for freq in freqs:
        assert freq in cfg.keys() and  isinstance(cfg[freq], int) and cfg[freq] > 0
    assert 'run_all_presets' in cfg.keys() and isinstance(cfg['run_all_presets'], bool)
    assert 'only_default_preset' in cfg.keys() and isinstance(cfg['only_default_preset'], bool)
    assert 'batch_size' in cfg.keys() and isinstance(cfg['batch_size'], int) and cfg['batch_size'] > 0

    for p in save_paths:
        assert cfg[p] is None or isinstance(cfg[p], str)




def get_run_config():
    """Parses YAML config and command-line arguments to create a run configuration.

    Loads a configuration from a YAML file specified by the --config argument.
    Overrides YAML values with any command-line arguments provided (e.g., --codec, --attack).
    Sets default values for optional parameters if they are not specified.
    Validates the final configuration.

    Returns:
        dict: The validated run configuration dictionary.
    
    Raises:
        ValueError: If the YAML configuration file cannot be opened or parsed.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--codec", type=str, default=None)
    parser.add_argument("--attack", type=str, default=None)
    parser.add_argument("--attack_preset", type=int, default=None)
    parser.add_argument("--loss_name", type=str, default=None)
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()

    with open(args.config) as stream:
        try:
            run_cfg = yaml.safe_load(stream)
            print(f'==== YAML config loaded successfully ====\n{run_cfg}')
        except yaml.YAMLError as exc:
            print(exc)
            raise ValueError('YAML config openning failed!')

    if args.attack is not None: 
        print('[Warning] attack set via cli argument, value in run_cfg ignored.')
        run_cfg['attack'] = args.attack 
    if args.codec is not None:  
        print('[Warning] codec set via cli argument, value in run_cfg ignored.')   
        run_cfg['codec'] = args.codec
    if args.attack_preset is not None:   
        print('[Warning] attack_preset set via cli argument, value in run_cfg ignored.')   
        run_cfg['attack_preset'] = args.attack_preset
    if args.loss_name is not None:     
        print('[Warning] loss_name set via cli argument, value in run_cfg ignored.')   
        run_cfg['loss_name'] = args.loss_name

    
    # setting defaults if not specified
    for p in save_paths:
        if p not in run_cfg.keys():
            print(f'[Warning] {p} not specified, setting to {DEFAULT_PATH}.')
            run_cfg[p] = DEFAULT_PATH 

    if 'device' not in run_cfg.keys():
        print(f'[Warning] device not specified, setting to {DEFAULT_DEVICE}.')
        run_cfg['device'] = DEFAULT_DEVICE

    for freq in freqs:
        if freq not in run_cfg.keys():
            print(f'[Warning] {freq} not specified, setting to {DEFAULT_FREQ}.')
            run_cfg[freq] = DEFAULT_FREQ
    
    if 'run_all_presets' not in run_cfg.keys():
        run_cfg['run_all_presets'] = False
    if 'only_default_preset' not in run_cfg.keys():
        run_cfg['only_default_preset'] = False

    run_cfg['batch_size'] = 1 # For now only supports 1 due to process images of different resolutions
    validate_config(run_cfg)
    print('==== YAML config validated successfully ====')
    return run_cfg