"""
Common utilities for training.
Contains only the functions that are actually used in the codebase.
"""
import random
from typing import List

import numpy as np
import torch


def init_single_gpu(seed=42):
    """Initialize RNG seeds for reproducible single GPU training."""
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    g = torch.Generator()
    g.manual_seed(seed)
    
    # Determine device (use CUDA if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    return device, g


def compute_ece(probs, labels_binary, n_bins=10):
    """Compute Expected Calibration Error."""
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Find predictions in this bin
        in_bin = (probs > bin_lower) & (probs <= bin_upper)
        bin_size = np.sum(in_bin)
        
        if bin_size > 0:
            bin_accuracy = np.mean(labels_binary[in_bin])
            bin_confidence = np.mean(probs[in_bin])
            ece += bin_size * np.abs(bin_accuracy - bin_confidence)
    
    return ece / len(labels_binary)


def compute_gradient_accumulation_steps(effective_batch_size, batch_size, accelerator=None):
    """Calculate gradient accumulation steps considering number of GPUs."""
    if accelerator is None:
        return effective_batch_size // batch_size
    else:
        return effective_batch_size // batch_size // accelerator.num_processes