import torch

from transformers import Trainer
from torch.optim import AdamW

class dLLMTrainer(Trainer):

    def __init__(self, *args, 
                 loss_calc="dica",
                 data_collator=None,
                 **kwargs):
        print("TRAINER ARGS ======", args)
        print("TRAINER KWARGS =======", kwargs)
        super().__init__(*args, data_collator=data_collator, **kwargs)
        self.loss_calc = loss_calc

    def create_optimizer(self):
        """Create a simple optimizer over all trainable params."""
        if self.optimizer is not None:
            return self.optimizer
        
        #TODO Check if I should set different learning rate for the AR and Diffusion
        base = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = AdamW([{"params": base, "lr": self.args.learning_rate, "weight_decay": 0.0}])
        return self.optimizer
    
    def shift_logits(self, logits):
        """
        Shift logits to match the input_ids for the next token prediction.
        """
        return torch.cat([logits[:,:1], logits[:, :-1]], dim=1)

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        """
        Compute the loss, optionally using two-pass softmasking approach.
        The loss is weighted by dsigma = 1/t.
        """
        input_ids = inputs.pop("input_ids")
        block_ranges = inputs.pop("block_ranges")
        labels = inputs.pop("labels")
        block_mask_indices = inputs.pop("block_mask_indices")
        t = inputs.pop("t")

        # ---- loss computation ----
        if self.loss_calc == "codila":
            outputs = model(input_ids, block_ranges, block_mask_indices, labels)
            final_loss = outputs.loss / t

        else:
            raise ValueError(f"Unknown loss_calc method: {self.loss_calc}")

        return (final_loss, outputs) if return_outputs else final_loss
    
    def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
        """Standard save (HF/PEFT)."""
        self.model.save_pretrained(output_dir + "/pretrained_model/")
        print(f"Pretrained model saved to {output_dir + '/pretrained_model/'}.", flush=True)
