"""
Clean, Config-Driven EBM Training Script.

This script serves as the main entry point for the training and evaluation
pipeline. It orchestrates the entire process, from parsing command-line
arguments and loading data to initializing the model and executing the
training loop.

The architecture is fully modular and configuration-driven:
- All parameters are defined in typed dataclasses in `src/ebm_training/config.py`
  and `src/energy_model/config/ebm_configs.py`.
- The main script (`main` function) populates these config objects from CLI arguments.
- Helper functions handle setup, data loading, and model creation.
- The core training and validation logic resides in the `ebm_training` package,
  which is called from the main loop.
- All dynamic state (metrics, progress) is managed in a `TrainingState` object.

Usage:
    python train_ebm_clean.py [options]
"""
import os
import warnings
warnings.filterwarnings("ignore")

import argparse
import json
import logging
import sys
import time
from pathlib import Path

import math
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add parent directory of 'src' to path to allow imports
# Assumes the script is in 'scripts/' and 'src/' is in the parent directory
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))

from src.energy_model.models import EnergyModel
from src.energy_model.config import EBMConfig
from src.ebm_training import (
    TrainingConfig, DataConfig, TrainingState,
    PromptResponseDataset, collate,
    train_one_epoch, validate_model,
    create_training_state, calculate_averages_from_history,
    load_checkpoint, save_checkpoint,
    analyze_per_sample, plot_timeseries,
    upload_file_to_drive, set_seed,
)


def setup_logging(results_dir: str, verbose: bool = False) -> None:
    """Initializes the logging system to output to console and a file.

    Args:
        results_dir (str): The directory where the log file will be saved.
        verbose (bool, optional): If True, sets the logging level to DEBUG.
            Defaults to False (INFO).
    """
    os.makedirs(results_dir, exist_ok=True)
    log_file_path = Path(results_dir) / "ebm_training.log"

    level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[], 
    )

    logger = logging.getLogger()
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(logging.FileHandler(log_file_path))


def parse_arguments() -> argparse.Namespace:
    """Sets up and parses command-line arguments.

    The arguments are organized into groups that correspond to the different
    configuration dataclasses (`DataConfig`, `TrainingConfig`, `EBMConfig`).
    Default values are pulled directly from the config class definitions,
    making this the single source of truth for the script's interface.

    Returns:
        argparse.Namespace: An object containing all parsed command-line arguments.
    """
    parser = argparse.ArgumentParser(description="Train the Energy-Based Model.")
    
    # Get default configs to populate argparse defaults
    data_defaults = DataConfig()
    train_defaults = TrainingConfig()
    model_defaults = EBMConfig()

    # --- Data Arguments ---
    data_group = parser.add_argument_group('Data Configuration')
    data_group.add_argument('--dataset-type', type=str, default=data_defaults.dataset_type, 
                            choices=["csv", "hf"], help="Type of dataset to load.")
    data_group.add_argument('--csv-path', type=str, default=data_defaults.csv_path, 
                            help="Path to the primary data file, if dataset type is CSV.")
    data_group.add_argument('--hf-dataset-name', type=str, default=data_defaults.hf_dataset_name,
                            help="Name of the HuggingFace dataset to use, if dataset type is HF.")
    data_group.add_argument('--hf-dataset-config', type=str, default=data_defaults.hf_dataset_config,
                             help="Configuration of the HuggingFace dataset to use (e.g., 'all', 'medicine').")
    data_group.add_argument('--hf-dataset-split', type=str, default=data_defaults.hf_dataset_split,
                            help="Split of the HuggingFace dataset to use, if dataset type is HF.")
    data_group.add_argument('--gpt2-path', type=str, default=data_defaults.gpt2_path, 
                            help="Path to CSV with separate GPT-2 responses for the HF dataset, if dataset type is HF.")
    data_group.add_argument('--prompt-col', type=str, default=data_defaults.prompt_col,
                            help="The column name for prompts in the dataset.")
    data_group.add_argument('--response-col', type=str, default=data_defaults.response_col,
                            help="The column name for the base LLM's responses ('answer' for CSV or 'chatgpt_answers' for HF dataset).")
    data_group.add_argument('--human-col', type=str, default=data_defaults.human_col,
                            help="The column name for the human responses.")
    data_group.add_argument('--gpt2-col', type=str, default=data_defaults.gpt2_col,
                            help="The column name for GPT2's responses.")
    data_group.add_argument('--data-frac', type=float, default=data_defaults.data_frac, 
                            help="Fraction of the full dataset to use (0 < data_frac ≤ 1.0).")
    data_group.add_argument('--val-split', type=float, default=data_defaults.val_split, 
                            help="Fraction of the selected data to use for validation.")
    
    # --- Training Arguments ---
    train_group = parser.add_argument_group('Training Configuration')
    train_group.add_argument('--epochs', type=int, default=train_defaults.epochs,
                             help="Number of training epochs.")
    train_group.add_argument('--batch-size', type=int, default=train_defaults.batch_size,
                             help="Training batch size.")
    train_group.add_argument('--lr', type=float, default=train_defaults.lr, help="Learning rate.")
    train_group.add_argument('--loss-strategy', type=str, default=train_defaults.loss_strategy, 
                             choices=["infonce", "infonce_expanded", "sum", "weighted_sum", "sequential"],
                             help="The loss calculation and model update strategy to use.")
    train_group.add_argument('--margin', type=float, default=train_defaults.margin,
                             help="Margin for the contrastive loss in the 'sequential', 'sum', and 'weighted_sum' strategies.")
    train_group.add_argument('--temperature', type=float, default=train_defaults.temperature,
                             help="Temperature for the InfoNCE loss in 'infonce' and 'infonce_expanded' strategies.")
    train_group.add_argument('--off-context-weight', type=float, default=train_defaults.off_context_weight,
                             help="Weight for off-context negatives in 'weighted_sum' strategy.")
    train_group.add_argument('--k-candidates', type=int, default=train_defaults.k_candidates, 
                             help="Candidates for hard negative mining.")
    
    # --- System and Resuming ---
    sys_group = parser.add_argument_group('System and Saving')
    sys_group.add_argument('--seed', type=int, default=train_defaults.seed,
                           help="Random seed for reproducibility.")
    sys_group.add_argument('--device', type=str, default=train_defaults.device,
                           help="Device to use (e.g. 'auto', 'cuda', 'cpu', etc.).")
    sys_group.add_argument('--num-workers', type=int, default=train_defaults.num_workers,
                           help="Number of worker processes for data loading.")
    sys_group.add_argument('--log-verbose', action='store_true', default=train_defaults.log_verbose,
                           help='Enable verbose logging (level = "DEBUG").')
    sys_group.add_argument('--results-dir', type=str, default=train_defaults.results_dir,
                             help="Directory to save all outputs (logs, models, plots).")
    sys_group.add_argument('--save-every-n-epochs', type=int, default=train_defaults.save_every_n_epochs,
                             help="Save a checkpoint every N epochs. Set to 0 to disable epoch-based saving.")
    sys_group.add_argument('--save-every-n-batches', type=int, default=train_defaults.save_every_n_batches,
                             help="Save a checkpoint every N batches. Set to 0 to disable batch-based saving.")
    sys_group.add_argument('--save-best-per-method', action='store_true', default=train_defaults.save_best_per_method,
                             help="If set, save a separate best model checkpoint for each sampling method's accuracy.")
    sys_group.add_argument('--resume-from-checkpoint', type=str, default=train_defaults.resume_from_checkpoint,
                           help="Path to checkpoint file to resume training from.")
    sys_group.add_argument("--no-final-analysis", dest="run_final_analysis", action="store_false",
                             help="Skip the final per-sample analysis and CSV generation after training.")
    sys_group.add_argument('--evaluate-only', type=str, default=train_defaults.evaluate_only, metavar="PATH_TO_MODEL",
                           help="Skip training and run a final evaluation on the model at the specified path.")
    sys_group.add_argument('--upload-to-gdrive', action='store_true', default=train_defaults.upload_to_gdrive,
                           help="Enable uploading checkpoints and models to Google Drive.")
    sys_group.add_argument('--gdrive-folder-id', type=str, default=train_defaults.gdrive_folder_id,
                           help="The ID of the Google Drive folder to upload files to.")
    sys_group.add_argument('--gdrive-creds-path', type=str, default=train_defaults.gdrive_creds_path,
                           help="Path to the PyDrive2 credentials file (creds.dat).")
    parser.set_defaults(run_final_analysis=train_defaults.run_final_analysis, save_best_per_method=train_defaults.save_best_per_method)

    # --- EBM Model Arguments ---
    model_group = parser.add_argument_group('EBM Model Configuration')
    model_group.add_argument('--text-encoder-model-type', type=str, default=model_defaults.text_encoder_model_type,
                             choices=["SentenceBERT", "BERT_CLS"], help="Type of sentence embedding model to use.")
    model_group.add_argument('--placeholder-type', type=str, default=model_defaults.placeholder_type,
                             choices=["learnable", "static"], help="Type of placeholder embedding for padding sentences.")
    model_group.add_argument('--placeholder-init-type', type=str, default=model_defaults.placeholder_init_type,
                             choices=["random", "zero"], help="Initialization for learnable placeholders.")
    model_group.add_argument('--n-sentences', type=int, default=model_defaults.n_sentences,
                             help="Number of sentences to normalize text to by padding or truncating.")
    model_group.add_argument('--d-model', type=int, default=model_defaults.d_model,
                             help="The embedding dimension of the model.")
    model_group.add_argument('--self-attention-n-layers', type=int, default=model_defaults.self_attention_n_layers,
                             help="Number of self-attention encoder layers.")
    model_group.add_argument('--cross-attention-n-layers', type=int, default=model_defaults.cross_attention_n_layers,
                             help="Number of cross-attention encoder layers.")
    model_group.add_argument('--attention-n-heads', type=int, default=model_defaults.attention_n_heads,
                             help="Number of attention heads.")
    model_group.add_argument('--dropout-rate', type=float, default=model_defaults.dropout_rate,
                             help="Dropout rate for the EBM's attention layers.")
    model_group.add_argument('--energy-head-mlp-layers', type=int, default=model_defaults.energy_head_mlp_layers,
                             help="Number of MLP layers in the final energy head.")
    model_group.add_argument('--energy-head-hidden-factor', type=int, default=model_defaults.energy_head_hidden_factor,
                             help="Multiplier for the hidden dimension size in the energy head MLP.")
    model_group.add_argument('--energy-head-pooling-type', type=str, default=model_defaults.energy_head_pooling_type,
                             choices=["flatten", "attention"], help="Pooling strategy for the energy head.")
    model_group.add_argument('--activation-fn', type=str, default=model_defaults.activation_fn,
                             choices=["ReLU", "GELU"], help="Activation function for MLP layers.")
    
    return parser.parse_args()


def create_configs_from_args(
    args: argparse.Namespace
) -> tuple[DataConfig, TrainingConfig, EBMConfig]:
    """Creates the dataclass config objects from the parsed CLI arguments.

    This function acts as a bridge between the unstructured `argparse.Namespace`
    and the strongly-typed `DataConfig`, `TrainingConfig`, and `EBMConfig`
    dataclasses, populating each with the relevant command-line arguments.

    Args:
        args (argparse.Namespace): The object containing all parsed command-line arguments.

    Returns:
        tuple[DataConfig, TrainingConfig, EBMConfig]: A tuple containing the
            initialized data, training, and model configuration objects.
    """
    data_cfg = DataConfig(
        dataset_type=args.dataset_type,
        csv_path=args.csv_path,
        hf_dataset_name=args.hf_dataset_name,
        hf_dataset_config=args.hf_dataset_config,
        hf_dataset_split=args.hf_dataset_split,
        gpt2_path=args.gpt2_path,
        prompt_col=args.prompt_col,
        response_col=args.response_col,
        human_col=args.human_col,
        gpt2_col=args.gpt2_col,
        data_frac=args.data_frac,
        val_split=args.val_split,
    )
    
    train_cfg = TrainingConfig(
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        loss_strategy=args.loss_strategy,
        margin=args.margin,
        temperature=args.temperature,
        off_context_weight=args.off_context_weight,
        k_candidates=args.k_candidates,
        seed=args.seed,
        device=args.device,
        num_workers=args.num_workers,
        log_verbose=args.log_verbose,
        results_dir=args.results_dir,
        save_every_n_epochs=args.save_every_n_epochs,
        save_every_n_batches=args.save_every_n_batches,
        save_best_per_method=args.save_best_per_method,
        resume_from_checkpoint=args.resume_from_checkpoint,
        run_final_analysis=args.run_final_analysis,
        evaluate_only=args.evaluate_only,
        upload_to_gdrive=args.upload_to_gdrive,
        gdrive_folder_id=args.gdrive_folder_id,
        gdrive_creds_path=args.gdrive_creds_path,
    )

    model_cfg = EBMConfig(
        text_encoder_model_type=args.text_encoder_model_type,
        placeholder_type=args.placeholder_type,
        placeholder_init_type=args.placeholder_init_type,
        n_sentences=args.n_sentences,
        d_model=args.d_model,
        self_attention_n_layers=args.self_attention_n_layers,
        cross_attention_n_layers=args.cross_attention_n_layers,
        attention_n_heads=args.attention_n_heads,
        dropout_rate=args.dropout_rate,
        energy_head_mlp_layers=args.energy_head_mlp_layers,
        energy_head_hidden_factor=args.energy_head_hidden_factor,
        energy_head_pooling_type=args.energy_head_pooling_type,
        activation_fn=args.activation_fn,
    )
    
    return data_cfg, train_cfg, model_cfg


def setup_device(train_cfg: TrainingConfig) -> torch.device:
    """Determines and logs the computing device (CPU or CUDA).

    Args:
        train_cfg (TrainingConfig): The training configuration object.

    Returns:
        torch.device: The configured PyTorch device.
    """
    if train_cfg.device == 'auto':
        if torch.cuda.is_available():
            device = torch.device('cuda')
            logging.info(f"Using CUDA device: {torch.cuda.get_device_name()}")
        else:
            device = torch.device('cpu')
            logging.info("CUDA not available, using CPU.")
    else:
        device = torch.device(train_cfg.device)
        logging.info(f"Using specified device: {device}")

    return device


def create_datasets(
    data_cfg: DataConfig, 
    train_cfg: TrainingConfig
) -> tuple:
    """Loads, processes, and splits the datasets based on configuration.

    Handles both local CSV files and datasets from the Hugging Face Hub by
    first loading them into a unified `datasets` object. It then performs
    column normalization, optional data merging, subsampling, and splitting.

    Args:
        data_cfg (DataConfig): Configuration for the dataset source and structure.
        train_cfg (TrainingConfig): Configuration for training parameters like the seed.

    Returns:
        tuple: A tuple containing the train_ds, val_ds, and full_ds.
    """
    if data_cfg.dataset_type == "csv":
        logging.info(f"Loading data from CSV: {data_cfg.csv_path}")
        hf_full = load_dataset("csv", data_files=data_cfg.csv_path, 
                               split=data_cfg.hf_dataset_split)
    elif data_cfg.dataset_type == "hf":
        logging.info(f"Loading data from Hub: {data_cfg.hf_dataset_name}")
        hf_full = load_dataset(
            data_cfg.hf_dataset_name, data_cfg.hf_dataset_config,
            split=data_cfg.hf_dataset_split, trust_remote_code=True
        )
    else:
        msg = f"Unsupported data type: {data_cfg.dataset_type}"
        raise ValueError(msg)
    
    column_mapping = {col: col.lower() for col in hf_full.column_names}
    hf_full = hf_full.rename_columns(column_mapping)
    logging.info(f"Standardized columns to lowercase: {hf_full.column_names}")

    if data_cfg.gpt2_path:
        logging.info(f"Merging with GPT-2 data from: {data_cfg.gpt2_path}")
        gpt2_df = pd.read_csv(data_cfg.gpt2_path)
        gpt2_df.columns = gpt2_df.columns.str.lower()
        gpt2_lookup = pd.Series(
            gpt2_df[data_cfg.gpt2_col].values, index=gpt2_df[data_cfg.prompt_col]
        ).fillna("").to_dict()
        
        hf_full = hf_full.map(
            lambda ex: {data_cfg.gpt2_col: gpt2_lookup.get(ex.get(data_cfg.prompt_col, ""), "")},
            num_proc=train_cfg.num_workers if train_cfg.num_workers > 0 else None
        )

    if data_cfg.data_frac < 1.0:
        logging.info(f"Subsampling data to {data_cfg.data_frac*100:.2f}%...")
        num_samples = int(len(hf_full) * data_cfg.data_frac)
        hf_full = hf_full.shuffle(seed=train_cfg.seed).select(range(num_samples))
        
    split = hf_full.train_test_split(test_size=data_cfg.val_split, seed=train_cfg.seed)
    
    train_ds = PromptResponseDataset(split["train"], data_config=data_cfg)
    val_ds = PromptResponseDataset(split["test"], data_config=data_cfg)
    full_ds = PromptResponseDataset(hf_full, data_config=data_cfg)

    logging.info(f"Data loaded successfully: {len(train_ds)} train samples, {len(val_ds)} val samples.")
    
    return train_ds, val_ds, full_ds


def create_data_loaders(
    train_ds, 
    val_ds, 
    train_cfg: TrainingConfig
) -> tuple[DataLoader, DataLoader]:
    """Creates the DataLoader objects for training and validation.

    This function wraps the provided dataset objects in PyTorch DataLoaders,
    configuring them with parameters from the `TrainingConfig` such as
    batch size, number of workers, and whether to pin memory.

    Args:
        train_ds (Dataset): The training dataset.
        val_ds (Dataset): The validation dataset.
        train_cfg (TrainingConfig): The configuration object for training parameters.

    Returns:
        tuple[DataLoader, DataLoader]: A tuple containing the configured
            training and validation DataLoaders.
    """
    train_loader = DataLoader(
        train_ds, batch_size=train_cfg.batch_size, shuffle=True,
        collate_fn=collate, num_workers=train_cfg.num_workers, pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds, batch_size=train_cfg.batch_size, shuffle=False,
        collate_fn=collate, num_workers=train_cfg.num_workers, pin_memory=True,
    )
    logging.info(f"DataLoaders created: {len(train_loader)} train batches, {len(val_loader)} val batches.")

    return train_loader, val_loader


def create_model(
    model_cfg: EBMConfig, 
    train_cfg: TrainingConfig, 
    device: torch.device
) -> tuple[EnergyModel, Optimizer]:
    """Instantiates the EnergyModel and its corresponding optimizer.

    Args:
        model_cfg (EBMConfig): Configuration for the model architecture.
        train_cfg (TrainingConfig): Configuration for training parameters like the learning rate.
        device (torch.device): The device to move the model to.

    Returns:
        tuple[EnergyModel, Optimizer]: The initialized model and optimizer.
    """
    model = EnergyModel(model_cfg).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=train_cfg.lr)    
    logging.info(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters.")

    return model, optimizer


def log_and_save_evaluation_summary(
    batch_accuracy_history: list[dict],
    dataset_name: str,
    results_dir: str,
) -> None:
    """Calculates, logs, and saves the final accuracy summary for an evaluation run.

    This function is used in the `evaluate-only` mode. It computes the mean
    accuracy for each sampling method from the collected batch results, prints a
    formatted summary to the console, and saves the same summary to a text file.

    Args:
        batch_accuracy_history (list[dict]): A list of dictionaries, where each
            dictionary contains the accuracy metrics for a single batch.
        dataset_name (str): The name of the dataset split being evaluated
            (e.g., "Validation Set").
        results_dir (str): The directory where the summary file will be saved.
        sampling_methods (list[str]): The list of all sampling methods used,
            which determines the metrics to be calculated.
    """
    avg_accuracies = calculate_averages_from_history(batch_accuracy_history)

    summary_lines = [f"\n--- {dataset_name} Evaluation Summary ---"]
    sorted_methods = sorted(avg_accuracies.keys(), key=lambda k: (k != 'overall', k))

    if 'overall' in sorted_methods:
        summary_lines.append(f"Overall Accuracy: {avg_accuracies.get('overall', 0.0):.2f}%")
    for method in sorted_methods:
        if method != "overall":
            acc = avg_accuracies.get(method, 0.0)
            summary_lines.append(f"  {method:<25s} | Accuracy: {acc:.2f}%")

    summary_text = "\n".join(summary_lines)
    logging.info(summary_text + "\n")

    filename = f"ebm_eval_{dataset_name.lower().replace(' ', '_')}.log"
    eval_log_path = Path(results_dir) / filename
    with open(eval_log_path, "w") as f:
        f.write(summary_text.lstrip())
    logging.info(f"Saved evaluation summary to: {eval_log_path}")


def log_and_save_epoch_end(
    epoch: int,
    epoch_start_time: float,
    model: EnergyModel,
    opt: Optimizer,
    state: TrainingState,
    train_cfg: TrainingConfig,
) -> None:
    """Handles all logging, checkpointing, and model saving at the end of an epoch.

    This helper function consolidates all tasks that need to be performed after
    a validation run is complete:
    1.  Prints a detailed multi-line summary of epoch metrics to the console.
    2.  Checks for new best validation accuracies and saves `best_model_...pt` files.
    3.  Saves a full training checkpoint if scheduled.
    4.  Writes a concise one-line summary to `ebm_training_epochs.log`.

    Args:
        epoch (int): The just-completed epoch number.
        epoch_start_time (float): The `time.time()` timestamp from the start of the epoch.
        model (EnergyModel): The model being trained.
        opt (Optimizer): The optimizer, required for saving checkpoints.
        state (TrainingState): The object holding all dynamic training state.
        train_cfg (TrainingConfig): The configuration object for training parameters.
    """
    # Check for new best accuracies and save best models
    for method in state.best_acc:
        latest_acc = state.epoch_avg_accuracy_val[method][-1]
        if not np.isnan(latest_acc) and latest_acc > state.best_acc.get(method, 0.0):
            state.best_acc[method] = latest_acc
            logging.info(f"New best val acc for '{method}': {latest_acc:.2f}%")
            if method == "overall" or train_cfg.save_best_per_method:
                model_save_path = Path(train_cfg.results_dir) / f"best_model_{method}.pt"
                torch.save(model.state_dict(), model_save_path)
                logging.info(f"Saved best model for '{method}' to: {model_save_path}")
                if train_cfg.upload_to_gdrive:
                    upload_file_to_drive(
                        train_cfg.gdrive_folder_id, model_save_path, 
                        train_cfg.gdrive_creds_path, train_cfg.results_dir)

    # Save epoch checkpoint if configured
    if train_cfg.save_every_n_epochs > 0 and (epoch % train_cfg.save_every_n_epochs == 0 or epoch == train_cfg.epochs):
        checkpoint_path = save_checkpoint(
            epoch, state.global_step, model, opt, 
            state.epoch_avg_energy_train, state.epoch_avg_energy_val,
            state.epoch_avg_loss_train, state.epoch_avg_loss_val,
            state.epoch_avg_accuracy_train, state.epoch_avg_accuracy_val,
            state.batch_loss_train, state.batch_loss_val,
            state.batch_accuracy_train, state.batch_accuracy_val,
            state.batch_energy_train, state.batch_energy_val,
            state.best_acc, state.encoding_cache, train_cfg.results_dir
        )
        if train_cfg.upload_to_gdrive:
            upload_file_to_drive(
                train_cfg.gdrive_folder_id, checkpoint_path, 
                train_cfg.gdrive_creds_path, train_cfg.results_dir)
            
    # Write a detailed summary
    epochs_log_path = Path(train_cfg.results_dir) / "ebm_training_epochs.log"
    log_lines = []

    overall_train_loss = state.epoch_avg_loss_train['overall'][-1]
    overall_val_loss = state.epoch_avg_loss_val['overall'][-1]
    overall_train_acc = state.epoch_avg_accuracy_train['overall'][-1]
    overall_val_acc = state.epoch_avg_accuracy_val['overall'][-1]
    best_overall_acc = state.best_acc['overall']
    
    header = (
        f"Epoch {epoch:03d} | Time: {time.time() - epoch_start_time:5.1f}s | "
        f"Overall Loss(T/V): {overall_train_loss:7.4f}/{overall_val_loss:7.4f} | "
        f"Overall Acc(T/V): {overall_train_acc:6.2f}%/{overall_val_acc:6.2f}% (Best: {best_overall_acc:6.2f}%)"
    )
    log_lines.append(header)

    for method in sorted(state.best_acc.keys()):
        if method == "overall":
            continue
        train_loss = state.epoch_avg_loss_train.get(method, [float('nan')])[-1]
        val_loss = state.epoch_avg_loss_val.get(method, [float('nan')])[-1]
        train_acc = state.epoch_avg_accuracy_train.get(method, [float('nan')])[-1]
        val_acc = state.epoch_avg_accuracy_val.get(method, [float('nan')])[-1]
        log_lines.append(
            f"  - {method:<25s} | Loss(T/V): {train_loss:7.4f}/{val_loss:7.4f} | Acc(T/V): {train_acc:6.2f}%/{val_acc:6.2f}%"
        )
    
    final_log_entry = "\n".join(log_lines)
    logging.info("\n" + final_log_entry + "\n")
    with open(epochs_log_path, "a") as f:
        f.write(final_log_entry + "\n\n")
    logging.info(f"Saved epoch summary to: {epochs_log_path}")


def main():
    """Main script entry point to orchestrate the training and evaluation pipeline."""
    
    args = parse_arguments()
    
    # 1) Setup
    setup_logging(args.results_dir, args.log_verbose)
    logging.info("Starting EBM training script...")
    
    data_cfg, train_cfg, model_cfg = create_configs_from_args(args)
    set_seed(train_cfg.seed)
    device = setup_device(train_cfg)

    os.makedirs(train_cfg.results_dir, exist_ok=True)
    logging.info(f"Saving config to: {Path(train_cfg.results_dir)}")
    config_dict = {
        'data_config': data_cfg.__dict__,
        'model_config': model_cfg.__dict__,
        'training_config': train_cfg.__dict__
    }
    with open(Path(train_cfg.results_dir) / 'config.json', 'w') as f:
        json.dump(config_dict, f, indent=2)

    # 2) Data Loading
    logging.info("Creating datasets and dataloaders...")
    train_ds, val_ds, full_ds = create_datasets(data_cfg, train_cfg)
    train_loader, val_loader = create_data_loaders(train_ds, val_ds, train_cfg)

    # 3) Model and Optimizer
    logging.info("Creating model and optimizer...")
    model, opt = create_model(model_cfg, train_cfg, device)

    # 4) Methods Initialization
    sampling_methods = ["positive", "sentence_masking", "token_masking", "off_context", "human", "gpt2",
                        "sentence_masking_prompt", "token_masking_prompt", "off_context_prompt"]
    if train_cfg.loss_strategy == "infonce_expanded":
        sampling_methods.append("off_context_batch_avg")
    state = create_training_state(sampling_methods)

    # 5) Evaluation-Only Mode
    if train_cfg.evaluate_only:
        logging.info(f"Running evaluation only mode for model: {train_cfg.evaluate_only}")
        model.load_state_dict(torch.load(train_cfg.evaluate_only, map_location=device))
        logging.info("Model state loaded.")

        logging.info("Running evaluation on training set...")
        validate_model(model, train_loader, full_ds, device, train_cfg, data_cfg, state,
                       is_eval_only=True)
        
        logging.info("Plotting evaluation summaries for training set...")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_energy_val)[c].tolist() for c in pd.DataFrame(state.batch_energy_val).columns}, 
            train_cfg.results_dir, "Eval Train", "Energy", "Batch #")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_loss_val)[c].tolist() for c in pd.DataFrame(state.batch_loss_val).columns}, 
            train_cfg.results_dir, "Eval Train", "Loss", "Batch #")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_accuracy_val)[c].tolist() for c in pd.DataFrame(state.batch_accuracy_val).columns}, 
            train_cfg.results_dir, "Eval Train", "Accuracy", "Batch #")
        log_and_save_evaluation_summary(state.batch_accuracy_val, "Train Set", train_cfg.results_dir)

        state.batch_energy_val.clear()
        state.batch_loss_val.clear()
        state.batch_accuracy_val.clear()

        logging.info("Running evaluation on validation set...")
        validate_model(model, val_loader, full_ds, device, train_cfg, data_cfg, state,
                       is_eval_only=True)

        logging.info("Plotting evaluation summaries for validation set...")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_energy_val)[c].tolist() for c in pd.DataFrame(state.batch_energy_val).columns}, 
            train_cfg.results_dir, "Eval Val", "Energy", "Batch #")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_loss_val)[c].tolist() for c in pd.DataFrame(state.batch_loss_val).columns}, 
            train_cfg.results_dir, "Eval Val", "Loss", "Batch #")
        plot_timeseries(
            {c: pd.DataFrame(state.batch_accuracy_val)[c].tolist() for c in pd.DataFrame(state.batch_accuracy_val).columns}, 
            train_cfg.results_dir, "Eval Val", "Accuracy", "Batch #")
        log_and_save_evaluation_summary(state.batch_accuracy_val, "Val Set", train_cfg.results_dir)

        if train_cfg.run_final_analysis:
            logging.info("Running final per-sample analysis...")
            analyze_per_sample(model, val_ds, device, train_cfg, data_cfg, state.encoding_cache, "eval_val")
            analyze_per_sample(model, train_ds, device, train_cfg, data_cfg, state.encoding_cache, "eval_train")
            
        logging.info("Evaluation complete.")
        return

    # 6) Training Mode Initialization
    if train_cfg.resume_from_checkpoint:
        logging.info(f"Resuming training from: {train_cfg.resume_from_checkpoint}")
        (state.start_epoch, state.global_step, state.epoch_avg_energy_train,
         state.epoch_avg_energy_val, state.epoch_avg_loss_train, state.epoch_avg_loss_val,
         state.epoch_avg_accuracy_train, state.epoch_avg_accuracy_val,
         state.batch_loss_train, state.batch_loss_val, state.batch_accuracy_train,
         state.batch_accuracy_val, state.batch_energy_train, state.batch_energy_val,
         state.best_acc, state.encoding_cache
        ) = load_checkpoint(train_cfg.resume_from_checkpoint, model, opt, device)
        
        if state.global_step > 0 and state.global_step % len(train_loader) != 0:
            state.resume_batch_idx = state.global_step % len(train_loader)
            logging.info(f"Resuming training from epoch {state.start_epoch}, batch {state.resume_batch_idx + 1}.")
        elif state.global_step > 0:
            state.start_epoch += 1
            logging.info(f"Resuming training from the start of epoch {state.start_epoch}.")

    # 7) Main Training Loop
    logging.info(f"Starting training from epoch {state.start_epoch} for {train_cfg.epochs} epochs...")
    start_time = time.time()

    for epoch in range(state.start_epoch, train_cfg.epochs + 1):
        logging.info(f"Starting epoch {epoch}/{train_cfg.epochs}...")

        epoch_start_time = time.time()
        
        train_one_epoch(
            epoch, model, train_loader, full_ds, device, train_cfg, data_cfg, opt, state
        )
        # After the first resumed epoch, all subsequent epochs start from the beginning
        state.resume_batch_idx = 0

        logging.info("Running validation...")
        validate_model(
            model, val_loader, full_ds, device, train_cfg, data_cfg, state,
            epoch=epoch,
        )

        log_and_save_epoch_end(
            epoch, epoch_start_time, model, opt, state, train_cfg
        )

    # 8) Finalization
    total_time = time.time() - start_time
    logging.info(f"Training finished in {total_time / 60:.2f} minutes.")

    # Plot all metrics
    logging.info("Plotting final training summaries...")
    for metric in ["energy", "loss", "accuracy"]:
        for split in ["train", "val"]:
            epoch_key = f"epoch_avg_{metric}_{split}"
            batch_key = f"batch_{metric}_{split}"
            if hasattr(state, epoch_key):
                plot_timeseries(
                    getattr(state, epoch_key), 
                    train_cfg.results_dir, f"{split.title()}", f"{metric.title()}", "Epoch #"
                    )
            if hasattr(state, batch_key):
                batch_df = pd.DataFrame(getattr(state, batch_key))
                plot_timeseries(
                    {c: batch_df[c].tolist() for c in batch_df.columns}, 
                    train_cfg.results_dir, f"{split.title()}", f"{metric.title()}", "Batch #"
                    )

    # Run and save final per-sample analysis
    if train_cfg.run_final_analysis:
        logging.info("Running final per-sample analysis...")
        analyze_per_sample(model, train_ds, device, train_cfg, data_cfg, state.encoding_cache, "train")
        analyze_per_sample(model, val_ds, device, train_cfg, data_cfg, state.encoding_cache, "val")

    # Save the final model
    final_model_path = Path(train_cfg.results_dir) / "final_model.pt"
    torch.save(model.state_dict(), final_model_path)
    logging.info(f"Saved final model to: {final_model_path}")
    if train_cfg.upload_to_gdrive:
        upload_file_to_drive(train_cfg.gdrive_folder_id, final_model_path, 
                             train_cfg.gdrive_creds_path, train_cfg.results_dir)

    # Save final training statistics summary
    stats_log_path = Path(train_cfg.results_dir) / "ebm_training_final_stats.log"
    summary_lines = []

    summary_lines.append("--- Final Training Statistics ---")
    summary_lines.append(f"Total training time: {total_time:.1f}s ({total_time/60:.1f} minutes)")
    summary_lines.append(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    summary_lines.append("\n--- Final Per-Method Average Losses ---")
    sorted_loss_keys = sorted(state.epoch_avg_loss_train.keys(), key=lambda k: (k != 'overall', k))
    for m in sorted_loss_keys:
        train_loss = state.epoch_avg_loss_train[m][-1]
        val_loss = state.epoch_avg_loss_val[m][-1]
        summary_lines.append(f"  {m:<25s} Train: {train_loss:.4f} | Val: {val_loss:.4f}")
        
    summary_lines.append("\n--- Final Per-Method Average Accuracies ---")
    sorted_acc_keys = sorted(state.epoch_avg_accuracy_train.keys(), key=lambda k: (k != 'overall', k))
    for m in sorted_acc_keys:
        train_acc = state.epoch_avg_accuracy_train[m][-1]
        val_acc = state.epoch_avg_accuracy_val[m][-1]
        best_acc = state.best_acc.get(m, 0.0)
        summary_lines.append(f"  {m:<25s} Train: {train_acc:.2f}% | Val: {val_acc:.2f}% (Best Val: {best_acc:.2f}%)")

    final_stats_summary = "\n".join(summary_lines)
    logging.info("\n" + final_stats_summary + "\n")

    with open(stats_log_path, "w") as f:
        f.write(final_stats_summary)
    logging.info(f"Saved full training stats to: {stats_log_path}")

    logging.info(f"EBM training script completed.")


if __name__ == "__main__":
    main()