"""EBM Training Package - Clean modular architecture."""

from .datasets import (
    PromptResponseDataset,
    CSVDataset, 
    HFDataset, 
    collate
)
from .negative_sampling import (
    mask_sentences,
    mask_tokens,
    sample_negative_responses,
    sample_negative_prompts,
)
from .losses import (
    contrastive_loss,
    individual_contrastive_losses,
    summed_contrastive_loss,
    infonce_loss,
    individual_infonce_losses,
    compute_accuracies,
)
from .utils import (
    get_cached_batch,
    compute_batch_energy_stats,
    calculate_averages_from_history,
    format_pbar_postfix,
    plot_timeseries,
    analyze_per_sample,
    load_checkpoint,
    save_checkpoint,
    upload_file_to_drive,
    set_seed,
)
from .validation import validate_model
from .training import train_one_epoch
from .config import (
    TrainingConfig,
    DataConfig,
    TrainingState,
    create_training_state,
)

__all__ = [
    "PromptResponseDataset",
    "CSVDataset",
    "HFDataset",
    "collate",
    "mask_sentences",
    "mask_tokens",
    "sample_negative_responses",
    "sample_negative_prompts",
    "contrastive_loss",
    "individual_contrastive_losses",
    "summed_contrastive_loss",
    "infonce_loss",
    "individual_infonce_losses",
    "compute_accuracies",
    "set_seed",
    "get_cached_batch",
    "compute_batch_energy_stats",
    "calculate_averages_from_history",
    "format_pbar_postfix",
    "plot_timeseries",
    "analyze_per_sample",
    "load_checkpoint",
    "save_checkpoint",
    "upload_file_to_drive",
    "validate_model",
    "train_one_epoch",
    "TrainingConfig",
    "DataConfig",
    "TrainingState",
    "create_training_state",
]
