import torch
from sklearn.metrics import accuracy_score, roc_auc_score


class EpochLogger:
    def __init__(self, cfg, name):
        self.cfg = cfg
        self.name = name
        self._iter = 0
        self._true = []
        self._pred = []
        self._custom_stats = {}
        self._loss = 0

    def update_stats(self, true, pred, loss, **kwargs):

        batch_size = true.shape[0]
        self._iter += 1
        self._loss += loss
        self._true.append(true)
        self._pred.append(pred)

        for key, val in kwargs.items():
            if key not in self._custom_stats:
                self._custom_stats[key] = val * batch_size
            else:
                self._custom_stats[key] += val * batch_size

    def reset(self):
        self._iter = 0
        self._true = []
        self._pred = []
        self._loss = 0
        self._custom_stats = {}

    def compute_basic(self):
        stats = {f"{self.name}/average_loss": self._loss / self._iter}

        return stats

    def _get_pred_int(self, pred_score):
        if len(pred_score.shape) == 1 or pred_score.shape[1] == 1:
            return (pred_score > 0.5).long()
        else:
            return pred_score.max(dim=1)[1]

    def compute_performance_metrics(self):
        if self.cfg.dataset.task_type == "classification":
            true, pred_score = torch.cat(self._true), torch.cat(self._pred)
            pred_int = self._get_pred_int(pred_score)

            res = {
                f"{self.name}/accuracy": accuracy_score(true, pred_int),
            }

            # Compute AUC-ROC for binary classification
            if self.cfg.model.dim_out == 1:  # Binary classification
                try:
                    auc_roc = roc_auc_score(
                        true, pred_score.to(dtype=torch.float32).detach().cpu().numpy()
                    )  # Use positive class probabilities
                    res[f"{self.name}/auc_roc"] = auc_roc
                except ValueError as e:
                    print(f"Error in AUC-ROC calculation: {e}")

        return res

    def write_epoch(self):
        self._loss /= self._iter
        stats = {
            **self.compute_basic(),
            **self.compute_performance_metrics(),
        }
        self.reset()

        return stats
