from src.trainer.base import MyTrainingArguments

from dataclasses import dataclass
from transformers import Trainer
from typing import Optional, Literal
import torch
import sys

@dataclass
class FinetuneArguments(MyTrainingArguments):
    train_only_head: bool = False
    train_only_last_mlp: bool = False
    num_train_epochs: float = 1.0
    learning_rate: Optional[float] = None
    logging_steps: int = 50
    eval_strategy: Literal["no", "steps", "epoch"] = "no"
    eval_steps: int = 1
    save_steps: int = 1000

class FinetuneTrainer(Trainer):
    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        sys.stdout.flush()
        return out
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        out = super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
        if isinstance(out, tuple):
            assert not torch.any(torch.isnan(out[0])).item(), "nan detected in the loss"
        else:
            assert not torch.any(torch.isnan(out)).item(), "nan detected in the loss"
        return out