import torch
import torch.nn as nn


class CustomModel(nn.Module):
    """PyTorch Model class"""

    def __init__(self, tokenizer, model, output: str = "outputs"):
        """
        initiates a PyTorch Lightning Model
        Args:
            tokenizer : T5 tokenizer
            model : T5 model
            output (str, optional): output directory to save model checkpoints. Defaults to "outputs".
        """
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.output = output
        # self.val_acc = Accuracy()
        # self.train_acc = Accuracy()

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        """forward step"""
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits
