"""Configuration classes for EBM training.

This module defines the dataclasses used for managing all static hyperparameters
and dynamic state throughout the training and evaluation pipeline. It separates
concerns into three main categories:
- TrainingConfig: Parameters that control the training loop and system settings.
- DataConfig: Parameters that define the data source and its structure.
- TrainingState: A container for all mutable state, such as metric histories,
  model checkpoints, and progress counters.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any


@dataclass 
class DataConfig:
    """Configuration for dataset source, structure, and processing.

    Attributes:
        dataset_type: The type of dataset to load ('csv' or 'hf').
        csv_path: The local file path to the dataset if `dataset_type` is 'csv'.
        hf_dataset_name: The name of the dataset on the Hugging Face Hub if
            `dataset_type` is 'hf'.
        hf_dataset_split: The split to use from the Hugging Face dataset.
        gpt2_path: Optional path to a CSV containing pre-computed GPT-2 responses
            to merge with the main dataset.
        prompt_col: The column name for prompts in the dataset.
        response_col: The column name for the golden responses.
        human_col: The column name for human-written negative responses.
        gpt2_col: The column name for GPT-2 generated negative responses.
        data_frac: The fraction of the total dataset to use (0.0 to 1.0).
        val_split: The fraction of the selected data to set aside for validation.
    """
    # Dataset source
    dataset_type: str = "csv"
    csv_path: str | None = None
    hf_dataset_name: str = "Hello-SimpleAI/HC3"
    hf_dataset_config: str = "all"
    hf_dataset_split: str = "train"
    gpt2_path: str | None = None

    # Column mapping
    prompt_col: str = "question"
    response_col: str = "answer"
    human_col: str = "human_answers"
    gpt2_col: str = "gpt2"

    # Data processing
    data_frac: float = 1
    train_split: float = 0.8  # Not used
    val_split: float = 0.1
    test_split: float = 0.1  # Not used

    def __post_init__(self):
        """Ensure all column name fields are lowercase after initialization."""
        self.prompt_col = self.prompt_col.lower()
        self.response_col = self.response_col.lower()
        self.human_col = self.human_col.lower()
        self.gpt2_col = self.gpt2_col.lower()


@dataclass
class TrainingConfig:
    """Configuration for all training, system, and I/O parameters.

    Attributes:
        batch_size: The number of samples per training batch.
        epochs: The total number of epochs to train for.
        lr: The learning rate for the AdamW optimizer.
        loss_strategy: The loss calculation and model update strategy to use.
        margin: The margin for margin-based contrastive losses.
        temperature: The temperature for the InfoNCE loss.
        off_context_weight: The weight multiplier for off-context negative samples
            in the 'weighted_sum' loss strategy.
        k_candidates: The number of negative candidates to generate for hard
            negative mining strategies.
        results_dir: The directory where all outputs (logs, models, plots) will be saved.
        save_every_n_epochs: If > 0, saves a full checkpoint every N epochs.
        save_every_n_batches: If > 0, saves a full checkpoint every N batches.
        save_best_per_method: If True, saves a separate model checkpoint each time a
            new best validation accuracy is found for any sampling method.
        run_final_analysis: If True, runs a detailed per-sample analysis after
            training is complete and saves it to a CSV.
        resume_from_checkpoint: Path to a checkpoint file to resume training from.
        evaluate_only: Path to a model file to run in evaluation-only mode,
            skipping training.
        upload_to_gdrive: If True, enables uploading checkpoints to Google Drive.
        gdrive_folder_id: The ID of the Google Drive folder for uploads.
        gdrive_creds_path: Path to the PyDrive2 credentials file (creds.dat).
        seed: The random seed for reproducibility.
        device: The device to use for training ('auto', 'cpu', 'cuda').
        num_workers: The number of worker processes for the DataLoader.
        log_verbose: If True, sets the logging level to DEBUG.
    """
    # Training hyperparameters
    batch_size: int = 16
    epochs: int = 10
    lr: float = 5e-5
    loss_strategy: str = "sum"
    margin: float = 0.5
    temperature: float = 0.1
    off_context_weight: float = 2.0
    k_candidates: int = 5
    
    # System settings
    seed: int = 42
    device: str = "auto"
    num_workers: int = 0
    log_verbose: bool = True

    # Checkpointing
    results_dir: str = "results"
    save_every_n_epochs: int = 0
    save_every_n_batches: int = 0
    save_best_per_method: bool = False
    
    # Training mode
    resume_from_checkpoint: str | None = None
    run_final_analysis: bool = True
    evaluate_only: str | None = None

    # Google Drive integration
    upload_to_gdrive: bool = False
    gdrive_folder_id: str | None = None
    gdrive_creds_path: str | None = None


@dataclass
class TrainingState:
    """A container for all mutable state during a training run.

    This object holds all metric histories, progress counters, and caches that
    are modified during the training loop. It is passed to and updated by the
    training and validation functions.

    Attributes:
        epoch_avg_energy_train: A dictionary mapping method names to a list of
            average training energy scores, one per epoch.
        epoch_avg_energy_val: A dictionary mapping method names to a list of
            average validation energy scores, one per epoch.
        epoch_avg_loss_train: A dictionary mapping method names to a list of
            average training loss values, one per epoch.
        epoch_avg_loss_val: A dictionary mapping method names to a list of
            average validation loss values, one per epoch.
        epoch_avg_accuracy_train: A dictionary mapping method names to a list of
            average training accuracy percentages, one per epoch.
        epoch_avg_accuracy_val: A dictionary mapping method names to a list of
            average validation accuracy percentages, one per epoch.
        batch_energy_train: A list of all per-batch training energy statistics.
        batch_energy_val: A list of all per-batch validation energy statistics.
        batch_loss_train: A list of all per-batch training loss statistics.
        batch_loss_val: A list of all per-batch validation loss statistics.
        batch_accuracy_train: A list of all per-batch training accuracy statistics.
        batch_accuracy_val: A list of all per-batch validation accuracy statistics.
        best_acc: A dictionary tracking the best validation accuracy achieved so far
            for each sampling method.
        encoding_cache: A cache mapping text strings to their computed tensor
            embeddings to avoid re-computation.
        start_epoch: The epoch number to start or resume training from.
        global_step: A counter for the total number of batches processed.
        resume_batch_idx: The batch index within an epoch to resume from, if
            resuming from a mid-epoch checkpoint.
    """    
    # Epoch-level tracking
    epoch_avg_energy_train: dict[str, list[float]]
    epoch_avg_energy_val: dict[str, list[float]]
    epoch_avg_loss_train: dict[str, list[float]]
    epoch_avg_loss_val: dict[str, list[float]]
    epoch_avg_accuracy_train: dict[str, list[float]]
    epoch_avg_accuracy_val: dict[str, list[float]]

    # Best accuracy tracking
    best_acc: dict[str, float]
    
    # Batch-level tracking (can use default_factory)
    batch_energy_train: list[dict[str, float]] = field(default_factory=list)
    batch_energy_val: list[dict[str, float]] = field(default_factory=list)
    batch_loss_train: list[dict[str, float]] = field(default_factory=list)
    batch_loss_val: list[dict[str, float]] = field(default_factory=list)
    batch_accuracy_train: list[dict[str, float]] = field(default_factory=list)
    batch_accuracy_val: list[dict[str, float]] = field(default_factory=list)
    
    # Encoding cache tracking
    encoding_cache: dict[str, Any] = field(default_factory=dict)
    
    # Training progress
    start_epoch: int = 1
    global_step: int = 0
    resume_batch_idx: int = 0


def create_training_state(sampling_methods: list[str]) -> TrainingState:
    """Factory function to create a new `TrainingState` object.

    Initializes all the epoch-level tracking dictionaries with the correct
    keys based on the provided sampling methods.

    Args:
        sampling_methods (list[str]): A list of all sampling method names
            (including 'positive') that will be used.

    Returns:
        TrainingState: An initialized state object ready for a new training run.
    """    
    # Energy tracking
    epoch_avg_energy_train = {m: [] for m in sampling_methods}
    epoch_avg_energy_val = {m: [] for m in sampling_methods}
    
    # Metrics tracking  
    metrics_methods = ["overall"] + [m for m in sampling_methods if m != "positive"]
    epoch_avg_loss_train = {m: [] for m in metrics_methods}
    epoch_avg_loss_val = {m: [] for m in metrics_methods}
    epoch_avg_accuracy_train = {m: [] for m in metrics_methods}
    epoch_avg_accuracy_val = {m: [] for m in metrics_methods}
    
    # Best accuracy tracking
    best_acc = dict.fromkeys(metrics_methods, 0.0)
    
    return TrainingState(
        epoch_avg_energy_train=epoch_avg_energy_train,
        epoch_avg_energy_val=epoch_avg_energy_val,
        epoch_avg_loss_train=epoch_avg_loss_train,
        epoch_avg_loss_val=epoch_avg_loss_val,
        epoch_avg_accuracy_train=epoch_avg_accuracy_train,
        epoch_avg_accuracy_val=epoch_avg_accuracy_val,
        best_acc=best_acc,
    )
