"""
Spinodal Decomposition - Ablation Experiments (ndt=1)

Short-term prediction: ndt=1 means predicting the next frame

Recommended model:
- FluxNet_D: Double bound constraint (0 <= phi <= 1)
"""

import os
import sys

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from experiments.common.experiment_runner import (
    run_single_experiment, generate_summary_table, ModelConfig, TrainingConfig
)

# ============================================================================
# Complete list of available models
# ============================================================================
ALL_AVAILABLE_MODELS = [
    # FluxNet series - onestep
    'FluxNet_N_onestep',
    'FluxNet_P_onestep',
    'FluxNet_L_onestep',
    'FluxNet_D_onestep',
    # FluxNet series - pushforward
    'FluxNet_N_pf',
    'FluxNet_P_pf',
    'FluxNet_L_pf',
    'FluxNet_D_pf',
    # CNN Baseline series - onestep
    'CNN_2D_direct_onestep',
    'CNN_2D_residual_onestep',
    'CNN_2D_direct_soft_onestep',
    'CNN_2D_residual_soft_onestep',
    'CNN_2D_bound_onestep',
    'CNN_2D_bound_soft_onestep',
    # CNN Baseline series - pushforward
    'CNN_2D_direct_pf',
    'CNN_2D_residual_pf',
    'CNN_2D_direct_soft_pf',
    'CNN_2D_residual_soft_pf',
    'CNN_2D_bound_pf',
    'CNN_2D_bound_soft_pf',
]

# ============================================================================
# List of models to actually run
# ============================================================================
SELECTED_MODELS = [
    # FluxNet series
    'FluxNet_D_onestep',
    'FluxNet_D_pf',
    # CNN Baseline series
    'CNN_2D_bound_soft_onestep',
]


def _make_training_config(hparams: dict, use_pf: bool, soft_cons: float = 0.0,
                          dcl_weight: float = 0.1) -> TrainingConfig:
    """
    Helper function to create TrainingConfig, uniformly handling loss_weight_mode

    Args:
        hparams: Hyperparameter dictionary
        use_pf: Whether to use pushforward training
        soft_cons: Soft conservation loss weight (only used for baselines)
        dcl_weight: DCL loss weight (only used for FluxNet-D)

    Returns:
        TrainingConfig instance

    Usage:
        1. By default uses adaptive loss weights (loss_weight_mode='adaptive')
        2. For manual specification, set hparams['loss_weight_mode'] = 'manual'
        3. Then specify loss weights in hparams['loss_weights']
           Example: {'p_loss': 1.0, 'dcl_loss': 0.1, 'stability_loss': 0.5}
    """
    return TrainingConfig(
        num_epochs=hparams['num_epochs'],
        batch_size=hparams['batch_size'],
        learning_rate=hparams['learning_rate'],
        weight_decay=hparams['weight_decay'],
        ndt=hparams['ndt'],
        num_workers=hparams['num_workers'],
        use_pushforward=use_pf,
        unroll_steps=hparams['unroll_steps'],
        soft_conservation_weight=soft_cons,
        dcl_weight=dcl_weight,
        loss_weight_mode=hparams.get('loss_weight_mode', 'adaptive'),
        loss_weights=hparams.get('loss_weights', {}),
    )


def get_experiment_config(model_name: str, hparams: dict) -> dict:
    """
    Generate experiment configuration based on model name

    Loss weight usage:
        - By default uses adaptive loss weights (loss_weight_mode='adaptive')
        - For manual specification, set hparams['loss_weight_mode'] = 'manual'
        - Then specify loss weights in hparams['loss_weights']
        - For FluxNet-D: available terms are p_loss, dcl_loss, stability_loss
        - For baseline: available terms are p_loss, cons_loss, stability_loss
    """
    use_pf = '_pf' in model_name

    # FluxNet系列
    if model_name.startswith('FluxNet_N_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_N',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_P_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_P',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size']
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_L_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_L',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0
            ),
            'training_config': _make_training_config(hparams, use_pf)
        }

    elif model_name.startswith('FluxNet_D_') and '1D' not in model_name:
        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='FluxNet_D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                neighborhood_size=hparams['neighborhood_size'],
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': _make_training_config(
                hparams, use_pf,
                dcl_weight=hparams.get('dcl_weight', 0.1)
            )
        }

    # CNN Baseline series
    elif model_name.startswith('CNN_2D'):
        pred_mode = 'residual' if 'residual' in model_name else 'direct'
        soft_cons = hparams.get('soft_cons_weight', 0.1) if 'soft' in model_name else 0.0
        bound_mode = 'double' if 'bound' in model_name else 'none'

        return {
            'name': model_name,
            'model_config': ModelConfig(
                model_type='CNN_Baseline_2D',
                base_channels=hparams['base_channels'],
                num_blocks=hparams['num_blocks'],
                kernel_size=hparams['kernel_size'],
                prediction_mode=pred_mode,
                bound_mode=bound_mode,
                lower_bound=0.0,
                upper_bound=1.0
            ),
            'training_config': _make_training_config(hparams, use_pf, soft_cons)
        }

    else:
        raise ValueError(f"Unknown model: {model_name}")


def main():
    # ========================================================================
    # Path configuration
    # ========================================================================
    save_path = "/home/ml4pf/zshlan/FluxNet/results/spinodal_decomposition/ablation_ndt1"
    train_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/train"
    val_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/val"
    test_folder = "/home/ml4pf/zshlan/FluxNet/dataset/spinodal_decomposition/test"

    # ========================================================================
    # Hyperparameter configuration - ndt=1 (short-term prediction, corresponding to 10dt)
    # ========================================================================
    hparams = {
        # Model architecture parameters
        'base_channels': 64,
        'num_blocks': 4,
        'kernel_size': 3,
        'neighborhood_size': 15,  # Small neighborhood suitable for small time steps
        # Training parameters
        'num_epochs': 100,
        'batch_size': 16,
        'learning_rate': 1e-3,
        'weight_decay': 1e-2,
        'ndt': 1,  # Short-term prediction, corresponding to 10dt
        'num_workers': 4,
        'unroll_steps': 5,
        # ================================================================
        # Loss weight configuration
        # ================================================================
        # Loss weight mode: 'adaptive' or 'manual'
        'loss_weight_mode': 'adaptive',
        # DCL loss weight (for FluxNet-D)
        'dcl_weight': 0.1,
        # Soft conservation loss weight (for baseline)
        'soft_cons_weight': 0.1,
        # Manual loss weights (takes effect when loss_weight_mode='manual')
        # Available loss terms: p_loss (prediction), dcl_loss (dual consistency), stability_loss (stability), cons_loss (conservation)
        'loss_weights': {
            'p_loss': 1.0,
            'dcl_loss': 0.1,
            'stability_loss': 0.5,
        },
    }

    # ========================================================================
    # Experiment control
    # ========================================================================
    gpu_id = 0
    seed = 42
    run_training = True
    run_evaluation = True
    evaluate_mode = 'both'
    visualize_trajectories = 'all'

    # ========================================================================
    # Generate experiment configurations
    # ========================================================================
    ablation_experiments = []
    for model_name in SELECTED_MODELS:
        if model_name in ALL_AVAILABLE_MODELS:
            try:
                exp_config = get_experiment_config(model_name, hparams)
                ablation_experiments.append(exp_config)
            except ValueError as e:
                print(f"Skipping model {model_name}: {e}")
        else:
            print(f"Warning: Model '{model_name}' is not in the available list")

    print(f"\n[ndt=1] Will run {len(ablation_experiments)} ablation experiments:")
    for exp in ablation_experiments:
        print(f"  - {exp['name']}")
    print()

    # ========================================================================
    # Run experiments
    # ========================================================================
    results = []

    for i, exp in enumerate(ablation_experiments):
        print(f"\n{'#'*80}")
        print(f"# Ablation Experiment {i+1}/{len(ablation_experiments)}: {exp['name']} (ndt=1)")
        print(f"{'#'*80}\n")

        try:
            result = run_single_experiment(
                model_config=exp['model_config'],
                training_config=exp['training_config'],
                dataset_type='spinodal_decomposition',
                train_folder=train_folder,
                val_folder=val_folder,
                test_folder=test_folder,
                save_path=save_path,
                experiment_name=exp['name'],
                gpu_id=gpu_id,
                run_training=run_training,
                run_evaluation=run_evaluation,
                seed=seed,
                evaluate_mode=evaluate_mode,
                visualize_trajectories=visualize_trajectories
            )
            results.append(result)
        except Exception as e:
            print(f"Experiment failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    # ========================================================================
    # Generate summary table
    # ========================================================================
    print("\n" + "="*80)
    generate_summary_table(save_path, ablation_experiments, "ablation_summary_ndt1.md")
    print(f"\nAll ablation experiments completed! Results saved at: {save_path}")


if __name__ == "__main__":
    main()
