"""Training utilities for EBM training.

This module provides a collection of helper functions used during the training,
validation, and analysis of the energy-based model. It includes utilities for
data handling (caching, batching), energy computation, metric calculation,
and results visualization.
"""

from __future__ import annotations

import logging
from argparse import Namespace
from pathlib import Path
from typing import Dict, List, Set, Tuple, Union

import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random
import numpy as np

from .config import DataConfig, TrainingConfig
from .datasets import PromptResponseDataset, collate
from .negative_sampling import sample_negative_responses, sample_negative_prompts
from .losses import individual_contrastive_losses, individual_infonce_losses, infonce_loss, compute_accuracies


def set_seed(seed: int = 42) -> None:
    """Sets the random seed for reproducibility across all libraries.

    Args:
        seed (int): The seed value to use for `random`, `numpy`, and `torch`.
    """    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_cached_batch(
    texts: List[str], 
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]]
) -> tuple[torch.Tensor, torch.Tensor]:
    """Retrieves a batch of encoded tensors from a cache dictionary.

    Args:
        texts (List[str]): A list of text strings to retrieve from the cache.
        cache (Dict[str, Tuple[torch.Tensor, torch.Tensor]]): The cache
            dictionary mapping text strings to their (embedding, mask) tensors.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing the stacked
            embedding tensors and the stacked mask tensors for the requested texts.
    """
    embs = torch.stack([cache[t][0] for t in texts])
    masks = torch.stack([cache[t][1] for t in texts])
    return embs, masks


def compute_invalid_masks(
    original_texts: List[str],
    transformed_texts: List[str],
    device: torch.device,
) -> torch.Tensor:
    """Computes a boolean mask for samples where a transformation had no effect.

    This is used to identify procedural negative samples (e.g., from sentence
    masking) that are identical to the original text, which can happen if the
    original text has only one sentence.

    Args:
        original_texts (List[str]): The list of original text samples.
        transformed_texts (List[str]): The list of transformed text samples.
        device (torch.device): The device on which to create the output tensor.

    Returns:
        torch.Tensor: A boolean tensor of shape (B,), where `True` indicates
        that the transformed text is the same as the original.
    """
    return torch.tensor(
        [orig == trans for orig, trans in zip(original_texts, transformed_texts)],
        dtype=torch.bool,
        device=device,
    )


def apply_invalid_mask(
    energy_tensor: torch.Tensor,
    mask: torch.Tensor,
    invalid_value: float = float("inf"),
) -> torch.Tensor:
    """Applies an invalid mask to an energy tensor.

    This function sets the energy of invalid samples (as identified by the mask)
    to a specified value, typically infinity, so they are ignored in loss
    calculations.

    Args:
        energy_tensor (torch.Tensor): The energy values to be masked.
        mask (torch.Tensor): A boolean mask where `True` indicates invalid samples.
        invalid_value (float): The value to assign to invalid entries.

    Returns:
        torch.Tensor: A new tensor with the mask applied.
    """
    energy_tensor = energy_tensor.clone()
    energy_tensor[mask] = invalid_value
    return energy_tensor


def compute_batch_energy_stats(
    pos_energy: torch.Tensor,
    neg_energies_dict: Dict[str, torch.Tensor],
) -> Dict[str, float]:
    """Computes and aggregates the mean energy for a batch.

    Calculates the mean energy for the positive samples and for each set of
    negative samples, ignoring any non-finite values (e.g., inf).

    Args:
        pos_energy (torch.Tensor): A tensor of positive sample energies.
        neg_energies_dict (Dict[str, torch.Tensor]): A dictionary mapping
            negative sampling method names to their energy tensors.

    Returns:
        Dict[str, float]: A dictionary containing the mean energy for
        'positive' and each negative sampling method.
    """
    stats = {"positive": pos_energy.mean().item()}
    
    for method, energies in neg_energies_dict.items():
        # Only compute mean for finite values
        finite_energies = energies[torch.isfinite(energies)]
        if finite_energies.numel() > 0:
            stats[method] = finite_energies.mean().item()
        else:
            stats[method] = float("nan")
    
    return stats


def calculate_averages_from_history(
    batch_history: List[Dict[str, float]]
) -> Dict[str, float]:
    """Calculates the mean for each metric across all batches in a history list.
    
    This function aggregates results from a list of batch metric dictionaries,
    handling missing keys and NaN values gracefully by only averaging valid numbers.

    Args:
        batch_history (List[Dict[str, float]]): A list where each element is a
            dictionary of metrics from a single batch.

    Returns:
        Dict[str, float]: A dictionary mapping each metric name to its mean
            value across all batches.
    """
    if not batch_history:
        return {}

    final_averages = {}
    all_methods = set(key for batch in batch_history for key in batch.keys())

    for method in all_methods:
        all_vals = [b.get(method) for b in batch_history if b.get(method) is not None]
        valid_vals = [v for v in all_vals if not math.isnan(v)]
        final_averages[method] = sum(valid_vals) / len(valid_vals) if valid_vals else 0.0
    
    return final_averages


def prepare_text_sets(
    prompts: List[str],
    responses: List[str],
    response_negs: Dict[str, List[str]],
    prompt_negs: Dict[str, List[str]],
) -> Tuple[Set[str], Dict[str, List[str]]]:
    """Separates texts into stable (cacheable) and procedural (non-cacheable) sets.

    Args:
        prompts (List[str]): A batch of input prompts.
        responses (List[str]): A batch of ground-truth responses.
        response_negs (Dict[str, List[str]]): A dictionary mapping negative
            response types to lists of text samples.
        prompt_negs (Dict[str, List[str]]): A dictionary mapping negative
            prompt types to lists of text samples.

    Returns:
        Tuple[Set[str], Dict[str, List[str]]]: A tuple containing:
            - A set of all unique, stable texts that can be safely cached.
            - A dictionary of procedural texts that must be re-encoded each time.
    """
    # Stable texts that can be cached
    stable_texts = set(prompts) | set(responses)
    
    # Add stable negative samples that don't change between runs
    stable_neg_types = {"human", "gpt2", "off_context"}
    for neg_type in stable_neg_types:
        if neg_type in response_negs:
            stable_texts.update(response_negs[neg_type])
    
    if "off_context_prompt" in prompt_negs:
        stable_texts.update(prompt_negs["off_context_prompt"])
    
    # Procedural texts that change each time (sentence/token masking)
    procedural_texts = {}
    procedural_types = {
        "sentence_masking", "token_masking", 
        "sentence_masking_prompt", "token_masking_prompt"
    }
    
    for neg_type in procedural_types:
        if neg_type in response_negs:
            procedural_texts[neg_type] = response_negs[neg_type]
        elif neg_type in prompt_negs:
            procedural_texts[neg_type] = prompt_negs[neg_type]
    
    return stable_texts, procedural_texts


def update_encoding_cache(
    texts: List[str],
    cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    model: torch.nn.Module,
) -> None:
    """Encodes a list of texts and updates the cache with the results.

    The encoded tensors are stored on the CPU to conserve GPU memory.

    Args:
        texts (List[str]): A list of text strings to encode and cache.
        cache (Dict[str, Tuple[torch.Tensor, torch.Tensor]]): The cache
            dictionary to update in-place.
        model (torch.nn.Module): The model, which must have a `text_encoder` attribute.
    """
    if not texts:
        return
        
    encoded_embs, encoded_masks = model.text_encoder(texts)
    for i, text in enumerate(texts):
        cache[text] = (encoded_embs[i].cpu(), encoded_masks[i].cpu())


def encode_procedural_texts(
    procedural_texts: Dict[str, List[str]],
    model: torch.nn.Module,
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
    """Encodes procedural texts that are not suitable for caching.

    Args:
        procedural_texts (Dict[str, List[str]]): A dictionary mapping a
            procedural method name to a list of generated texts.
        model (torch.nn.Module): The model with a `text_encoder` attribute.

    Returns:
        Dict[str, Tuple[torch.Tensor, torch.Tensor]]: A dictionary mapping each
        method name to its encoded (embedding, mask) tensors.
    """
    encoded_procedural = {}
    for name, texts in procedural_texts.items():
        embs, masks = model.text_encoder(texts)
        encoded_procedural[name] = (embs, masks)
    return encoded_procedural


def prepare_encoded_negatives(
    response_negs: Dict[str, List[str]],
    prompt_negs: Dict[str, List[str]],
    encoding_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    encoded_procedural: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Dict[str, Tuple[torch.Tensor, torch.Tensor]], 
           Dict[str, Tuple[torch.Tensor, torch.Tensor]]]:
    """Assembles dictionaries of encoded negative samples.

    This function retrieves encoded tensors for stable negatives from the main
    cache and combines them with the freshly encoded tensors for procedural
    negatives.

    Args:
        response_negs (Dict[str, List[str]]): Dictionary of raw negative response texts.
        prompt_negs (Dict[str, List[str]]): Dictionary of raw negative prompt texts.
        encoding_cache (Dict): The main cache of stable text encodings.
        encoded_procedural (Dict): The dictionary of freshly encoded procedural texts.

    Returns:
        Tuple containing two dictionaries:
            - encoded_response_negs: Maps method names to encoded response tensors.
            - encoded_prompt_negs: Maps method names to encoded prompt tensors.
    """
    encoded_response_negs = {}
    encoded_prompt_negs = {}
    
    # Response negatives
    stable_response_types = {"human", "gpt2", "off_context"}
    for neg_type in stable_response_types:
        if neg_type in response_negs:
            encoded_response_negs[neg_type] = get_cached_batch(
                response_negs[neg_type], encoding_cache
            )
    
    procedural_response_types = {"sentence_masking", "token_masking"}
    for neg_type in procedural_response_types:
        if neg_type in encoded_procedural:
            encoded_response_negs[neg_type] = encoded_procedural[neg_type]
    
    # Prompt negatives
    if "off_context_prompt" in prompt_negs:
        encoded_prompt_negs["off_context_prompt"] = get_cached_batch(
            prompt_negs["off_context_prompt"], encoding_cache
        )
    
    procedural_prompt_types = {"sentence_masking_prompt", "token_masking_prompt"}
    for neg_type in procedural_prompt_types:
        if neg_type in encoded_procedural:
            encoded_prompt_negs[neg_type] = encoded_procedural[neg_type]
    
    return encoded_response_negs, encoded_prompt_negs


def format_pbar_postfix(
    batch_loss_stats: Dict[str, float],
    batch_acc_stats: Dict[str, float],
    batch_energy_stats: Dict[str, float],
    loss_history: List[Dict[str, float]],
) -> Dict[str, str]:
    """Formats the postfix dictionary for the tqdm progress bar.

    This utility computes the running average for losses and combines all
    batch-level metrics into a dictionary of formatted strings.

    Args:
        batch_loss_stats (Dict[str, float]): Loss values for the current batch.
        batch_acc_stats (Dict[str, float]): Accuracy values for the current batch.
        batch_energy_stats (Dict[str, float]): Mean energy values for the current batch.
        loss_history (List[Dict[str, float]]): A list of all previous batch loss
            dictionaries from the current epoch, used for averaging.

    Returns:
        Dict[str, str]: A dictionary of formatted strings for tqdm.set_postfix.
    """
    # NaN-safe running average for progress bar losses
    avg_losses_so_far = {}
    for method in batch_loss_stats:
        all_vals = [b.get(method) for b in loss_history if b.get(method) is not None]
        valid_vals = [v for v in all_vals if not math.isnan(v)]
        avg_losses_so_far[method] = (
            sum(valid_vals) / len(valid_vals) if valid_vals else 0.0
        )

    # Build the core dictionary
    postfix = {
        "loss": f"{batch_loss_stats.get('overall', 0.0):.4f}",
        "avg_loss": f"{avg_losses_so_far.get('overall', 0.0):.4f}",
        "acc": f"{batch_acc_stats.get('overall', 0.0):.2f}%",
        "pos_E": f"{batch_energy_stats.get('positive', 0.0):.3f}",
    }

    # Add per-sampler metrics, excluding the 'overall' and 'positive' keys
    all_methods = set(batch_loss_stats.keys()) | set(batch_acc_stats.keys()) | set(batch_energy_stats.keys())
    for method in sorted(m for m in all_methods if m not in ["overall", "positive"]):
        if method in batch_loss_stats:
            postfix[f"{method}_loss"] = f"{batch_loss_stats[method]:.4f}"
        if method in avg_losses_so_far:
            postfix[f"{method}_avg_loss"] = f"{avg_losses_so_far[method]:.4f}"
        if method in batch_acc_stats:
            postfix[f"{method}_acc"] = f"{batch_acc_stats[method]:.2f}%"
        if method in batch_energy_stats:
            postfix[f"{method}_E"] = f"{batch_energy_stats[method]:.3f}"
            
    return postfix


def plot_timeseries(
    series: Dict[str, List[float]],
    results_dir: str,
    plot_type: str,       # e.g. "train", "val", "eval train", "eval val"
    metric_name: str,     # e.g. "Energy", "Loss", "Accuracy"
    x_label: str          # e.g. "Epoch #" or "Batch #"
) -> None:
    """Generates and saves a time-series plot of training/validation metrics.

    Args:
        series (Dict[str, List[float]]): A dictionary where keys are series
            labels (e.g., method names) and values are lists of metric points.
        results_dir (str): The directory path to save the plot image.
        plot_type (str): The type of plot (e.g., "Train", "Val"). Used for titles.
        metric_name (str): The name of the metric being plotted (e.g., "Loss").
        x_label (str): The label for the x-axis (e.g., "Epoch #" or "Batch #").
    """
    try: 
        # For display in titles and labels (e.g., "eval_val" -> "Eval Val")
        display_plot_type = plot_type.replace('_', ' ').title()
        display_metric = metric_name.replace('_', ' ').title()

        # For use in filenames (e.g., "Eval Val" -> "eval_val")
        filename_plot_type = plot_type.lower().replace(' ', '_')
        filename_metric = metric_name.lower().replace(' ', '_')
        filename_xlabel = x_label.split()[0].lower()

        # X-axis is just 1..N
        num_points = len(next(iter(series.values()), []))
        if num_points == 0:
            return
            
        xs = range(1, num_points + 1)
        plt.figure()

        sorted_series = sorted(series.items(), key=lambda item: (item[0] not in ['positive', 'overall'], item[0]))
        for method, values in sorted_series:
            if len(values) == num_points:
                plt.plot(xs, values, label=method)
        
        plt.xlabel(x_label)
        plt.ylabel(f"Avg {display_metric}")
        plt.title(f"{display_plot_type} Average {display_metric} per {x_label.split()[0]}")
        plt.legend()
        plt.tight_layout()
        
        out = Path(results_dir) / f"{filename_metric}_{filename_plot_type}_{filename_xlabel}.png"
        plt.savefig(out)
        plt.close()
        logging.info(f"Saved {filename_metric}‐per‐{filename_xlabel} plot to: {out}")

    except Exception as e:
        logging.error(f"Could not plot {display_metric} for {display_plot_type}: {e}")


def analyze_per_sample(
    model: torch.nn.Module, 
    dataset: PromptResponseDataset,
    device: torch.device, 
    train_cfg: TrainingConfig,
    data_cfg: DataConfig,
    encoding_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    prefix: str,
) -> None:
    """Analyzes per-sample metrics for a dataset and saves results to a CSV file.

    This function iterates through a dataset one sample at a time, generates all
    supported negative samples, computes their energies, calculates all relevant
    losses and accuracies according to the provided configuration, and compiles
    the results into a detailed CSV log. It uses an encoding cache for efficiency.

    Args:
        model (torch.nn.Module): The trained energy model to use for analysis.
        dataset (PromptResponseDataset): The dataset to analyze.
        device (torch.device): The device to use for tensor computations.
        train_cfg (TrainingConfig): Configuration object with training parameters
            like loss strategy, margin, and temperature.
        data_cfg (DataConfig): Configuration object with data parameters like
            column names.
        encoding_cache (Dict): The dictionary of pre-computed text encodings.
            This function will use and may populate this cache.
        prefix (str): A prefix for the output filename (e.g., "train", "eval_val").
    """
    loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate)
    rows = []

    with torch.no_grad():
        for batch_data in tqdm(loader, desc=f"per-sample analysis ({prefix})"):
            idx, prompt_list, response_list = batch_data
            idx, prompt, response = idx[0], prompt_list[0], response_list[0]
            idx_list = [idx]

            # 1) Generate all negative text samples for this single item
            response_negs = {
                "sentence_masking": sample_negative_responses(
                    sampling_type="sentence_masking", responses=response_list
                ),
                "token_masking": sample_negative_responses(
                    sampling_type="token_masking", responses=response_list
                ),
                "off_context": sample_negative_responses(
                    sampling_type="off_context", responses=response_list, 
                    indices=idx_list, dataset=dataset
                ),
                "human": sample_negative_responses(
                    sampling_type="human", responses=response_list, 
                    indices=idx_list, dataset=dataset, data_cfg=data_cfg
                ),
                "gpt2": sample_negative_responses(
                    sampling_type="gpt2", responses=response_list, 
                    indices=idx_list, dataset=dataset, data_cfg=data_cfg
                ),
            }
            prompt_negs = {
                "sentence_masking_prompt": sample_negative_prompts(
                    sampling_type="sentence_masking_prompt", prompts=prompt_list
                ),
                "token_masking_prompt": sample_negative_prompts(
                    sampling_type="token_masking_prompt", prompts=prompt_list
                ),
                "off_context_prompt": sample_negative_prompts(
                    sampling_type="off_context_prompt", prompts=prompt_list, 
                    indices=idx_list, dataset=dataset
                ),
            }

            # 2) Centralized Encoding Step
            stable_texts, procedural_texts = prepare_text_sets(prompt_list, response_list, response_negs, prompt_negs)
            cache_misses = [t for t in stable_texts if t not in encoding_cache]
            update_encoding_cache(cache_misses, encoding_cache, model)
            encoded_procedural = encode_procedural_texts(procedural_texts, model)

            # 3) Retrieve all required encoded tensors
            pos_x_encoded, pos_x_mask = get_cached_batch(prompt_list, encoding_cache)
            pos_y_encoded, pos_y_mask = get_cached_batch(response_list, encoding_cache)
            encoded_response_negs, encoded_prompt_negs = prepare_encoded_negatives(
                response_negs, prompt_negs, encoding_cache, encoded_procedural
            )

            # 4) Compute energies using forward_from_encoded
            pos_e = model.forward_from_encoded((pos_x_encoded, pos_x_mask), (pos_y_encoded, pos_y_mask)).flatten()
            response_neg_es = {m: model.forward_from_encoded((pos_x_encoded, pos_x_mask), negs).item() 
                               for m, negs in encoded_response_negs.items()}
            prompt_neg_es = {m: model.forward_from_encoded(negs, (pos_y_encoded, pos_y_mask)).item() 
                             for m, negs in encoded_prompt_negs.items()}
                
            # Handle invalid samples
            if response_negs["sentence_masking"][0] == response:
                response_neg_es["sentence_masking"] = float('inf')
            if prompt_negs["sentence_masking_prompt"][0] == prompt:
                prompt_neg_es["sentence_masking_prompt"] = float('inf')
            
            all_neg_es_dict_scalar = {**response_neg_es, **prompt_neg_es}
            all_neg_e_tensors = {m: torch.tensor([e], device=device) for m, e in all_neg_es_dict_scalar.items()}
                        
            # 5) Calculate metrics
            if train_cfg.loss_strategy in ['sum', 'weighted_sum', 'sequential']:
                individual_losses = individual_contrastive_losses(pos_e, all_neg_e_tensors, train_cfg.margin)
                valid_losses = [loss for loss in individual_losses.values() if not torch.isnan(loss)]
                overall_loss = sum(valid_losses) if valid_losses else torch.tensor(0.0)
            else: # infonce, infonce_expanded
                # For batch_size=1, infonce_expanded is the same as infonce
                individual_losses = individual_infonce_losses(pos_e, all_neg_e_tensors, temperature=train_cfg.temperature)
                overall_loss = infonce_loss(pos_e, list(all_neg_e_tensors.values()), temperature=train_cfg.temperature)

            # Calculate accuracies
            accuracies = compute_accuracies(pos_e, all_neg_e_tensors)

            # 6) Record for CSV
            row_data = {
                "pos_prompt": prompt,
                "pos_response": response,
                "pos_energy": pos_e.item(),
                "overall_accuracy": accuracies.get("overall", float('nan')),
                "overall_loss": overall_loss.item(),
            }
            all_methods = accuracies.keys() - {"overall"}
            for method in all_methods:
                raw_negs = {k: v[0] for k, v in {**response_negs, **prompt_negs}.items()}
                row_data[f"neg_{method}_text"] = raw_negs.get(method, "")
                row_data[f"neg_{method}_energy"] = all_neg_es_dict_scalar.get(method, float('nan'))
                row_data[f"{method}_accuracy"] = accuracies.get(method, float('nan'))
                row_data[f"{method}_loss"] = individual_losses.get(method, torch.tensor(float('nan'))).item()
            rows.append(row_data)

    # Save CSV with ordered columns
    df = pd.DataFrame(rows)

    prefix_cols = [
        "pos_prompt", "pos_response", "pos_energy", 
        "overall_accuracy", "overall_loss"
    ]
    method_names = sorted(list(accuracies.keys() - {"overall"}))
    final_col_order = prefix_cols.copy()
    for method in method_names:
        final_col_order.extend([
            f"neg_{method}_text",
            f"neg_{method}_energy",
            f"{method}_accuracy",
            f"{method}_loss",
        ])
    final_col_order_existing = [col for col in final_col_order if col in df.columns]

    df = df[final_col_order_existing]
    stats_path = Path(train_cfg.results_dir)/f"{prefix}_sample_stats.csv"
    df.to_csv(stats_path, index=False)
    logging.info(f"Saved per-sample analysis to: {stats_path}")


def load_checkpoint(
    checkpoint_path: Union[str, Path], 
    model: torch.nn.Module, 
    optimizer: Optimizer, 
    device: torch.device
) -> Tuple[int, int, dict, dict, dict, dict, dict, dict, 
           list, list, list, list, list, list, dict, dict]:
    """Loads a training checkpoint from a file.

    Args:
        checkpoint_path (Union[str, Path]): Path to the checkpoint file (.pt).
        model (torch.nn.Module): The model instance to load the state_dict into.
        optimizer (Optimizer): The optimizer instance to load the state_dict into.
        device (torch.device): The device to map the loaded tensors to.

    Returns:
        A tuple containing all saved training state:
        (epoch, global_step, epoch_avg_energy_train, epoch_avg_energy_val,
         epoch_avg_loss_train, epoch_avg_loss_val, epoch_avg_accuracy_train,
         epoch_avg_accuracy_val, batch_loss_train, batch_loss_val,
         batch_accuracy_train, batch_accuracy_val, batch_energy_train,
         batch_energy_val, best_acc, encoding_cache)
    """
    logging.info(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    return (checkpoint['epoch'], 
            checkpoint['global_step'],
            checkpoint.get('epoch_avg_energy_train', {}),
            checkpoint.get('epoch_avg_energy_val', {}),
            checkpoint.get('epoch_avg_loss_train', {}),
            checkpoint.get('epoch_avg_loss_val', {}),
            checkpoint.get('epoch_avg_accuracy_train', {}),
            checkpoint.get('epoch_avg_accuracy_val', {}),
            checkpoint.get('batch_loss_train', []),
            checkpoint.get('batch_loss_val', []),
            checkpoint.get('batch_accuracy_train', []),
            checkpoint.get('batch_accuracy_val', []),
            checkpoint.get('batch_energy_train', []),
            checkpoint.get('batch_energy_val', []),
            checkpoint.get('best_acc', {}),
            checkpoint.get('encoding_cache', {}))


def save_checkpoint(
    epoch: int, global_step: int, model: torch.nn.Module, optimizer: Optimizer, 
    epoch_avg_energy_train: dict, epoch_avg_energy_val: dict,
    epoch_avg_loss_train: dict, epoch_avg_loss_val: dict,
    epoch_avg_accuracy_train: dict, epoch_avg_accuracy_val: dict,
    batch_loss_train: list, batch_loss_val: list,
    batch_accuracy_train: list, batch_accuracy_val: list,
    batch_energy_train: list, batch_energy_val: list,
    best_acc: dict, encoding_cache: dict, results_dir: str
) -> Path:
    """Saves the complete training state to a checkpoint file.

    Args:
        All arguments are training state variables to be saved.

    Returns:
        Path: The path to the saved checkpoint file.
    """
    checkpoint = {
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch_avg_energy_train': epoch_avg_energy_train,
        'epoch_avg_energy_val': epoch_avg_energy_val,
        'epoch_avg_loss_train': epoch_avg_loss_train,
        'epoch_avg_loss_val': epoch_avg_loss_val,
        'epoch_avg_accuracy_train': epoch_avg_accuracy_train,
        'epoch_avg_accuracy_val': epoch_avg_accuracy_val,
        'batch_loss_train': batch_loss_train,
        'batch_loss_val': batch_loss_val,
        'batch_accuracy_train': batch_accuracy_train,
        'batch_accuracy_val': batch_accuracy_val,
        'batch_energy_train': batch_energy_train,
        'batch_energy_val': batch_energy_val,
        'best_acc': best_acc,
        'encoding_cache': encoding_cache
    }
    
    checkpoint_path = Path(results_dir) / f"checkpoint_epoch_{epoch}_step_{global_step}.pt"
    torch.save(checkpoint, checkpoint_path)
    logging.info(f"Saved checkpoint to: {checkpoint_path}")
    return checkpoint_path


def upload_file_to_drive(
    folder_id: str,
    local_path: Union[str, Path],
    creds_path: Union[str, Path],
    results_dir: Union[str, Path],
) -> None:
    """Uploads a local file to a specified Google Drive folder.

    This function handles PyDrive2 authentication, including token refreshing,
    and uploads a file. If a file with the same name already exists in the
    target folder, it will be updated. Upon successful upload, the local
    file is deleted to save space.

    The function will gracefully do nothing if the `pydrive2` library is not
    installed or if the folder ID or credentials path are not provided.

    Args:
        folder_id (str): The ID of the target Google Drive folder.
        local_path (Union[str, Path]): The path to the local file to be uploaded.
        creds_path (Union[str, Path]): The path to the PyDrive2 credentials file
            (e.g., 'creds.dat').
        results_dir (Union[str, Path]): A local directory used to store a
            writable copy of the credentials.
    """
    import logging
    import os
    import shutil
    
    try:
        from pydrive2.auth import GoogleAuth
        from pydrive2.drive import GoogleDrive
        PYDRIVE_AVAILABLE = True
    except ImportError:
        PYDRIVE_AVAILABLE = False
    
    if not PYDRIVE_AVAILABLE:
        logging.warning("PyDrive2 not installed, skipping Google Drive upload.")
        return
    if not folder_id or not creds_path:
        logging.warning("Google Drive folder ID or credentials path not provided, skipping upload.")
        return
    try:
        # Authentication logic inside the function for robustness
        writable_creds_path = os.path.join(results_dir, "creds.dat")
        if not os.path.exists(writable_creds_path):
             shutil.copy(creds_path, writable_creds_path)

        gauth = GoogleAuth()
        gauth.LoadCredentialsFile(writable_creds_path)
        
        if gauth.credentials is None:
            logging.error("GDrive credentials not found.")
            return
        elif gauth.access_token_expired:
            gauth.Refresh()
            gauth.SaveCredentialsFile(writable_creds_path)
            logging.info("GDrive token refreshed successfully.")
        else:
            gauth.Authorize()
        
        drive = GoogleDrive(gauth)

        file_name = os.path.basename(local_path)
        # Search for existing file by name
        query = f"'{folder_id}' in parents and title = '{file_name}' and trashed=false"
        file_list = drive.ListFile({'q': query}).GetList()

        if file_list:
            # File exists, update it
            gfile = file_list[0]
            logging.info(f"Updating existing file on Google Drive: {gfile['title']}")
        else:
            # File does not exist, create new
            logging.info(f"Creating new file on Google Drive: {os.path.basename(local_path)}")
            gfile = drive.CreateFile({
                'title': os.path.basename(local_path),
                'parents': [{'id': folder_id}]
            })

        gfile.SetContentFile(str(local_path))
        gfile.Upload()
        logging.info(f"Successfully uploaded {os.path.basename(local_path)} to Google Drive.")
        
        # Remove the local file after successful upload to save local space
        os.remove(local_path)
        logging.info(f"Removed local file: {local_path}")
        
    except Exception as e:
        logging.error(f"Failed to upload {os.path.basename(local_path)} to Google Drive: {e}")
