"""class dLLMTrainer(Trainer): 
    def __init__(self, loss_type='vanilla', bottom_k_percent = None, seed=1234, mask_token_id=126336, *args, **kwargs): 
        super().__init__(*args, **kwargs) 
        self.bottom_k_percent = bottom_k_percent # Will be set if using bottom-r training 
        self.rng = random.Random(seed) 
        self.loss_mode = loss_type self.losses = { 'vanilla': self.vanilla_sft_loss, 'rdro': self.rdro_loss, 'bottom_k': self.bottom_k, 'top_k': self.top_k, 'mixed': self.mixture_sft_loss } self.mask_token_id = mask_token_id self.mix_modes = [ ("vanilla", self.vanilla_sft_loss), ("topk", partial(self.bottom_k, top_k=True)), ("bottomk",partial(self.bottom_k, top_k=False)), ] self.mix_probs = [1/3, 1/3, 1/3] self._mix_cached_gs = None self._mix_cached_idx = None def _get_mix_idx(self): global_step = int(self.state.global_step) # Ensure the same loss was used in the grad-accum steps. if self._mix_cached_gs == global_step: return self._mix_cached_idx if dist.is_available() and dist.is_initialized(): # Sample the Loss on rank0 and broadcast, so all ranks use the same loss. if dist.get_rank() == 0: idx = self._sample_idx() t = torch.tensor([idx], device=self.args.device, dtype=torch.long) else: t = torch.tensor([0], device=self.args.device, dtype=torch.long) dist.broadcast(t, src=0) idx = int(t.item()) else: idx = self._sample_idx() self._mix_cached_gs = global_step self._mix_cached_idx = idx return idx def _sample_idx(self): return self.rng.choices(range(len(self.mix_modes)), weights=self.mix_probs, k=1)[0] def mixture_sft_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False): idx = self._get_mix_idx() loss_name, loss_fn = self.mix_modes[idx] # if (self.state.global_step) % self.args.logging_steps == 0: # self.log({ # "loss_mode": idx, # }) print(f'RANK [{dist.get_rank()}] [LOSS TYPE] Loss Idx: {self._mix_cached_idx} - Loss Name: {loss_name} - Global Step: {self._mix_cached_gs}\n') loss, outputs = loss_fn(model, inputs, num_items_in_batch, return_outputs=True) return (loss, outputs) if return_outputs else loss
        """