"""
Shallow Water Equations - Ablation Experiments

Complete list of available models (select from these for ablation experiments):

FluxNet model series:
- FluxNet_SW_LAP: L-head for h + Advection-Pressure decomposition (RECOMMENDED)
- FluxNet_SW_PAP: P-head for h + Advection-Pressure decomposition (ablation)
- FluxNet_SW_LAP_no_gate: LAP without h^2 pressure gate (ablation)
- FluxNet_SW_PPP: P-head for all fields (ablation)
- FluxNet_SW_LPP: L-head for h, P-head for mx/my (ablation)
- FluxNet_SW_NNN: No constraint (deprecated)

Baseline model series:
- SW_Baseline_direct: Direct prediction
- SW_Baseline_residual: Residual prediction
- SW_Baseline_bound: With h lower bound constraint
- SW_Baseline_soft: With soft conservation loss

Other models:
- FNO_SW: FNO baseline

Each model can choose:
- onestep training
- pushforward training (pf)
"""

import os
import sys

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import (
    run_single_experiment, generate_summary_table, ModelConfig, TrainingConfig
)

# ============================================================================
# Complete list of available models (select from these for experiments)
# ============================================================================
ALL_AVAILABLE_MODELS = [
    # FluxNet series - onestep (recommended: LAP)
    'FluxNet_SW_LAP_onestep',       # RECOMMENDED: L-head + Advection-Pressure
    'FluxNet_SW_PAP_onestep',       # ablation: P-head + Advection-Pressure
    'FluxNet_SW_LAP_no_gate_onestep',  # ablation: LAP without h^2 gate
    'FluxNet_SW_PPP_onestep',       # ablation: all P-head
    'FluxNet_SW_LPP_onestep',       # ablation: L-head for h, P-head for m
    'FluxNet_SW_NNN_onestep',       # deprecated: no constraint
    # FluxNet series - pushforward
    'FluxNet_SW_LAP_pf',
    'FluxNet_SW_PAP_pf',
    'FluxNet_SW_LAP_no_gate_pf',
    'FluxNet_SW_PPP_pf',
    'FluxNet_SW_LPP_pf',
    'FluxNet_SW_NNN_pf',
    # Baseline series (old CNN direct) - onestep
    'SW_Baseline_direct_onestep',
    'SW_Baseline_residual_onestep',
    'SW_Baseline_direct_soft_onestep',
    'SW_Baseline_residual_soft_onestep',
    'SW_Baseline_bound_onestep',
    'SW_Baseline_bound_soft_onestep',
    # Baseline series (old CNN direct) - pushforward
    'SW_Baseline_direct_pf',
    'SW_Baseline_residual_pf',
    'SW_Baseline_direct_soft_pf',
    'SW_Baseline_residual_soft_pf',
    'SW_Baseline_bound_pf',
    'SW_Baseline_bound_soft_pf',
    # FNO (no projection)
    'FNO_SW_onestep',
    'FNO_SW_pf',
    # FNO + soft conservation loss
    'FNO_SW_soft_onestep',          # FNO + soft conservation loss
    'FNO_SW_soft_pf',
    # FNO + Projection (strong baseline)
    'FNO_SW_Proj_box_onestep',      # FNO + box projection (h>=0)
    'FNO_SW_Proj_box_mass_onestep', # FNO + box + mass projection (RECOMMENDED baseline)
    'FNO_SW_Proj_box_pf',
    'FNO_SW_Proj_box_mass_pf',
    # CNN + Projection (strong baseline)
    'CNN_SW_Proj_box_onestep',      # CNN + box projection
    'CNN_SW_Proj_box_mass_onestep', # CNN + box + mass projection (RECOMMENDED baseline)
    'CNN_SW_Proj_box_pf',
    'CNN_SW_Proj_box_mass_pf',
    # FNO + FluxLAP head (ablation: FNO backbone with LAP head)
    'FNO_FluxLAP_onestep',
    'FNO_FluxLAP_pf',
]

# ============================================================================
# List of models to actually run
# Baseline comparison and ablation experiments as specified in issue126.md
# ============================================================================
SELECTED_MODELS = [
    # ===== Our method (FluxNet-LAP) =====
    'FluxNet_SW_LAP_pf',              # Optimal method: L-head + Adv-Pressure + pushforward

    # ===== Baseline comparison =====
    'FNO_SW_pf',                      # FNO baseline
    'FNO_SW_soft_pf',                 # FNO + soft conservation loss
    'SW_Baseline_residual_pf',        # CNN baseline (residual prediction)
    'SW_Baseline_residual_soft_pf',   # CNN + soft conservation loss
    'FNO_SW_Proj_box_mass_pf',        # FNO + Box + Mass projection (strong baseline)
    'CNN_SW_Proj_box_mass_pf',        # CNN + Box + Mass projection (strong baseline)

    # ===== Ablation experiments =====
    'FluxNet_SW_PPP_pf',              # Ablation: P-head for all three fields
    'FluxNet_SW_LPP_pf',              # Ablation: L-head for h, P-head for m
    'FluxNet_SW_PAP_pf',              # Ablation: P-head for h + Adv-Pressure
    'FluxNet_SW_LAP_no_gate_pf',      # Ablation: LAP without press_scale gating
    'FluxNet_SW_LAP_onestep',         # Ablation: LAP without pushforward
    'FNO_FluxLAP_pf',                 # Ablation: FNO backbone + LAP head
]


def _make_training_config(hparams: dict, use_pf: bool, soft_cons: float = 0.0) -> TrainingConfig:
    """
    Helper function to create TrainingConfig, uniformly handling loss_weight_mode

    Args:
        hparams: Hyperparameter dictionary
        use_pf: Whether to use pushforward training
        soft_cons: Soft conservation loss weight (only used for baselines)

    Returns:
        TrainingConfig instance

    Usage:
        1. By default uses adaptive loss weights (loss_weight_mode='adaptive')
        2. For manual specification, set hparams['loss_weight_mode'] = 'manual'
        3. Then specify loss weights in hparams['loss_weights']
           Example: {'p_loss': 1.0, 'stability_loss': 0.5, 'cons_loss': 0.1}
    """
    return TrainingConfig(
        num_epochs=hparams['num_epochs'],
        batch_size=hparams['batch_size'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],
        ndt=hparams['ndt'],
        num_workers=hparams['num_workers'],
        use_pushforward=use_pf,
        unroll_steps=hparams['unroll_steps'],
        soft_conservation_weight=soft_cons,
        loss_weight_mode=hparams.get('loss_weight_mode', 'adaptive'),
        loss_weights=hparams.get('loss_weights', {}),
    )


def get_experiment_config(model_name: str, hparams: dict) -> dict:
    """
    Generate experiment configuration based on model name

    Args:
        model_name: Model name
        hparams: Hyperparameter dictionary

    Returns:
        Experiment configuration dictionary

    Loss weight usage:
        - By default uses adaptive loss weights (loss_weight_mode='adaptive')
        - For manual specification, set hparams['loss_weight_mode'] = 'manual'
        - Then specify loss weights in hparams['loss_weights']
        - Available loss terms: p_loss, stability_loss, cons_loss
    """
    use_pf = '_pf' in model_name

    # FluxNet_SW series - new head configuration
    # LAP: RECOMMENDED - L-head + Advection-Pressure decomposition
    if model_name.startswith('FluxNet_SW_LAP_no_gate'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='LAP_no_gate',
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_SW_LAP'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='LAP',
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_SW_PAP'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='PAP',
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_SW_PPP'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='PPP',
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_SW_LPP'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='LPP',
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_SW_NNN'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                head_config='NNN'
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    # Baseline series
    elif model_name.startswith('SW_Baseline'):
        pred_mode = 'residual' if 'residual' in model_name else 'direct'
        soft_cons = hparams.get('soft_cons_weight', 0.1) if 'soft' in model_name else 0.0
        bound_h = 'bound' in model_name

        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_SW_Baseline',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                prediction_mode=pred_mode,
                bound_h=bound_h,
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf, soft_cons)
        }

    # FNO + soft conservation loss
    elif model_name.startswith('FNO_SW_soft'):
        soft_cons = hparams.get('soft_cons_weight', 0.1)
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_SW',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4)
            ),
            'training_config': _make_training_config(hparams, use_pf, soft_cons)
        }

    # FNO series (no projection)
    elif model_name.startswith('FNO_SW') and 'Proj' not in model_name and 'FluxLAP' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_SW',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4)
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    # FNO + Projection
    elif model_name.startswith('FNO_SW_Proj'):
        proj_mode = 'box_mass' if 'box_mass' in model_name else 'box'
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_SW_Proj',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                projection_mode=proj_mode,
                prediction_mode='residual'
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    # CNN + Projection
    elif model_name.startswith('CNN_SW_Proj'):
        proj_mode = 'box_mass' if 'box_mass' in model_name else 'box'
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='CNN_SW_Proj',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                projection_mode=proj_mode,
                prediction_mode='residual'
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    # FNO + FluxLAP head (ablation: FNO backbone with LAP conservation head)
    elif model_name.startswith('FNO_FluxLAP'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FNO_FluxLAP',
                modes=hparams.get('fno_modes', 16),
                width=hparams.get('fno_width', 64),
                num_layers=hparams.get('fno_layers', 4),
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    else:
        raise ValueError(f"Unknown model: {model_name}")


def main():
    # ========================================================================
    # Path configuration (using absolute paths)
    # ========================================================================
    save_path = "/home/ml4pf/zshlan/FluxNet/results/shallow_water/ablation"
    train_folder = "/home/ml4pf/zshlan/FluxNet/dataset/shallow_water/train"
    val_folder = "/home/ml4pf/zshlan/FluxNet/dataset/shallow_water/val"
    # test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/shallow_water/test"
    test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/shallow_water/test_long"
    # test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/shallow_water/test_256"

    # ========================================================================
    # Hyperparameter configuration
    # ========================================================================
    hparams = {
        # Model architecture parameters
        'base_channels': 64,
        'num_blocks': 6,
        'kernel_size': 5,
        'neighborhood_size': 5,
        # Training parameters
        'num_epochs': 50,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'weight_decay': 1e-2,
        'ndt': 1,
        'num_workers': 4,
        'unroll_steps': 5,
        # FNO specific
        'fno_modes': 16,
        'fno_width': 64,
        'fno_layers': 4,
        # ================================================================
        # Loss weight configuration
        # ================================================================
        'loss_weight_mode': 'adaptive',
        'soft_cons_weight': 0.1,
        'loss_weights': {
            'p_loss': 1.0,
            'stability_loss': 0.5,
            'cons_loss': 0.1,
        },
    }

    # ========================================================================
    # Experiment control
    # ========================================================================
    gpu_id = 1
    seed = 42
    run_training = False
    run_evaluation = True
    # evaluate_mode = 'both'        # 'onestep', 'rollout', 'both'
    evaluate_mode = 'rollout'        # 'onestep', 'rollout', 'both'
    visualize_trajectories = 'all'  # 'all', None, or list of h5 paths
    # visualize_trajectories = None  # 'all', None, or list of h5 paths

    # ========================================================================
    # Generate experiment configurations
    # ========================================================================
    ablation_experiments = []
    for model_name in SELECTED_MODELS:
        if model_name in ALL_AVAILABLE_MODELS:
            try:
                exp_config = get_experiment_config(model_name, hparams)
                ablation_experiments.append(exp_config)
            except ValueError as e:
                print(f"Skipping model {model_name}: {e}")
        else:
            print(f"Warning: Model '{model_name}' is not in the available list")

    print(f"\nWill run {len(ablation_experiments)} ablation experiments:")
    for exp in ablation_experiments:
        print(f"  - {exp['name']}")
    print()

    # ========================================================================
    # Run experiments
    # ========================================================================
    results = []

    for i, exp in enumerate(ablation_experiments):
        print(f"\n{'#'*80}")
        print(f"# Ablation Experiment {i+1}/{len(ablation_experiments)}: {exp['name']}")
        print(f"{'#'*80}\n")

        try:
            result = run_single_experiment(
                model_config=exp['model_config'],
                training_config=exp['training_config'],
                dataset_type='shallow_water',
                train_folder=train_folder,
                val_folder=val_folder,
                test_folder=test_folder,
                save_path=save_path,
                experiment_name=exp['name'],
                gpu_id=gpu_id,
                run_training=run_training,
                run_evaluation=run_evaluation,
                seed=seed,
                evaluate_mode=evaluate_mode,
                visualize_trajectories=visualize_trajectories
            )
            results.append(result)
        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    # ========================================================================
    # Generate summary table
    # ========================================================================
    print("\n" + "="*80)
    generate_summary_table(save_path, ablation_experiments, "ablation_summary.md")
    print(f"\nAll ablation experiments completed! Results saved at: {save_path}")


if __name__ == "__main__":
    main()
