from typing import List

from algorithms.utils.types import LossValues


def print_mean_epoch_losses(epoch: int, all_losses: List[LossValues]) -> None:
    total_loss, tau_loss, chance_loss, choice_loss, value_loss, l2_loss = 0, 0, 0, 0, 0, 0
    for losses in all_losses:
        total, tau, chance, choice, value, l2 = losses
        total_loss += total
        tau_loss += tau
        chance_loss += chance
        choice_loss += choice
        value_loss += value
        l2_loss += l2
    n = len(all_losses)
    print(("Epoch {0}. Total: {1:.3g}, Tau: {2:3g} Chance: {3:.3g}, Choice: {4:.3g}, "
           "Value: {5:.3g}, L2: {6:.3g}").format(
        epoch, total_loss / n, tau_loss / n, chance_loss / n, choice_loss / n, value_loss / n, l2_loss / n))
