"""
Utility Functions

Common utilities for training, evaluation, and data management.
"""

import os
import shutil
import torch
import yaml


# Directory configuration
DATA_DIR = os.path.dirname("/data/hd/projects/Datasets/Neural/")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))


class AverageMeter(object):
    """
    Computes and stores the average and current value.
    
    Useful for tracking metrics during training epochs.
    
    Example:
        losses = AverageMeter()
        for batch in dataloader:
            loss = compute_loss(...)
            losses.update(loss.item(), batch_size)
        print(f"Average loss: {losses.avg}")
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions.
    
    Args:
        output: Model predictions (batch_size x num_classes)
        target: Ground truth labels (batch_size,)
        topk: Tuple of k values for top-k accuracy computation
    
    Returns:
        List of top-k accuracies as percentages
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def save_config_file(model_checkpoints_folder, args):
    """
    Save configuration arguments to a YAML file.
    
    Args:
        model_checkpoints_folder: Directory to save config file
        args: Configuration arguments (argparse namespace or dict)
    """
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        with open(os.path.join(model_checkpoints_folder, 'config.yml'), 'w') as outfile:
            yaml.dump(args, outfile, default_flow_style=False)


def save_model(model, optimizer, opt, epoch, save_file):
    """
    Save model checkpoint with training state.
    
    Args:
        model: PyTorch model
        optimizer: Optimizer state
        opt: Training options
        epoch: Current epoch number
        save_file: Path to save checkpoint
    """
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save checkpoint and optionally copy to best model.
    
    Args:
        state: Checkpoint state dictionary
        is_best: Whether this is the best model so far
        filename: Path to save checkpoint
    """
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
