
import numpy as np
import torch

from tabicl.core.enums import Task
from tabicl.results.prediction_metrics import PredictionMetrics


class PredictionMetricsTracker():
    """
    Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
    Uses torch.Tensor for predictions and true values.
    """

    def __init__(self, task: Task):

        self.task = task
        self.reset()


    def reset(self) -> None:

        self.ys_pred: list[np.ndarray] = []
        self.ys_true: list[np.ndarray] = []


    def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None:

        self.ys_pred.append(y_pred.detach().cpu().numpy()[0])
        self.ys_true.append(y_true.detach().cpu().numpy()[0])


    def get_metrics(self) -> PredictionMetrics:

        y_pred = np.concatenate(self.ys_pred, axis=0)
        y_true = np.concatenate(self.ys_true, axis=0)

        return PredictionMetrics.from_prediction(y_pred, y_true, self.task)

