"""
Traffic Flow - Ablation Experiments

Complete list of available models (select from these for ablation experiments):

FluxNet model series:
- FluxNet_N_1D: No constraint (formerly FluxNet_U_1D)
- FluxNet_P_1D: Positive flux constraint (softplus)
- FluxNet_L_1D: Lower bound constraint
- FluxNet_D_1D: Double bound constraint (optimal)

Baseline model series:
- CNN_1D_direct: Direct prediction
- CNN_1D_residual: Residual prediction
- CNN_1D_bound: With double bound constraint (sigmoid)
- CNN_1D_soft: With soft conservation loss

Each model can choose:
- onestep training
- pushforward training (pf)
"""

import os
import sys

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import (
    run_single_experiment, generate_summary_table, ModelConfig, TrainingConfig
)

# ============================================================================
# Complete list of available models (select from these for experiments)
# ============================================================================
ALL_AVAILABLE_MODELS = [
    # FluxNet series - onestep
    'FluxNet_N_1D_onestep',   # No constraint (deprecated)
    'FluxNet_P_1D_onestep',   # Positive flux (softplus)
    'FluxNet_L_1D_onestep',   # Lower bound only
    'FluxNet_D_1D_onestep',   # Double bound (RECOMMENDED for traffic flow)
    'FluxNet_U_1D_onestep',   # Upper bound only (ablation)
    'FluxNet_D_1D_no_dcl_onestep',  # D without DCL loss (ablation)
    # FluxNet series - pushforward
    'FluxNet_N_1D_pf',
    'FluxNet_P_1D_pf',
    'FluxNet_L_1D_pf',
    'FluxNet_D_1D_pf',
    'FluxNet_U_1D_pf',
    'FluxNet_D_1D_no_dcl_pf',  # D without DCL loss (ablation)
    # FNO series - onestep
    'FNO_1D_onestep',         # Standard FNO
    'FNO_1D_soft_onestep',    # FNO + soft conservation loss
    'FNO_1D_bound_soft_onestep',  # FNO + bound (sigmoid) + soft conservation
    'FNO_FluxD_1D_onestep',   # FNO with FluxNet-D head (FNO backbone + flux head)
    # FNO series - pushforward
    'FNO_1D_pf',
    'FNO_1D_soft_pf',         # FNO + soft conservation loss
    'FNO_1D_bound_soft_pf',   # FNO + bound (sigmoid) + soft conservation
    'FNO_FluxD_1D_pf',
    # CNN Baseline series - onestep
    'CNN_1D_direct_onestep',
    'CNN_1D_residual_onestep',
    'CNN_1D_direct_soft_onestep',
    'CNN_1D_residual_soft_onestep',
    'CNN_1D_bound_onestep',         # sigmoid bound
    'CNN_1D_bound_soft_onestep',    # sigmoid bound + soft conservation
    # CNN Baseline series - pushforward
    'CNN_1D_direct_pf',
    'CNN_1D_residual_pf',
    'CNN_1D_direct_soft_pf',
    'CNN_1D_residual_soft_pf',
    'CNN_1D_bound_pf',
    'CNN_1D_bound_soft_pf',
]

# ============================================================================
# List of models to actually run
# Baseline comparison and ablation experiments as specified in issue126.md
# ============================================================================
SELECTED_MODELS = [
    # ===== Our method (FluxNet-D) =====
    'FluxNet_D_1D_pf',              # Optimal method: double bound constraint + pushforward

    # ===== Baseline comparison =====
    'FNO_1D_pf',                    # FNO baseline
    'FNO_1D_soft_pf',               # FNO + soft conservation loss
    'FNO_1D_bound_soft_pf',         # FNO + bounded (sigmoid) + soft conservation loss
    'CNN_1D_residual_pf',           # CNN baseline (residual prediction)
    'CNN_1D_residual_soft_pf',      # CNN + soft conservation loss
    'CNN_1D_bound_soft_pf',         # CNN + bounded (sigmoid) + soft conservation loss

    # ===== Ablation experiments =====
    'FluxNet_P_1D_pf',              # Ablation: P-head (positive flux)
    'FluxNet_L_1D_pf',              # Ablation: L-head (lower bound only)
    'FluxNet_U_1D_pf',              # Ablation: U-head (upper bound only)
    'FluxNet_D_1D_no_dcl_pf',       # Ablation: D-head without dual consistency loss
    'FluxNet_D_1D_onestep',         # Ablation: D-head without pushforward
    'FNO_FluxD_1D_pf',              # Ablation: FNO backbone + D-head
]


def get_experiment_config(model_name: str, hparams: dict) -> dict:
    """Generate experiment configuration based on model name"""
    use_pf = '_pf' in model_name

    # FluxNet series
    if model_name.startswith('FluxNet_N_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_N_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    elif model_name.startswith('FluxNet_P_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_P_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    elif model_name.startswith('FluxNet_L_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_L_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    # FluxNet_D_1D without DCL loss (ablation)
    elif model_name.startswith('FluxNet_D_1D_no_dcl'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_D_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps'],
                dcl_weight=0.0,  # Disable DCL loss
                loss_weight_mode='manual',
                loss_weights=hparams.get('loss_weights', {'p_loss': 1.0, 'stability_loss': 0.5})
            )
        }

    elif model_name.startswith('FluxNet_D_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_D_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps'],
                dcl_weight=hparams.get('dcl_weight', 0.1),
                loss_weight_mode=hparams.get('loss_weight_mode', 'adaptive'),
                loss_weights=hparams.get('loss_weights', {})
            )
        }

    elif model_name.startswith('FluxNet_U_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_U_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    # FNO series
    elif model_name.startswith('FNO_FluxD_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_FluxD_1D',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    # FNO + bound + soft (bounded + soft conservation)
    elif model_name.startswith('FNO_1D_bound_soft'):
        soft_cons = hparams.get('soft_cons_weight', 0.1)
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_1D',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                prediction_mode='residual',
                bound_mode='double',  # sigmoid bounded to [0,1]
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps'],
                soft_conservation_weight=soft_cons
            )
        }

    # FNO + soft (soft conservation only)
    elif model_name.startswith('FNO_1D_soft'):
        soft_cons = hparams.get('soft_cons_weight', 0.1)
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_1D',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                prediction_mode='residual'
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps'],
                soft_conservation_weight=soft_cons
            )
        }

    elif model_name.startswith('FNO_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_1D',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                prediction_mode='residual'
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    # CNN Baseline series
    elif model_name.startswith('CNN_1D'):
        pred_mode = 'residual' if 'residual' in model_name else 'direct'
        soft_cons = 0.1 if 'soft' in model_name else 0.0
        bound_mode = 'double' if 'bound' in model_name else 'none'

        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='CNN_Baseline_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                prediction_mode=pred_mode,
                bound_mode=bound_mode,
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                soft_conservation_weight=soft_cons,
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    else:
        raise ValueError(f"Unknown model: {model_name}")


def main():
    # ========================================================================
    # Path configuration (using absolute paths)
    # ========================================================================
    save_path = "/home/ml4pf/zshlan/FluxNet/results/traffic_flow/ablation"
    train_folder = "/home/ml4pf/zshlan/FluxNet/dataset/traffic_flow/train"
    val_folder = "/home/ml4pf/zshlan/FluxNet/dataset/traffic_flow/val"
    test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/traffic_flow/test"

    # ========================================================================
    # Hyperparameter configuration
    # ========================================================================
    hparams = {
        # Model architecture parameters
        'base_channels': 64,
        'num_blocks': 6,
        'kernel_size': 5,
        'neighborhood_size': 11,
        # Training parameters
        'num_epochs': 100,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'weight_decay': 1e-2,
        'ndt': 1,
        'num_workers': 4,
        'unroll_steps': 5,
        # FNO specific
        'fno_modes': 16,
        'fno_width': 64,
        'fno_layers': 4,
        # ================================================================
        # Loss weight configuration
        # ================================================================
        # Global loss weight mode: 'adaptive' or 'manual'
        'loss_weight_mode': 'adaptive',
        # DCL loss weight (for FluxNet-D series)
        'dcl_weight': 0.1,
        # Soft conservation loss weight (for baseline models)
        'soft_cons_weight': 0.1,
        # Manual loss weights (takes effect when loss_weight_mode='manual')
        # Available loss terms: p_loss (prediction), dcl_loss (dual consistency), stability_loss (stability), cons_loss (conservation)
        'loss_weights': {
            'p_loss': 1.0,
            'dcl_loss': 0.1,
            'stability_loss': 0.5,
            'cons_loss': 0.1,
        },
    }

    # ========================================================================
    # Experiment control
    # ========================================================================
    gpu_id = 1
    seed = 42
    run_training = True
    run_evaluation = True
    evaluate_mode = 'both'
    visualize_trajectories = 'all'
    # visualize_trajectories = None

    # ========================================================================
    # Generate experiment configurations
    # ========================================================================
    ablation_experiments = []
    for model_name in SELECTED_MODELS:
        if model_name in ALL_AVAILABLE_MODELS:
            try:
                exp_config = get_experiment_config(model_name, hparams)
                ablation_experiments.append(exp_config)
            except ValueError as e:
                print(f"Skipping model {model_name}: {e}")
        else:
            print(f"Warning: Model '{model_name}' is not in the available list")

    print(f"\nWill run {len(ablation_experiments)} ablation experiments:")
    for exp in ablation_experiments:
        print(f"  - {exp['name']}")
    print()

    # ========================================================================
    # Run experiments
    # ========================================================================
    results = []

    for i, exp in enumerate(ablation_experiments):
        print(f"\n{'#'*80}")
        print(f"# Ablation Experiment {i+1}/{len(ablation_experiments)}: {exp['name']}")
        print(f"{'#'*80}\n")

        try:
            result = run_single_experiment(
                model_config=exp['model_config'],
                training_config=exp['training_config'],
                dataset_type='traffic_flow',
                train_folder=train_folder,
                val_folder=val_folder,
                test_folder=test_folder,
                save_path=save_path,
                experiment_name=exp['name'],
                gpu_id=gpu_id,
                run_training=run_training,
                run_evaluation=run_evaluation,
                seed=seed,
                evaluate_mode=evaluate_mode,
                visualize_trajectories=visualize_trajectories
            )
            results.append(result)
        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    # ========================================================================
    # Generate summary table
    # ========================================================================
    print("\n" + "="*80)
    generate_summary_table(save_path, ablation_experiments, "ablation_summary.md")
    print(f"\nAll ablation experiments completed! Results saved at: {save_path}")


if __name__ == "__main__":
    main()
