"""
Spinodal Decomposition - Ablation Experiments

Complete list of available models (select from these for ablation experiments):

FluxNet model series:
- FluxNet_N: No constraint (formerly FluxNet_U)
- FluxNet_P: Positive flux constraint (softplus)
- FluxNet_L: Lower bound constraint (phi >= 0)
- FluxNet_D: Double bound constraint (0 <= phi <= 1, optimal)

Baseline model series:
- CNN_2D_direct: Direct prediction
- CNN_2D_residual: Residual prediction
- CNN_2D_bound: With double bound constraint (sigmoid)
- CNN_2D_soft: With soft conservation loss

Each model can choose:
- onestep training
- pushforward training (pf)
"""

import os
import sys

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import (
    run_single_experiment, generate_summary_table, ModelConfig, TrainingConfig
)

# ============================================================================
# Complete list of available models (select from these for experiments)
# ============================================================================
ALL_AVAILABLE_MODELS = [
    # FluxNet series - onestep
    'FluxNet_N_onestep',
    'FluxNet_P_onestep',
    'FluxNet_L_onestep',
    'FluxNet_D_onestep',
    # FluxNet series - pushforward
    'FluxNet_N_pf',
    'FluxNet_P_pf',
    'FluxNet_L_pf',
    'FluxNet_D_pf',
    # CNN Baseline series - onestep
    'CNN_2D_direct_onestep',
    'CNN_2D_residual_onestep',
    'CNN_2D_direct_soft_onestep',
    'CNN_2D_residual_soft_onestep',
    'CNN_2D_bound_onestep',
    'CNN_2D_bound_soft_onestep',
    # CNN Baseline series - pushforward
    'CNN_2D_direct_pf',
    'CNN_2D_residual_pf',
    'CNN_2D_direct_soft_pf',
    'CNN_2D_residual_soft_pf',
    'CNN_2D_bound_pf',
    'CNN_2D_bound_soft_pf',
]

# ============================================================================
# List of models to actually run (comment out unwanted models)
# ============================================================================
SELECTED_MODELS = [
    # FluxNet series
    'FluxNet_N_onestep',
    'FluxNet_P_onestep',
    'FluxNet_L_onestep',
    'FluxNet_D_onestep',
    'FluxNet_N_pf',
    'FluxNet_D_pf',
    # CNN Baseline series
    'CNN_2D_residual_onestep',
    'CNN_2D_residual_soft_onestep',
    'CNN_2D_bound_soft_onestep',
    'CNN_2D_residual_pf',
]


def get_experiment_config(model_name: str, hparams: dict) -> dict:
    """Generate experiment configuration based on model name"""
    use_pf = '_pf' in model_name

    # FluxNet series
    if model_name.startswith('FluxNet_N_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_N',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    elif model_name.startswith('FluxNet_P_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_P',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    elif model_name.startswith('FluxNet_L_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_L',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    elif model_name.startswith('FluxNet_D_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    # CNN Baseline series
    elif model_name.startswith('CNN_2D'):
        pred_mode = 'residual' if 'residual' in model_name else 'direct'
        soft_cons = 0.1 if 'soft' in model_name else 0.0
        bound_mode = 'double' if 'bound' in model_name else 'none'

        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='CNN_Baseline_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                prediction_mode=pred_mode,
                bound_mode=bound_mode,
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': TrainingConfig(
                num_epochs=hparams['num_epochs'],
                batch_size=hparams['batch_size'],
                learning_rate=hparams['learning_rate'],
                weight_decay=hparams['weight_decay'],
                ndt=hparams['ndt'],
                num_workers=hparams['num_workers'],
                soft_conservation_weight=soft_cons,
                use_pushforward=use_pf,
                unroll_steps=hparams['unroll_steps']
            )
        }

    else:
        raise ValueError(f"Unknown model: {model_name}")


def main():
    # ========================================================================
    # Path configuration (using absolute paths)
    # ========================================================================
    save_path = "/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation"
    train_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/train"
    val_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/val"
    test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/test"

    # ========================================================================
    # Hyperparameter configuration
    # ========================================================================
    hparams = {
        'base_channels': 64,
        'num_blocks': 4,
        'kernel_size': 3,
        'neighborhood_size': 15,
        'num_epochs': 100,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'weight_decay': 1e-2,
        'ndt': 1,
        'num_workers': 4,
        'unroll_steps': 5,
    }

    # ========================================================================
    # Experiment control
    # ========================================================================
    gpu_id = 0
    seed = 42
    run_training = True
    run_evaluation = True
    evaluate_mode = 'both'
    visualize_trajectories = 'all'

    # ========================================================================
    # Generate experiment configurations
    # ========================================================================
    ablation_experiments = []
    for model_name in SELECTED_MODELS:
        if model_name in ALL_AVAILABLE_MODELS:
            try:
                exp_config = get_experiment_config(model_name, hparams)
                ablation_experiments.append(exp_config)
            except ValueError as e:
                print(f"Skipping model {model_name}: {e}")
        else:
            print(f"Warning: Model '{model_name}' is not in the available list")

    print(f"\nWill run {len(ablation_experiments)} ablation experiments:")
    for exp in ablation_experiments:
        print(f"  - {exp['name']}")
    print()

    # ========================================================================
    # Run experiments
    # ========================================================================
    results = []

    for i, exp in enumerate(ablation_experiments):
        print(f"\n{'#'*80}")
        print(f"# Ablation Experiment {i+1}/{len(ablation_experiments)}: {exp['name']}")
        print(f"{'#'*80}\n")

        try:
            result = run_single_experiment(
                model_config=exp['model_config'],
                training_config=exp['training_config'],
                dataset_type='spinodal_decomposition',
                train_folder=train_folder,
                val_folder=val_folder,
                test_folder=test_folder,
                save_path=save_path,
                experiment_name=exp['name'],
                gpu_id=gpu_id,
                run_training=run_training,
                run_evaluation=run_evaluation,
                seed=seed,
                evaluate_mode=evaluate_mode,
                visualize_trajectories=visualize_trajectories
            )
            results.append(result)
        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    # ========================================================================
    # Generate summary table
    # ========================================================================
    print("\n" + "="*80)
    generate_summary_table(save_path, ablation_experiments, "ablation_summary.md")
    print(f"\nAll ablation experiments completed! Results saved at: {save_path}")


if __name__ == "__main__":
    main()
