
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F



def replay_compute(train_dataloader_prune,buffer,args):

    # Init
    progress_bar = tqdm(range(len(train_dataloader_prune)))

    with torch.no_grad():
        inputs = next(iter(train_dataloader_prune)) # only one batch
        # for step, inputs in enumerate(train_dataloader_prune):

        input_ids = inputs['input_ids']
        labels = inputs['labels']
        attention_mask = inputs['attention_mask']
        task = inputs['task']

        labels = F.pad(labels, (0,args.max_target_length-labels.size(1)), "constant", -100)  # effectively zero padding
        input_ids = F.pad(input_ids, (0,args.max_source_length-input_ids.size(1)), "constant", args.tokenizer.pad_token_id)  # effectively zero padding
        attention_mask = F.pad(attention_mask, (0,args.max_source_length-attention_mask.size(1)), "constant", 0)  # effectively zero padding

        # needed for the buffer module
        buffer.add_data(examples=input_ids, labels=labels, attention_mask=attention_mask, task=task)

        progress_bar.update(1)
        progress_bar.set_description('DERPP Compute Iter ')  # show the loss, mean while

    return buffer
