import os
from typing import Dict, Any

#TODO: May need to just add fields instead of creating a new config
def build_experiment_config(args) -> Dict[str, Any]:
    """Build experiment configuration from arguments and dataset config."""
    # Load dataset-specific configuration
    dataset_cfg = args.dataset
    # Parse neurons configuration
    neurons = {k: tuple(v) #tuple(map(int, v.strip("()").split(",")))
               for k, v in dataset_cfg['neurons'].items()}
    
    # Calculate input size from neurons config
    input_size = max(end for _, (start, end) in neurons.items())

    # Make dataset path
    dataset_path = os.path.join(args.data_base_path, dataset_cfg.path)
    
    # Check if latent_size was explicitly set or if we should auto-adjust
    population_latent_sizes = dataset_cfg.get('population_latent_sizes', None)
    if population_latent_sizes:
        # When using population-specific sizes, total latent size is the sum
        total_latent_size = sum(population_latent_sizes.values())
        latent_size = total_latent_size
    else:
        # Default case: equal distribution
        population_latent_sizes = {pop: 1 for pop in neurons.keys()}
        # Set latent_size to the number of populations
        latent_size = len(population_latent_sizes)
    
    # Map stimulated region from dataset (if present) to population names
    # Support both 'stimulated_region' and newer 'stimulated_region_name'
    stimulated_populations = None
    stimulated_region_name = None
    if 'stimulated_region_name' in dataset_cfg:
        stimulated_region_name = dataset_cfg.get('stimulated_region_name')
    elif 'stimulated_region' in dataset_cfg:
        stimulated_region_name = dataset_cfg.get('stimulated_region')

    if stimulated_region_name:
        try:
            region_key = str(stimulated_region_name)
            # Infer populations by region prefix in neuron names (e.g., 'CFA_*/RFA_*')
            stimulated_populations = [
                name for name in neurons.keys()
                if name.startswith(region_key)
                or name.upper().startswith(region_key.upper())
                or name.lower().startswith(region_key.lower())
            ]
            if len(stimulated_populations) == 0:
                stimulated_populations = None
        except Exception:
            stimulated_populations = None

    # Derive windowing defaults (new windowed data pipeline)
    # Prefer explicit arg, then dataset config (total_t, then window_size); no legacy timepoint fallback
    _default_time_raw = dataset_cfg.get('total_t', None)
    if _default_time_raw is None:
        _default_time_raw = dataset_cfg.get('window_size', 0)
    default_time = int(_default_time_raw)
    window_size = int(getattr(args, 'window_size', dataset_cfg.get('window_size', default_time)))
    step_size = int(getattr(args, 'step_size', dataset_cfg.get('step_size', window_size)))
    causal_model = bool(getattr(args, 'causal_model', dataset_cfg.get('causal_model', False)))
    ic_window_size = int(getattr(args, 'ic_window_size', dataset_cfg.get('ic_window_size', max(1, window_size // 2))))
    seed = int(getattr(args, 'seed', dataset_cfg.get('seed', 0)))

    # Training schedule defaults for progressive loss window
    epochs_per_group = int(getattr(args, 'epochs_per_group', dataset_cfg.get('epochs_per_group', 300)))
    points_per_group = int(getattr(args, 'points_per_group', dataset_cfg.get('points_per_group', 20000)))

    # Build comprehensive config
    config = {
        # Dataset information
        'data_type': dataset_cfg.name,
        'dataset_path': dataset_path,
        'neurons': neurons,
        'input_size': input_size,

        'causal_model': causal_model,
        'solver_type': args.solver,
        'dynamics_dt': args.dynamics_dt,
        'window_size': window_size,
        'step_size': step_size,
        'ic_window_size': ic_window_size,
        'seed': seed,

        # Dynamics architecture
        'dynamics_model_type': args.dynamics.dynamics_model,
        'dynamics_monotonic': args.dynamics.dynamics_monotonic,
        'dynamics_hidden_dim': args.dynamics.dynamics_hidden_dim,
        'dynamics_nonlinearity': args.dynamics.dynamics_nonlinearity,
        'compositional_func': args.dynamics.compositional_func,
        
        # Readout architecture
        'readout_type': args.dynamics.readout,
        'force_percent': args.dynamics.force_percent,
        'nonlinear_readout': args.dynamics.nonlinear_readout,
        'readout_hidden_dim': args.dynamics.readout_hidden_dim,
        'readout_nonlinearity': args.dynamics.readout_nonlinearity,
        
        # Latent space configuration
        'latent_size': latent_size,  # Use the potentially adjusted latent_size
        'encoder_size': args.encoder_size,
        'population_latent_sizes': population_latent_sizes,  # Add population-specific sizes
        'stimulated_populations': stimulated_populations,
        'stimulated_region_name': stimulated_region_name,
        
        # Training parameters
        'epochs': args.epochs,
        'epochs_per_group': epochs_per_group,
        'points_per_group': points_per_group,
        'lr': args.lr,
        'lr_scheduler': args.lr_scheduler,
        'batch_size': args.batch_size,
        'dropout': args.dropout,
        'weight_decay': args.weight_decay,
        'points_per_group': args.points_per_group,
        'epochs_per_group': args.epochs_per_group,
        
        # Regularization
        'l1_reg': args.l1_reg,
        'orth_reg': args.orth_reg,  # Orthogonal regularization
        'noise': args.noise,
        # External input channels expected by the dynamics cell (always 1 now)
        'external_input_size': 1,
        
        # Experiment settings
        'heatmap': args.heatmap,
        'plot': args.plot,
        'save_plots': args.save_plots,
        'output_dir': args.output_dir,
        'validate_only': args.validate_only,

        # Weights and Biases (wandb) settings
        'use_wandb': getattr(args, 'use_wandb', False),
        'wandb_project': args.wandb.get('project', None),
        'wandb_entity': args.wandb.get('entity', None), 
        'wandb_name': args.wandb.get('name', None),
        'wandb_dir': args.wandb.get('dir', './results/wandb_logs'),
        'wandb_mode': args.wandb.get('mode', 'online'),

        # Additional dataset-specific settings
        'zscore': dataset_cfg.get('zscore', False),
        'scalers_path': dataset_cfg.get('scalers_path', None),
    }
    
    # IC solver (validation-time) options
    config['use_ic_solver'] = bool(getattr(args, 'use_ic_solver', False))
    config['ic_steps'] = int(getattr(args, 'ic_steps', 200))
    config['ic_l2'] = float(getattr(args, 'ic_l2', 1e-2))
    
    return config


def get_experiment_name(config: Dict[str, Any], custom_name: str = None) -> str:
    """Generate experiment name from configuration."""
    if custom_name:
        return custom_name
    
    # Auto-generate name from key parameters
    name_parts = [
        f"{config['epochs']}E",
        f"{config['latent_size']}L"
    ]
    
    # Add regularization info
    if config['l1_reg']:
        name_parts.append(f"L1-{config['l1_reg']}")
    if config['force_percent']:
        name_parts.append(f"F{int(config['force_percent']*100)}")
    
    return "-".join(name_parts) 