"""
Convection-Diffusion - Ablation Experiments

Complete list of available models (select from these for ablation experiments):

FluxNet model series:
- FluxNet_N_1D: No constraint (formerly FluxNet_U_1D)
- FluxNet_P_1D: Positive flux constraint (softplus)
- FluxNet_L_1D: Lower bound constraint (optimal, c >= 0)

Baseline model series:
- CNN_1D_direct: Direct prediction
- CNN_1D_residual: Residual prediction
- CNN_1D_bound: With lower bound constraint (softplus)
- CNN_1D_soft: With soft conservation loss

Each model can choose:
- onestep training
- pushforward training (pf)
"""

import os
import sys

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import (
    run_single_experiment, generate_summary_table, ModelConfig, TrainingConfig
)

# ============================================================================
# Complete list of available models (select from these for experiments)
# ============================================================================
ALL_AVAILABLE_MODELS = [
    # FluxNet series - onestep
    'FluxNet_N_1D_onestep',
    'FluxNet_P_1D_onestep',
    'FluxNet_L_1D_onestep',
    # FluxNet series - pushforward
    'FluxNet_N_1D_pf',
    'FluxNet_P_1D_pf',
    'FluxNet_L_1D_pf',
    # CNN Baseline series - onestep
    'CNN_1D_direct_onestep',
    'CNN_1D_residual_onestep',
    'CNN_1D_direct_soft_onestep',
    'CNN_1D_residual_soft_onestep',
    'CNN_1D_bound_onestep',
    'CNN_1D_bound_soft_onestep',
    # CNN Baseline series - pushforward
    'CNN_1D_direct_pf',
    'CNN_1D_residual_pf',
    'CNN_1D_direct_soft_pf',
    'CNN_1D_residual_soft_pf',
    'CNN_1D_bound_pf',
    'CNN_1D_bound_soft_pf',
]

# ============================================================================
# List of models to actually run
# According to issue126.md: convection-diffusion is a toy case, only needs simple ablation
# ============================================================================
SELECTED_MODELS = [
    # ===== Our method (FluxNet-L) =====
    'FluxNet_L_1D_pf',              # Optimal method: lower bound constraint + pushforward

    # ===== Ablation experiments =====
    'FluxNet_N_1D_pf',              # Ablation: no constraint
    'FluxNet_P_1D_pf',              # Ablation: positive flux
    'FluxNet_L_1D_onestep',         # Ablation: L-head without pushforward
]


def _make_training_config(hparams: dict, use_pf: bool, soft_cons: float = 0.0) -> TrainingConfig:
    """
    Helper function to create TrainingConfig, uniformly handling loss_weight_mode

    Usage:
        1. By default uses adaptive loss weights (loss_weight_mode='adaptive')
        2. For manual specification, set hparams['loss_weight_mode'] = 'manual'
        3. Then specify loss weights in hparams['loss_weights']
           Example: {'p_loss': 1.0, 'stability_loss': 0.5, 'cons_loss': 0.1}
    """
    return TrainingConfig(
        num_epochs=hparams['num_epochs'],
        batch_size=hparams['batch_size'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],
        ndt=hparams['ndt'],
        num_workers=hparams['num_workers'],
        use_pushforward=use_pf,
        unroll_steps=hparams['unroll_steps'],
        soft_conservation_weight=soft_cons,
        loss_weight_mode=hparams.get('loss_weight_mode', 'adaptive'),
        loss_weights=hparams.get('loss_weights', {}),
    )


def get_experiment_config(model_name: str, hparams: dict) -> dict:
    """
    Generate experiment configuration based on model name

    Loss weight usage:
        - By default uses adaptive loss weights (loss_weight_mode='adaptive')
        - For manual specification, set hparams['loss_weight_mode'] = 'manual'
        - Then specify loss weights in hparams['loss_weights']
        - Available loss terms: p_loss (prediction), stability_loss (stability), cons_loss (conservation)
    """
    use_pf = '_pf' in model_name

    # FluxNet series
    if model_name.startswith('FluxNet_N_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_N_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_P_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_P_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_L_1D'):
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_L_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    # CNN Baseline series
    elif model_name.startswith('CNN_1D'):
        pred_mode = 'residual' if 'residual' in model_name else 'direct'
        soft_cons = hparams.get('soft_cons_weight', 0.1) if 'soft' in model_name else 0.0
        bound_mode = 'lower' if 'bound' in model_name else 'none'

        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='CNN_Baseline_1D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                prediction_mode=pred_mode,
                bound_mode=bound_mode,
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf, soft_cons)
        }

    else:
        raise ValueError(f"Unknown model: {model_name}")


def main():
    # ========================================================================
    # Path configuration (using absolute paths)
    # ========================================================================
    save_path = "/home/ml4pf/zshlan/FluxNet/results/convection_diffusion/ablation"
    train_folder = "/home/ml4pf/zshlan/FluxNet/dataset/convection_diffusion/train"
    val_folder = "/home/ml4pf/zshlan/FluxNet/dataset/convection_diffusion/val"
    test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/convection_diffusion/test"

    # ========================================================================
    # Hyperparameter configuration
    # ========================================================================
    hparams = {
        # Model architecture parameters
        'base_channels': 32,
        'num_blocks': 6,
        'kernel_size': 5,
        'neighborhood_size': 3,
        # Training parameters
        'num_epochs': 100,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'weight_decay': 1e-2,
        'ndt': 1,
        'num_workers': 4,
        'unroll_steps': 5,
        # ================================================================
        # Loss weight configuration
        # ================================================================
        'loss_weight_mode': 'adaptive',
        'soft_cons_weight': 0.1,
        'loss_weights': {
            'p_loss': 1.0,
            'stability_loss': 0.5,
            'cons_loss': 0.1,
        },
    }

    # ========================================================================
    # Experiment control
    # ========================================================================
    gpu_id = 2
    seed = 42
    run_training = False
    run_evaluation = False
    evaluate_mode = 'both'
    visualize_trajectories = 'all'

    # ========================================================================
    # Generate experiment configurations
    # ========================================================================
    ablation_experiments = []
    for model_name in SELECTED_MODELS:
        if model_name in ALL_AVAILABLE_MODELS:
            try:
                exp_config = get_experiment_config(model_name, hparams)
                ablation_experiments.append(exp_config)
            except ValueError as e:
                print(f"Skipping model {model_name}: {e}")
        else:
            print(f"Warning: Model '{model_name}' is not in the available list")

    print(f"\nWill run {len(ablation_experiments)} ablation experiments:")
    for exp in ablation_experiments:
        print(f"  - {exp['name']}")
    print()

    # ========================================================================
    # Run experiments
    # ========================================================================
    results = []

    for i, exp in enumerate(ablation_experiments):
        print(f"\n{'#'*80}")
        print(f"# Ablation Experiment {i+1}/{len(ablation_experiments)}: {exp['name']}")
        print(f"{'#'*80}\n")

        try:
            result = run_single_experiment(
                model_config=exp['model_config'],
                training_config=exp['training_config'],
                dataset_type='convection_diffusion',
                train_folder=train_folder,
                val_folder=val_folder,
                test_folder=test_folder,
                save_path=save_path,
                experiment_name=exp['name'],
                gpu_id=gpu_id,
                run_training=run_training,
                run_evaluation=run_evaluation,
                seed=seed,
                evaluate_mode=evaluate_mode,
                visualize_trajectories=visualize_trajectories
            )
            results.append(result)
        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    # ========================================================================
    # Generate summary table
    # ========================================================================
    print("\n" + "="*80)
    generate_summary_table(save_path, ablation_experiments, "ablation_summary.md")
    print(f"\nAll ablation experiments completed! Results saved at: {save_path}")


if __name__ == "__main__":
    main()
