import torch.nn as nn
import pytorch_lightning as pl

from conf.dataset import DatasetParams
from conf.model import LoggingParams
from utils.utils import display_tensor, display_mask


class LogStrategy(nn.Module):
    def __init__(
        self,
        params: LoggingParams,
        params_data: DatasetParams,
    ):
        super().__init__()
        self.params = params
        self.params_data = params_data

    @staticmethod
    def already_logged(
        plMod: pl.LightningModule,
        batch_idx: int,
        batch_size: int,
    ) -> int:
        """
        Return how much of the sample should be logged
        """
        max_quant = plMod.params.logging.log_steps.max_quantity  # TODO should not me log_Steps
        already_logged = batch_idx * batch_size
        remaining = max(0, max_quant - already_logged)
        return remaining

    def log_train(
        self,
        stage_prefix: str,
        prediction,
        input_to_model,
        batch,
        ts,
        x0_hat,
        plMod,
        batch_idx,
        modes,
    ):
        raise NotImplementedError

    def log_generate(
        self,
        stage_prefix,
        plMod,
        batch,
        modes,
        batch_idx,
        idx=None,
    ):
        raise NotImplementedError
