import yaml
import argparse
from pathlib import Path
from typing import Dict, Any


def load_config(config_path: str) -> Dict[str, Any]:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Ensure numeric values are properly typed
    if 'training' in config:
        if 'learning_rate' in config['training']:
            config['training']['learning_rate'] = float(config['training']['learning_rate'])
        if 'weight_decay' in config['training']:
            config['training']['weight_decay'] = float(config['training']['weight_decay'])
        if 'beta' in config['training']:
            config['training']['beta'] = float(config['training']['beta'])
    
    if 'data' in config:
        if 'train_size' in config['data']:
            config['data']['train_size'] = int(config['data']['train_size'])
        if 'test_size' in config['data']:
            config['data']['test_size'] = int(config['data']['test_size'])
        if 'batch_size' in config['data']:
            config['data']['batch_size'] = int(config['data']['batch_size'])
        if 'num_workers' in config['data']:
            config['data']['num_workers'] = int(config['data']['num_workers'])
        if 'leave_one_out_ratio' in config['data']:
            config['data']['leave_one_out_ratio'] = float(config['data']['leave_one_out_ratio'])
    
    if 'model' in config:
        if 'latent_dim' in config['model']:
            config['model']['latent_dim'] = int(config['model']['latent_dim'])
        if 'input_dim' in config['model']:
            config['model']['input_dim'] = int(config['model']['input_dim'])
        # Optional architecture selector
        if 'arch' in config['model']:
            config['model']['arch'] = str(config['model']['arch'])
        if 'in_channels' in config['model']:
            config['model']['in_channels'] = int(config['model']['in_channels'])
        # Hierarchical latent dims list (optional)
        if 'latent_dims' in config['model'] and config['model']['latent_dims'] is not None:
            config['model']['latent_dims'] = [int(v) for v in config['model']['latent_dims']]
        # Normalize optional encoder/decoder dims if provided
        if 'encoder_hidden_dims' in config['model'] and config['model']['encoder_hidden_dims'] is not None:
            config['model']['encoder_hidden_dims'] = [int(v) for v in config['model']['encoder_hidden_dims']]
        if 'decoder_hidden_dims' in config['model'] and config['model']['decoder_hidden_dims'] is not None:
            config['model']['decoder_hidden_dims'] = [int(v) for v in config['model']['decoder_hidden_dims']]
        # CNN channel lists
        if 'encoder_channels' in config['model'] and config['model']['encoder_channels'] is not None:
            config['model']['encoder_channels'] = [int(v) for v in config['model']['encoder_channels']]
        if 'decoder_channels' in config['model'] and config['model']['decoder_channels'] is not None:
            config['model']['decoder_channels'] = [int(v) for v in config['model']['decoder_channels']]
    
    if 'validation' in config:
        if 'random_seed' in config['validation']:
            config['validation']['random_seed'] = int(config['validation']['random_seed'])
    
    if 'training' in config:
        if 'epochs' in config['training']:
            config['training']['epochs'] = int(config['training']['epochs'])
        if 'early_stopping_patience' in config['training']:
            config['training']['early_stopping_patience'] = int(config['training']['early_stopping_patience'])
        if 'save_frequency' in config['training']:
            config['training']['save_frequency'] = int(config['training']['save_frequency'])
        # IWAE k can be int or list[int]
        if 'iwae_k' in config['training']:
            if isinstance(config['training']['iwae_k'], list):
                config['training']['iwae_k'] = [int(v) for v in config['training']['iwae_k']]
            else:
                config['training']['iwae_k'] = int(config['training']['iwae_k'])
    
    return config


def save_config(config: Dict[str, Any], save_path: str) -> None:
    """Save configuration to YAML file."""
    with open(save_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)


def parse_args() -> argparse.Namespace:
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='VAE MNIST Experiment')
    parser.add_argument('--config', type=str, required=True,
                       help='Path to configuration file')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default=None,
                       help='Force device selection (cpu or cuda). If omitted, auto-detect. You can also use env VAE_DEVICE.')
    parser.add_argument('--train_size', type=int, default=None,
                       help='Override training data size')
    parser.add_argument('--leave_one_out_ratio', type=float, default=None,
                       help='Override leave-one-out ratio')
    return parser.parse_args()


def update_config_with_args(config: Dict[str, Any], args: argparse.Namespace) -> Dict[str, Any]:
    """Update configuration with command line arguments."""
    if args.train_size is not None:
        config['data']['train_size'] = args.train_size
    if args.leave_one_out_ratio is not None:
        config['data']['leave_one_out_ratio'] = args.leave_one_out_ratio
    return config 