import torch
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, TrainingArguments, AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from transformers.trainer_utils import TrainOutput
import time 
import pdb
from tqdm import tqdm


class LookGradSFTTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def training_step(self, model, inputs):
        model.train()
        inputs = self._prepare_inputs(inputs)

        outputs = model(**inputs)
        loss = outputs.loss

        loss.backward()
        pdb.set_trace()

        return loss.detach()

#  model_state = self.model.state_dict()
#  Bkey = [key for key in model_state.keys() if 'lora_B' in key]
#  model_state[Bkey[20]]

# (Pdb) model_state = self.model.state_dict()
# (Pdb) Bkey = [key for key in model_state.keys() if 'lora_B' in key]
# (Pdb) model_state[Bkey[20]]
