import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from transformers import PreTrainedModel
from transformers import Trainer

import copy
from typing import Dict, Union, Any

class CustomSLTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, weight=None):
        labels = inputs["labels"]
        unwrapped_model = self.accelerator.unwrap_model(model)

        loss, outputs = self.compute_loss_func(
            unwrapped_model,
            inputs,
            labels,
            return_outputs=True,
            weight=weight,
        )

        if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
            loss *= self.accelerator.num_processes

        return (loss, outputs) if return_outputs else loss