
import numpy as np
import torch

from tabicl.core.enums import Task
from tabicl.data.preprocessor import Preprocessor
from tabicl.results.prediction_metrics import PredictionMetrics


class PredictionMetricsTracker():
    """
    Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
    Uses torch.Tensor for predictions and true values.
    """

    def __init__(self, task: Task, preprocessor: Preprocessor) -> None:

        self.task = task
        self.preprocessor = preprocessor
        self.reset()


    def reset(self) -> None:

        self.ys_pred: list[np.ndarray] = []
        self.ys_true: list[np.ndarray] = []


    def update(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> None:

        y_pred_np = y_pred.detach().cpu().numpy()[0]
        y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)

        y_true_np = y_true.detach().cpu().numpy()[0]
        y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)

        self.ys_pred.append(y_pred_ori)
        self.ys_true.append(y_true_ori)


    def get_metrics(self) -> PredictionMetrics:

        y_pred = np.concatenate(self.ys_pred, axis=0)
        y_true = np.concatenate(self.ys_true, axis=0)

        return PredictionMetrics.from_prediction(y_pred, y_true, self.task)

