from multiprocessing import reduction
import pytorch_lightning as pl
from models import utils
from transformers import (
    Adafactor,
    GPT2LMHeadModel,
    GPT2Config,
    GPT2DoubleHeadsModel,
    AutoTokenizer,
    GPT2Tokenizer,
    GPT2TokenizerFast,
    GPTJForCausalLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig
)
# from models.GPT_Neo import GPTNeoForCausalLM
import jsonlines
import torch
from Datasets import Custom_Dataset_GPT2
from torch.utils.data import RandomSampler
from torch.utils.data import DataLoader, ConcatDataset
from collections import Counter
from tqdm import tqdm
import re
import torch.nn.functional as F
import torch.distributed as dist
from torchmetrics.functional import accuracy
from transformers import get_polynomial_decay_schedule_with_warmup, get_constant_schedule_with_warmup
import string
import math
import os
import csv
import random
import pandas as pd
from datetime import datetime
import deepspeed
import numpy as np

class GPT2Valid(pl.LightningModule):
    def __init__(self, hparams):
        super(GPT2Valid, self).__init__()
        self.check_validation_only = hparams.check_validation_only
        self.eval_dataset = hparams.eval_dataset
        self.mode = hparams.mode

        # Model Initializaion
        self.tokenizer = GPT2Tokenizer.from_pretrained(
            hparams.tokenizer_name_or_path, fast=False)
        if 'gpt' in hparams.tokenizer_name_or_path:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if hparams.method == 'baseline':
            if 'gpt-j' in hparams.model_name_or_path:
                self.model = GPTJForCausalLM.from_pretrained(
                    hparams.model_name_or_path, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True, 
                    resid_pdrop=0, embd_pdrop=0, attn_pdrop=0, pad_token_id=self.tokenizer.eos_token_id)
            else:
                try: # GPT
                    self.model = AutoModelForCausalLM.from_pretrained(
                        hparams.model_name_or_path, resid_pdrop=0, embd_pdrop=0, attn_pdrop=0, pad_token_id=self.tokenizer.eos_token_id)
                except TypeError:
                    try: # GPT-Neo
                        self.model = AutoModelForCausalLM.from_pretrained(
                            hparams.model_name_or_path, resid_dropout=0, embed_dropout=0, attention_dropout=0, pad_token_id=self.tokenizer.eos_token_id)
                    except TypeError: # OPT
                        self.model = AutoModelForCausalLM.from_pretrained(
                            hparams.model_name_or_path, torch_dtype=torch.float16, dropout=0, attention_dropout=0, activation_dropout=0)
        elif hparams.method == 'scratch':
            config = AutoConfig.from_pretrained(hparams.model_name_or_path)
            config.vocab_size = len(self.tokenizer)
            config.resid_pdrop = 0
            config.embd_pdrop = 0
            config.attn_pdrop = 0
            config.pad_token_id = self.tokenizer.eos_token_id
            self.model = AutoModelForCausalLM(config)
        else:
            raise Exception(f'Currently not supporting {hparams.method}')

        self.save_hyperparameters(hparams)
        print(f'Using {hparams.method} as method')
        if hparams.negative_loss:
            print('********Training on negative loss!!********')

        if self.check_validation_only:
            self.model.eval()
            # self._device = torch.device(
                # 'cuda') if torch.cuda.is_available() else torch.device('cpu')
            # self.model.to('cuda')

        self.model.resize_token_embeddings(len(self.tokenizer))

        self.vocab_size = self.tokenizer.vocab_size
        self.batch_size_per_gpu = self.hparams.eval_batch_size
        self.pred_log = self.hparams.pred_log
        self.output_dir = self.hparams.output_dir
        self.init_validation = True
        self.valid_df = None

    def freeze_params(self, model):
        for name, param in self.model.named_parameters():
            param.requires_grad = False

    def normalize_answer(self, s):
        """Lower text and remove punctuation, articles and extra whitespace."""
        def remove_articles(text):
            return re.sub(r"\b(a|an|the)\b", " ", text)

        def white_space_fix(text):
            return " ".join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return "".join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        def rid_of_specials(text):
            text = text.replace("<extra_id_0>", "")
            text = text.replace("<extra_id_1>", "")
            return text

        return rid_of_specials(white_space_fix(remove_articles(remove_punc(lower(s)))))

    def exact_match_score(self, prediction, ground_truth):
        return int(self.normalize_answer(prediction) == self.normalize_answer(ground_truth))

    def _f1_score(self, prediction, ground_truth):
        prediction_tokens = self.normalize_answer(prediction).split()
        ground_truth_tokens = self.normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    def get_dataset(self, dataset_name, tokenizer, subset_path, type_path, args, length=None):
        if length:
            dataset = Custom_Dataset_GPT2(dataset_name=dataset_name, tokenizer=tokenizer, subset_path=subset_path, type_path=type_path, input_length=length,
                                        output_length=length, args=args)
        else:
            dataset = Custom_Dataset_GPT2(dataset_name=dataset_name, tokenizer=tokenizer, subset_path=subset_path, type_path=type_path, input_length=args.max_input_length,
                                        output_length=args.max_output_length, args=args)
        return dataset

    def lmap(self, f, x):
        """list(map(f, x))"""
        return list(map(f, x))

    def forward(self, input_ids, attention_mask=None, lm_labels=None):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=lm_labels,
        )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100
        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            lm_labels=lm_labels
        )
        score = None
        loss, score = outputs[0], outputs[1]
        return loss, score

    def ids_to_clean_text(self, generated_ids):
        gen_text = self.tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        return self.lmap(str.strip, gen_text)
    
    def on_validation_end(self):
        if self.hparams.mode  in ['log_individual_ppl', 'negative_inject'] and self.init_validation:
            time = datetime.now().strftime('%m.%d.%H.%M')
            self.valid_df.to_csv(f'outputs/init_{self.hparams.wandb_run_name}.csv')
            self.init_validation = False

    def validation_accuracy(self, batch, dataset_name):
        input_ids = batch['source_ids']

        max_len = self.hparams.injected_length if self.hparams.injected_length else self.hparams.max_input_length

        labels, preds = [], []
        for i in range(1, max_len):
            label = input_ids[..., i]
            prompt = input_ids[..., :i]
            try:
                pred = self.model.generate(prompt, max_length=i+1)[:, -1]
            except IndexError:
                pred = self.model.generate(torch.squeeze(prompt), max_length=i+1).squeeze()[-1]
            # model_inputs = self.model.prepare_inputs_for_generation(prompt)
            # outputs = self.model(**model_inputs, return_dict=True)
            # assert i == outputs.logits.shape[1]
            # next_token_logits = outputs.logits[..., -1, :]
            # pred = torch.argmax(next_token_logits, dim=-1)

            labels.append(torch.squeeze(label))
            preds.append(torch.squeeze(pred))

        preds = torch.stack(preds)
        labels = torch.stack(labels)
        
        score = accuracy(preds, labels, ignore_index=-100)
        self.log(f'{dataset_name}/acc', score, on_epoch=True,
                 prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)

        return torch.t(preds), torch.t(labels)


    def validation_el(self, batch, dataset_name, N=(10,)):
        input_ids = batch['source_ids']
        # input_ids = torch.squeeze(batch['source_ids'])
        # if len(input_ids.shape) == 1:
        #     input_ids = torch.unsqueeze(input_ids, 0)
        max_len = self.hparams.injected_length if self.hparams.injected_length else self.hparams.max_input_length

        batch_size = input_ids.shape[0]
        hard_numerator = [0] * batch_size
        soft_numerator = {n: [0] * batch_size for n in N}

        for i in reversed(range(1, max_len)):
            label = input_ids[..., i:max_len]
            prompt = input_ids[..., :i]
            pred = self.model.generate(prompt, max_length=max_len)[..., i:]

            for example_idx in range(batch_size):
                p, l = pred[example_idx], label[example_idx]
                # extaction likelihood hard
                if torch.equal(p, l):
                    hard_numerator[example_idx] += 1                    
                # extraction likelihood soft
                for n in N:
                    p_ngram = utils.ngram_of_1D_tensor(p, n)
                    l_ngram = utils.ngram_of_1D_tensor(l, n)
                    l_unique = set(l_ngram)
                    p_tp = [i for i in p_ngram if i in l_unique]
                    try:
                        p_acc = len(p_tp) / len(l_ngram)
                        soft_numerator[n][example_idx] += p_acc
                    except ZeroDivisionError: # n-gram isn't defined
                        pass

                # lcs = utils.lcs_of_1D_tensor(p, l)
                # soft_numerator[example_idx] += lcs / l.shape[0]


        demoninator = max_len - 1
        hard_score = [h / demoninator for h in hard_numerator]
        soft_score = {n: [0] * batch_size for n in N}
        for n in N:
            for i, _ in enumerate(soft_numerator[n]):
                soft_score[n][i] = soft_numerator[n][i] / (max_len - 1 - (n - 1))

        # soft_score = [s / demoninator for s in soft_numerator]
        self.log(f'{dataset_name}/el', sum(hard_score) / len(hard_score),
                    prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        for n in N:
            self.log(f'{dataset_name}/el_soft_{n}-gram', sum(soft_score[n]) / len(soft_score[n]),
                        prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)

        ret = {'el': torch.Tensor(hard_score)}
        for k in soft_score.keys():
            ret[f'el_soft_{k}-gram'] = torch.Tensor(soft_score[k])
        return ret

    def validation_no_reduce(self, batch, dataset_name='original', calc_acc=True, calc_el=True, generate_example=True):
        loss_reduced, score = self._step(batch)
        self.log('val_loss', loss_reduced, on_epoch=True, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        self.log('val_ppl', torch.exp(loss_reduced), on_epoch=True, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)

        value_dict = {}
        if calc_acc:
            preds, labels = self.validation_accuracy(batch, dataset_name)
            accs = []
            if len(preds.shape)==1:
                preds = torch.unsqueeze(preds, 0)
                labels = torch.unsqueeze(labels, 0)

            for pred, label in zip(preds, labels):
                try:
                    acc = accuracy(pred, label, ignore_index=-100)
                    accs.append(acc)
                except IndexError:
                    pass
            if accs:
                accs = torch.stack(accs)
            value_dict['acc'] = accs
        
        if calc_el:
            el = self.validation_el(batch, dataset_name)
            value_dict.update(el)

        if generate_example:
            max_len = self.hparams.injected_length if self.hparams.injected_length else self.hparams.max_input_length
            input_ids = batch['source_ids']
            prompt = input_ids[..., :100]
            pred = self.model.generate(prompt, max_length=max_len)[..., 100:]
            value_dict['preds'] = pred

        shift_logits = score[..., :-1, :].contiguous().squeeze()
        shift_labels = batch['target_ids'][..., 1:].contiguous().squeeze()
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        loss_no_reduce = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # reduce along sequence only, leave batch
        if len(batch['target_ids'].shape) > 1:
            loss_no_reduce = loss_no_reduce.view(batch['target_ids'].shape[0], -1) # (batch, seq_len)
        else:
            loss_no_reduce = torch.unsqueeze(loss_no_reduce, 0)
        mean_losses = []
        for seq_loss in loss_no_reduce:
            mean_loss = seq_loss[seq_loss != 0].mean()
            mean_losses.append(mean_loss)
        
        mean_losses = torch.stack(mean_losses)

        value_dict['doc_id'] = batch['doc_id']
        value_dict['loss'] = mean_losses
        return value_dict

    def validation_general_lm(self, batch):
        task = batch["task"][0]
        task_type = batch["task_type"][0]
        
        # get ppl of Pile dataset
        if task_type == 'ppl':
            loss, score = self._step(batch)
            # ppl = torch.exp(loss)
            self.log(f'{task}/loss', loss,
                        on_epoch=True, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        elif task_type == 'classification':
            self.classification_verbalizer(padding_length=self.hparams.max_input_length, task=task, batch=batch, choices=batch["choices"], answer_index=batch["answer_index"])
        elif task_type == 'completion':
            self.lambada_evaluation(padding_length=self.hparams.max_input_length, task='lambada', batch=batch)
        else:
            raise Exception(f'Currently, {task} not implemented..')

    def validation_step(self, batch, batch_idx, dataloader_idx=-1):
        if self.mode == 'mwo':
            if dataloader_idx == 0:
                self.validation_accuracy(batch, calc_perp=True,
                                    dataset_name='train/')                
            elif dataloader_idx == 1:
                self.validation_accuracy(batch, calc_perp=True,
                                    dataset_name='val_entity_O/')
            else:
                self.validation_accuracy(batch, calc_perp=True,
                                    dataset_name='val_entity_X/')
        elif self.mode == 'log_individual_ppl':
            return self.validation_no_reduce(batch)
        elif self.mode == 'general_lm_eval':
            return self.validation_general_lm(batch)
        elif self.mode == 'negative_inject':
            if dataloader_idx in [0, -1]:
                return self.validation_no_reduce(batch)
            else:
                self.validation_general_lm(batch)
        else:
            raise Exception(
                f'Currently not supporting {self.mode} for validation')

    def validation_epoch_end(self, output):
        if self.hparams.mode in ['log_individual_ppl', 'negative_inject']:
            if self.init_validation:
                log_col_name = 'init'
            else:
                log_col_name = f'{self.current_epoch:02d}'

            print('All Gather')
            if len(self.hparams.valid_sets) > 1:
                outputs = self.all_gather(output)[0]
            else:
                outputs = self.all_gather(output)
            print('Done Gather')
            keys = outputs[0].keys() # [doc_id, loss, acc, el, el_soft]
            full_output = {k: [] for k in keys}

            # gather all outputs
            for out in outputs:
                for k in keys:
                    full_output[k].append(torch.flatten(out[k]))
            
            # refactor into pandas favorable format
            for k in keys:
                full_output[k] = torch.cat(full_output[k])
                full_output[k] = torch.flatten(full_output[k]).cpu().numpy()

            if 'preds' in full_output:
                if len(full_output['preds'].shape) == 1:
                    full_output['preds'] = self.tokenizer.decode(full_output['preds'])
                else:
                    full_output['preds'] = self.tokenizer.batch_decode(full_output['preds'])

            # except for 'doc_id' rename all keys
            for k in list(keys):
                full_output[f'{k}_{log_col_name}'] = full_output.pop(k)
            full_output['doc_id'] = full_output.pop(f'doc_id_{log_col_name}')

            df = pd.DataFrame(full_output)
            
            # append to the df that stores all results from all ddp processes
            df['doc_id'] = df['doc_id'].astype(int)
            df = df.drop_duplicates(['doc_id'])
            df = df.set_index('doc_id')
            self.valid_df = self.valid_df.combine_first(df)
            self.valid_df = self.valid_df.reindex(self.valid_df_index)
            print(self.valid_df)
            # log forgetting chunk-wise (temporarily hard-coded)
            if len(self.hparams.train_sets) > 1:
                splits = np.array_split(self.valid_df, 4)
                for i, split in enumerate(splits):
                    print(i, split.index)
                    self.log(f'{i}/el_soft_10-gram', split[f'el_soft_10-gram_{log_col_name}'].mean(),
                        prog_bar=True, logger=True, add_dataloader_idx=False)
                    self.log(f'{i}/acc', split[f'acc_{log_col_name}'].mean(),
                        prog_bar=True, logger=True, add_dataloader_idx=False)


    def get_rid_of_pad(self, tokens):
        while tokens[-1] == -100 or tokens[-1] == self.tokenizer.pad_token_id:
            tokens.pop()
        return tokens
        
    def val_dataloader(self):
        datasets = []
        for i in range(len(self.hparams.valid_sets)):
            dataset = self.hparams.valid_sets[i]
            subset_path = self.hparams.subset_path[i]
            type_path = self.hparams.valid_type_path[i]
            dataset_name = dataset

            length = None
            if 'extraction' in dataset:
                length = self.hparams.injected_length

            dataset = self.get_dataset(
                dataset_name=dataset_name, tokenizer=self.tokenizer, subset_path=subset_path, type_path=type_path, args=self.hparams, length=length)
            datasets.append(dataset)
        if self.mode  in ['log_individual_ppl', 'negative_inject'] and self.valid_df is None:
            self.valid_df = datasets[0].dataset
            self.valid_df = self.valid_df.set_index('doc_id')
            self.valid_df_index = self.valid_df.index
            self.valid_df['prefix'] = self.valid_df['text'].apply(lambda x: self.tokenizer.decode(self.tokenizer.encode(x)[:100]))

        dataloaders = []
        for i, dataset in enumerate(datasets):
            if i == 0 and self.mode in ['log_individual_ppl', 'negative_inject']:  # For the injected data (should be refactored later)
                dataloaders.append(DataLoader(dataset,
                batch_size=self.hparams.train_batch_size * self.hparams.gradient_accumulation_steps * len(self.hparams.train_sets),
                num_workers=self.hparams.num_workers, shuffle=False))
            else:
                dataloaders.append(DataLoader(dataset, batch_size=self.hparams.eval_batch_size,
                                    num_workers=self.hparams.num_workers, shuffle=False))    
        return dataloaders

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.model.config.n_ctx
        except AttributeError:
            # gptneoconfig doesn't have n_ctx apparently
            return self.model.config.max_position_embeddings

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model
        """
        with torch.no_grad():
            res = self.model(inps)
            return res[0][:, :, :]

    def classification_verbalizer(self, padding_length, task, batch, choices, answer_index):
        source_ids = batch["source_ids"].tolist()
        target_ids = batch["target_ids"]
        batch_size = len(source_ids)
        answer_idx = [-1] * batch_size
        for i in range(batch_size):
            answer_idx[i] = answer_index[i]

        batch_acc = 0

        inps = []
        cont_toks_list = []
        inplens = []

        answers = torch.zeros(batch_size, len(choices), device=self.device)

        for c_idx in range(len(choices)):
            #choice_ids = self.tokenizer.batch_encode_plus([str(choices[c_idx])], max_length=self.hparams.max_input_length,
            #                                                padding='max_length', truncation=True, return_tensors="pt")["input_ids"].squeeze().tolist()
            choice_ids = self.tokenizer.batch_encode_plus(list(choices[c_idx]), max_length=self.hparams.max_input_length, add_special_tokens=False,
                                                            padding='max_length', truncation=True, return_tensors="pt")["input_ids"].tolist()
            for i in range(batch_size):
                context_enc = self.get_rid_of_pad(source_ids[i])
                continuation_enc = self.get_rid_of_pad(choice_ids[i])
                # if len(continuation_enc) > 10:
                #    continuation_enc = continuation_enc[len(continuation_enc)-10:]
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length
                # inp = torch.tensor(
                #    (context_enc + continuation_enc)[-(self.max_length+1):][:-1],
                #    dtype=torch.long
                # ).to(self.device)
                inp = torch.tensor(
                    (context_enc + continuation_enc)[-(padding_length):][:-1],
                    dtype=torch.long
                ).to(self.device)
                inplen, = inp.shape
                cont = continuation_enc
                # since in _collate we make sure length is descending, the longest is always the first one.
                #padding_length = padding_length if padding_length is not None else inplen
                # pad length from seq to padding_length
                inp = torch.cat([
                    inp,  # [seq]
                    # [padding_length - seq]
                    torch.zeros(padding_length - inplen,
                                dtype=torch.long).to(inp.device) + self.tokenizer.pad_token_id
                ], dim=0)
                inps.append(inp.unsqueeze(0))  # [1, padding_length]
                cont_toks_list.append(cont)
                inplens.append(inplen)

            batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
            multi_logits = F.log_softmax(self._model_call(
                batched_inps), dim=-1)  # [batch, padding_length, vocab]
            cnt = 0
            for logits, inp, inplen, cont_toks \
                    in zip(multi_logits, inps, inplens, cont_toks_list):

                # Slice to original seq length
                contlen = len(cont_toks)
                original_logits = logits

                # [1, seq, vocab]
                logits = logits[inplen-contlen:inplen].unsqueeze(0)
                # Check if per-token argmax is exactly equal to continuation
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
                    0).to(self.device)   # [1, seq]
                # Obtain log-probs at the corresponding continuation token indices
                # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                logits = torch.gather(
                    logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)  # [1, seq]
                # Answer: (log prob, is-exact-match)
                loss = -float(logits.sum())
                answers[cnt][c_idx] = loss
                cnt += 1
            inps = []
            cont_toks_list = []
            inplens = []

        answer_idx = torch.Tensor(answer_idx).to(self.device)
        answers = torch.argmin(answers, dim=1)

        batch_acc = int(torch.where(answers == answer_idx, 1, 0).sum())

        batch_acc_avg = batch_acc / batch_size

        self.log(f'{task}/acc', batch_acc_avg, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)

        return

    def lambada_evaluation(self, padding_length, task, batch):
        source_ids = batch["source_ids"].tolist()
        target_ids = batch["target_ids"].tolist()
        batch_size = len(source_ids)
        batch_loss = 0
        batch_acc = 0
        batch_f1 = 0
        inps = []
        cont_toks_list = []
        inplens = []
        for i in range(batch_size):
            if source_ids[i]==target_ids[i]:
                context_enc = source_ids[i][:padding_length-10]
                continuation_enc = target_ids[i][padding_length-10:]
            else:
                context_enc = self.get_rid_of_pad(source_ids[i])
                continuation_enc = self.get_rid_of_pad(target_ids[i])
                #if len(continuation_enc) > 10:
                #    continuation_enc = continuation_enc[len(continuation_enc)-10:]
            # sanity check
            assert len(context_enc) > 0
            assert len(continuation_enc) > 0
            assert len(continuation_enc) <= self.max_length

            #inp = torch.tensor(
            #    (context_enc + continuation_enc)[-(self.max_length+1):][:-1],
            #    dtype=torch.long
            #).to(self.device)
            inp = torch.tensor(
                (context_enc + continuation_enc)[-(padding_length):][:-1],
                dtype=torch.long
            ).to(self.device)
            inplen, = inp.shape
            cont = continuation_enc

            # since in _collate we make sure length is descending, the longest is always the first one.
            #padding_length = padding_length if padding_length is not None else inplen
            # pad length from seq to padding_length
            inp = torch.cat([
                inp,  # [seq]
                torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device)  # [padding_length - seq]
            ], dim=0)
            inps.append(inp.unsqueeze(0))  # [1, padding_length]
            cont_toks_list.append(cont)
            inplens.append(inplen)

        batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
        multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu()  # [batch, padding_length, vocab]
        for logits, inp, inplen, cont_toks \
                in zip(multi_logits, inps, inplens, cont_toks_list):

            # Slice to original seq length
            contlen = len(cont_toks)
            original_logits = logits
            logits = logits[inplen-contlen:inplen].unsqueeze(0)  # [1, seq, vocab]
            # Check if per-token argmax is exactly equal to continuation
            greedy_tokens = logits.argmax(dim=-1)
            cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)  # [1, seq]
            max_equal = (greedy_tokens == cont_toks).all()
            predicted = self.ids_to_clean_text(greedy_tokens)
            ground_truth = self.ids_to_clean_text(cont_toks)
            em = self.exact_match_score(predicted[0], ground_truth[0])
            f1 = self._f1_score(predicted[0], ground_truth[0])

            # Obtain log-probs at the corresponding continuation token indices
            # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
            logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)  # [1, seq]
            # Answer: (log prob, is-exact-match)
            loss = -float(logits.sum())
            if bool(max_equal) or em==1:
                batch_acc+=1
                
            batch_loss += loss
            batch_f1 += f1
            
        batch_loss_avg = batch_loss / batch_size
        batch_acc_avg = batch_acc / batch_size
        batch_f1_avg = batch_f1 / batch_size
        self.log(f'{task}/loss', batch_loss_avg, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f'{task}/acc', batch_acc_avg, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        self.log(f'{task}/f1', batch_f1_avg, prog_bar=True, logger=True, add_dataloader_idx=False, sync_dist=True)
        return