import torch
import sys


def run_one_epoch(
    epoch, model, loader, criterion, accuracy_funcs, device, optimizer=None, train=False
):
    total_loss = 0
    accs = []

    for _ in accuracy_funcs:
        accs.append(0)

    if train:
        model.train()
    else:
        model.eval()

    for n, (batch_data, batch_labels) in enumerate(loader, start=1):
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        if train:
            optimizer.zero_grad()

        preds = model.forward(batch_data)

        loss = criterion(preds, batch_labels)
        if train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        for m, accuracy in enumerate(accuracy_funcs):
            accs[m] += accuracy(preds, batch_labels).item()

        if n % 50 == 0 and train:
            train_accs = ""
            for k, acc in enumerate(accs):
                train_accs += f"Train Accuracy {k}: {acc/n:.3f}, "

            print(f"Batch {n}: Train Loss: {loss.item():.3f}, {train_accs}")

    total_loss = total_loss / n
    total_accs = [acc / n for acc in accs]

    return {
        "loss": total_loss,
        "accuracy": total_accs,
    }
