import os
import copy
import json
import torch
import torch.distributed as dist
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from hexa.models.modeling_r2c2 import R2C2ForConditionalGeneration
from hexa.utils.metrics import AverageMetric
from transformers import BartConfig


def rename_state_dict(state_dict, offset, vocab_size, prefix='model', remove_str='seq2seq_'):
    new_state_dict = OrderedDict()
    for k,v in state_dict.items():
        k = k.lower()
        if remove_str is not None:
            if remove_str in k:
                k = k.replace(remove_str, '')
        if k=='start':
            continue
        if 'encoder.' in k:
            k = f'{prefix}.' + k
        if 'decoder.' in k:
            k = f'{prefix}.' + k
        if 'ffn.lin' in k:
            k = k.replace('ffn.lin', 'ffn_lin')
        if 'position_embeddings' in k:
            pass
            # v = torch.cat((v, v[-offset:]))
        if k == 'embeddings.weight':
            k = 'lm_head.weight'
        new_state_dict[k] = v    
    new_state_dict['final_logits_bias'] = torch.zeros([1, vocab_size])    
    new_state_dict['model.shared.weight'] = state_dict['embeddings.weight']    
    return new_state_dict


# parlai > utils > fp16.py
class FP16SafeCrossEntropy(torch.nn.Module):
    """
    FP16-safe cross entropy loss.

    This avoids overflow in the softmax by doing the operation in FP32.
    """

    def __init__(
        self,
        weight: Optional[torch.Tensor] = None,
        ignore_index: int = -100,
        reduction: str = 'none',
    ):
        # default ignore_index=-100 mimics pytorch's default in
        # torch.nn.functional.nll_loss
        super().__init__()
        self.register_buffer('weight', weight)  # type: ignore
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, scores, targets):
        return F.nll_loss(
            F.log_softmax(scores, 1, dtype=torch.float32),
            targets,
            weight=self.weight,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
        )


class HFR2C2Model(R2C2ForConditionalGeneration):
    def __init__(self, opt):
        config = copy.deepcopy(opt.model)
        config.fp16 = opt.trainer.fp16
        config.n_docs = opt.dataset.n_docs
        config = BartConfig.from_dict(config)
        super().__init__(config)
        self.config = config
        self.pad_token_id = self.config.pad_token_id
        self.start_token_id = self.config.bos_token_id
        self.end_token_id = self.config.eos_token_id
        self._load()
        self.criterion = self._build_criterion()
        
    def _build_criterion(self):
        if not self.config.fp16:
            return torch.nn.CrossEntropyLoss(
                ignore_index=self.pad_token_id, reduction='none'
            )
        else:
            # FP16 safe cross entropy (softmax done in FP32)
            # from utils import FP16SafeCrossEntropy
            return FP16SafeCrossEntropy(ignore_index=self.pad_token_id, reduction='none')        
        
    def _load(self):
        init_model = self.config.model_file
        print(f'Load checkpoint from {init_model}')
        if self.config.model_init_from_hf:
            states = torch.load(init_model, map_location=torch.device('cpu'))            
        else:
            states = torch.load(init_model, map_location=torch.device('cpu'))['model']
            states = rename_state_dict(states, offset=0, vocab_size=self.config.vocab_size)
        self.load_state_dict(states)        
        
    # parlai > agents > rag > model_types.py > get_forced_decoder_inputs
    @staticmethod    
    def _get_forced_decoder_inputs(
        inputs: torch.LongTensor,
        bsz: int,
        start_idx: int,
        end_idx: int,
        generation_model: str,
        start_param: Optional[torch.nn.Parameter] = None,
    ) -> torch.LongTensor:
        """
        Return the forced decoder inputs, depending on given parameters.

        These inputs are not formatted for RAG models.

        They merely correspond to the appropriate seq2seq_decoder input.
        """
        if generation_model == 'bart':
            tens = torch.LongTensor([end_idx, start_idx]).to(inputs).detach().expand(bsz, 2)
        elif start_param is not None:
            tens = start_param.detach().expand(bsz, 1).to(inputs)
        else:
            tens = torch.LongTensor([start_idx]).expand(inputs.size(0), 1).to(inputs)
        dec_inputs = torch.cat([tens, inputs], 1)
        return dec_inputs  # type: ignore        
    
    # parlai > agents > fid > fid.py > concat_enc_outs
    def _concat_enc_outs(
        self,
        bsz: int,
        encoder_hidden_states: torch.Tensor,
        attention_mask: torch.BoolTensor,
        right_padded: bool = True,
    ) -> Tuple[torch.Tensor, torch.BoolTensor]:
        """
        Concatenate Encoder Outputs.

        Does the whole "FiD" thing; each query/document pair is independently encoded in the
        Encoder, so we need to concatenate all the outputs prior to sending to the decoder.

        :param input:
            [bsz, seqlen] original input to the encoder
        :param encoder_hidden_states:
            [bsz * n_docs, seqlen] output representations from the encoder
        :param attention_mask:
            encoder mask
        :param embedding_size:
            emb/hidden size of the enc representations
        :param padding_idx:
            pad token index; used for mask purposes.
        :param right_padded:
            whether the input is right padded (true) or left padded (false)

        :return (new_out, new_mask):
            return the encoder output and encoder mask, appropriately concatenated.
        """
        # bsz, n_docs = inputs.size(0), encoder_hidden_states.size(0) // inputs.size(0)
        embedding_size = self.config.d_model
        n_docs = encoder_hidden_states.size(0) // bsz
        # assert n_docs == self.config.n_docs
        
        split_enc_out = encoder_hidden_states.split([n_docs] * bsz, dim=0)
        split_mask = attention_mask.split([n_docs] * bsz, dim=0)

        concat_outs: List[torch.Tensor] = []
        concat_lengths = []
        for i in range(bsz):
            mask_i = split_mask[i].view(-1)
            out_i = split_enc_out[i].reshape(-1, embedding_size)[mask_i]
            concat_outs.append(out_i)
            concat_lengths.append(out_i.size(0))

        new_out = encoder_hidden_states.new(bsz, max(concat_lengths), embedding_size)
        new_mask: torch.BoolTensor = attention_mask.new(bsz, max(concat_lengths))  # type: ignore
        new_out.fill_(self.pad_token_id)
        new_mask.fill_(False)

        for i, (out_i, length_i) in enumerate(zip(concat_outs, concat_lengths)):
            if right_padded:
                new_out[i, :length_i] = out_i
                new_mask[i, :length_i] = True
            else:
                new_out[i, new_out.size(1) - length_i :] = out_i
                new_mask[i, new_out.size(1) - length_i :] = True

        return new_out, new_mask
    
    # parlai > agents > rag > module.py > RagModel > decode_forced
    def _prepare_decoder_input_ids(self, label_vec):
        bsz, seqlen = label_vec.shape[:2]
        decoder_input_ids = label_vec.narrow(1, 0, seqlen - 1)
        decoder_input_ids = self._get_forced_decoder_inputs(
            decoder_input_ids,            
            bsz,
            start_idx=self.start_token_id,
            end_idx=self.end_token_id,
            generation_model=self.config.generation_model,
        )
        return decoder_input_ids     
    
    # projects > seeker > agents > seeker_modules.py > interleave_fid_combo_outputs
    def _interleave_fid_combo_outputs(
        self, 
        encoder_hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        skip_retrieval_vec: torch.Tensor,
        is_doc_added: torch.Tensor,
        right_padded: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        """
        Interleave FiD encoder outputs.
        Outputs are either encodings of documents/input,
        or encodings of just input.
        This operation is to preserve original batch order.
        """        
        # decompose encoder_hidden_states
        enc_out_skip_retrieval = encoder_hidden_states[is_doc_added==0]
        mask_skip_retrieval = attention_mask[is_doc_added==0]
        
        enc_out_retrieval = encoder_hidden_states[is_doc_added]        
        mask_retrieval = attention_mask[is_doc_added]
        
        enc_out_retrieval, mask_retrieval = self._concat_enc_outs(
            bsz = (skip_retrieval_vec==0).sum().item(),
            encoder_hidden_states = enc_out_retrieval, 
            attention_mask = mask_retrieval, 
            right_padded = right_padded,
        )                      
            
        bsz = enc_out_skip_retrieval.size(0) + enc_out_retrieval.size(0)
        dim = enc_out_skip_retrieval.size(-1)
        seqlen = max(enc_out_skip_retrieval.size(1), enc_out_retrieval.size(1))

        # merge encoder_hidden_states
        new_out = enc_out_retrieval.new(bsz, seqlen, dim).fill_(0)
        new_mask = mask_retrieval.new(bsz, seqlen).fill_(False)
        retr_offset = 0
        skip_offset = 0

        for i, skip in enumerate(skip_retrieval_vec):
            if skip:
                vec = enc_out_skip_retrieval[skip_offset]
                mask = mask_skip_retrieval[skip_offset]
                skip_offset += 1
            else:
                vec = enc_out_retrieval[retr_offset]
                mask = mask_retrieval[retr_offset]
                retr_offset += 1
            if right_padded:
                new_out[i, : vec.size(0), :] = vec
                new_mask[i, : mask.size(0)] = mask
            else:
                new_out[i, -vec.size(0) :, :] = vec
                new_mask[i, -mask.size(0) :] = mask
        return new_out, new_mask    
    
    # parlai > agents > rag > model_types.py > RagTurn > compute_loss
    def _compute_loss(
        self,
        scores: torch.Tensor,
        preds: torch.LongTensor,
        label_vec: torch.LongTensor,
        input_turns_cnt: Any,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        if scores.size(1) != label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        real_bsz = label_vec.size(0)
        resize_label = real_bsz != scores.size(0)
        if resize_label:
            assert self.config.turn_marginalize == 'doc_only'
            label_vec = label_vec.repeat_interleave(
                input_turns_cnt, dim=0
            )  # type: ignore

        # compute loss
        score_view = scores.reshape(-1, scores.size(-1))
        loss = self.criterion(score_view, label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        metric_loss = loss.tolist()

        if resize_label:
            assert self.config.turn_marginalize == 'doc_only'
            loss = sum_across_turns(
                loss, input_turns_cnt, discount=self.discount_factor
            )
            metric_loss = sum_across_turns(loss, input_turns_cnt).tolist()

        # compute metric counters
        notnull = label_vec.ne(self.pad_token_id)
        target_tokens = metric_target_tokens = notnull.long().sum(dim=-1)
        correct = metric_correct = ((label_vec == preds) * notnull).sum(dim=-1)
        if resize_label:
            metric_target_tokens = sum_across_turns(target_tokens, input_turns_cnt)
            metric_correct = sum_across_turns(correct, input_turns_cnt)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, metric_correct, metric_target_tokens

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        lm_labels: torch.Tensor,
        skip_retrieval_vec: torch.Tensor,
        is_doc_added: torch.Tensor,
        return_dict : Optional = None,
        input_turns_cnt : Optional = None,
        max_n_docs: Optional = None,
    ):

        encoder_hidden_states = self.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )            
        encoder_hidden_states = encoder_hidden_states.last_hidden_state             

        if torch.all(skip_retrieval_vec):
            pass
        
        elif torch.all(~skip_retrieval_vec):
            encoder_hidden_states, attention_mask = self._concat_enc_outs(
                bsz = lm_labels.shape[0],
                encoder_hidden_states = encoder_hidden_states, 
                attention_mask = attention_mask, 
                right_padded = True,
            )                          
        else:
            encoder_hidden_states, attention_mask = self._interleave_fid_combo_outputs(
                encoder_hidden_states = encoder_hidden_states,
                attention_mask = attention_mask,
                skip_retrieval_vec = skip_retrieval_vec,
                is_doc_added = is_doc_added,
                right_padded = True,
            )
        
        # decode_forced
        decoder_input_ids = self._prepare_decoder_input_ids(lm_labels)
        decoder_attention_mask = decoder_input_ids.ne(self.pad_token_id)        
        
        dec_out = self.model.decoder(
            input_ids = decoder_input_ids,
            attention_mask = decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,    
            encoder_attention_mask=attention_mask,
        )        
        dec_out = dec_out.last_hidden_state         
        lm_logits = self.lm_head(dec_out) + self.final_logits_bias
        _, preds = lm_logits.max(dim=-1)
        
        loss, metric_loss, metric_correct, metric_target_tokens = self._compute_loss(
            scores = lm_logits,
            preds = preds,
            label_vec = lm_labels,            
            input_turns_cnt = input_turns_cnt,
        )        

        return {
            'loss': loss,
            'metric_loss': metric_loss,
            'metric_target_tokens': metric_target_tokens,
            'metric_correct': metric_correct,
            'logits': lm_logits,
            'preds': preds,            
        }


class TrainInferenceComboHFR2C2Model(R2C2ForConditionalGeneration):
    def __init__(self, opt):
        config = copy.deepcopy(opt.model)
        config.fp16 = opt.trainer.fp16
        config.n_docs = opt.dataset.n_docs
        config = BartConfig.from_dict(config)
        super().__init__(config)
        self.config = config
        self.pad_token_id = self.config.pad_token_id
        self.start_token_id = self.config.bos_token_id
        self.end_token_id = self.config.eos_token_id
        self._load()
        self.criterion = self._build_criterion()

    def _build_criterion(self):
        if not self.config.fp16:
            return torch.nn.CrossEntropyLoss(
                ignore_index=self.pad_token_id, reduction='none'
            )
        else:
            # FP16 safe cross entropy (softmax done in FP32)
            # from utils import FP16SafeCrossEntropy
            return FP16SafeCrossEntropy(ignore_index=self.pad_token_id, reduction='none')

    def _load(self):
        init_model = self.config.model_file
        print(f'Load checkpoint from {init_model}')
        if self.config.model_init_from_hf:
            model_fpath = os.path.join(init_model, 'pytorch_model.bin')
            if os.path.isdir(init_model):
                if not os.path.exists(model_fpath) and dist.get_rank() == 0:
                    from glob import glob
                    sharded_files = glob(init_model + "/pytorch_model-*.bin")
                    assert len(sharded_files) > 1
                    print('The sharded model files will be merged.')
                    new_state_dict = {}
                    for path in sharded_files:
                        state = torch.load(path, map_location='cpu')
                        for k, v in state.items():
                            if k not in new_state_dict:
                                new_state_dict[k] = v
                    torch.save(new_state_dict, model_fpath)
                    for path in sharded_files:
                        os.remove(path)
                init_model = model_fpath
                if dist.is_initialized():
                    dist.barrier()
            states = torch.load(init_model, map_location=torch.device('cpu'))
            self.load_state_dict(states)
        else:
            states = torch.load(init_model, map_location=torch.device('cpu'))['model']
            states = rename_state_dict(states, offset=0, vocab_size=self.config.vocab_size)
            self.load_state_dict(states)

        # parlai > agents > rag > model_types.py > get_forced_decoder_inputs

    @staticmethod
    def _get_forced_decoder_inputs(
            inputs: torch.LongTensor,
            bsz: int,
            start_idx: int,
            end_idx: int,
            generation_model: str,
            start_param: Optional[torch.nn.Parameter] = None,
    ) -> torch.LongTensor:
        """
        Return the forced decoder inputs, depending on given parameters.

        These inputs are not formatted for RAG models.

        They merely correspond to the appropriate seq2seq_decoder input.
        """
        if generation_model == 'bart':
            tens = torch.LongTensor([end_idx, start_idx]).to(inputs).detach().expand(bsz, 2)
        elif start_param is not None:
            tens = start_param.detach().expand(bsz, 1).to(inputs)
        else:
            tens = torch.LongTensor([start_idx]).expand(inputs.size(0), 1).to(inputs)
        dec_inputs = torch.cat([tens, inputs], 1)
        return dec_inputs  # type: ignore

    # parlai > agents > fid > fid.py > concat_enc_outs
    def _concat_enc_outs(
            self,
            bsz: int,
            encoder_hidden_states: torch.Tensor,
            attention_mask: torch.BoolTensor,
            right_padded: bool = True,
    ) -> Tuple[torch.Tensor, torch.BoolTensor]:
        """
        Concatenate Encoder Outputs.

        Does the whole "FiD" thing; each query/document pair is independently encoded in the
        Encoder, so we need to concatenate all the outputs prior to sending to the decoder.

        :param input:
            [bsz, seqlen] original input to the encoder
        :param encoder_hidden_states:
            [bsz * n_docs, seqlen] output representations from the encoder
        :param attention_mask:
            encoder mask
        :param embedding_size:
            emb/hidden size of the enc representations
        :param padding_idx:
            pad token index; used for mask purposes.
        :param right_padded:
            whether the input is right padded (true) or left padded (false)

        :return (new_out, new_mask):
            return the encoder output and encoder mask, appropriately concatenated.
        """
        # bsz, n_docs = inputs.size(0), encoder_hidden_states.size(0) // inputs.size(0)
        embedding_size = self.config.d_model
        n_docs = encoder_hidden_states.size(0) // bsz
        # assert n_docs == self.config.n_docs

        split_enc_out = encoder_hidden_states.split([n_docs] * bsz, dim=0)
        split_mask = attention_mask.split([n_docs] * bsz, dim=0)

        concat_outs: List[torch.Tensor] = []
        concat_lengths = []
        for i in range(bsz):
            mask_i = split_mask[i].view(-1)
            out_i = split_enc_out[i].reshape(-1, embedding_size)[mask_i]
            concat_outs.append(out_i)
            concat_lengths.append(out_i.size(0))

        new_out = encoder_hidden_states.new(bsz, max(concat_lengths), embedding_size)
        new_mask: torch.BoolTensor = attention_mask.new(bsz, max(concat_lengths))  # type: ignore
        new_out.fill_(self.pad_token_id)
        new_mask.fill_(False)

        for i, (out_i, length_i) in enumerate(zip(concat_outs, concat_lengths)):
            if right_padded:
                new_out[i, :length_i] = out_i
                new_mask[i, :length_i] = True
            else:
                new_out[i, new_out.size(1) - length_i:] = out_i
                new_mask[i, new_out.size(1) - length_i:] = True

        return new_out, new_mask

    # parlai > agents > rag > module.py > RagModel > decode_forced
    def _prepare_decoder_input_ids(self, label_vec):
        bsz, seqlen = label_vec.shape[:2]
        decoder_input_ids = label_vec.narrow(1, 0, seqlen - 1)
        decoder_input_ids = self._get_forced_decoder_inputs(
            decoder_input_ids,
            bsz,
            start_idx=self.start_token_id,
            end_idx=self.end_token_id,
            generation_model=self.config.generation_model,
        )
        return decoder_input_ids

        # projects > seeker > agents > seeker_modules.py > interleave_fid_combo_outputs

    def _interleave_fid_combo_outputs(
            self,
            encoder_hidden_states: torch.Tensor,
            attention_mask: torch.Tensor,
            skip_retrieval_vec: torch.Tensor,
            is_doc_added: torch.Tensor,
            right_padded: bool = True,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        """
        Interleave FiD encoder outputs.
        Outputs are either encodings of documents/input,
        or encodings of just input.
        This operation is to preserve original batch order.
        """
        # decompose encoder_hidden_states
        enc_out_skip_retrieval = encoder_hidden_states[is_doc_added == 0]
        mask_skip_retrieval = attention_mask[is_doc_added == 0]

        enc_out_retrieval = encoder_hidden_states[is_doc_added]
        mask_retrieval = attention_mask[is_doc_added]

        enc_out_retrieval, mask_retrieval = self._concat_enc_outs(
            bsz=(skip_retrieval_vec == 0).sum().item(),
            encoder_hidden_states=enc_out_retrieval,
            attention_mask=mask_retrieval,
            right_padded=right_padded,
        )

        bsz = enc_out_skip_retrieval.size(0) + enc_out_retrieval.size(0)
        dim = enc_out_skip_retrieval.size(-1)
        seqlen = max(enc_out_skip_retrieval.size(1), enc_out_retrieval.size(1))

        # merge encoder_hidden_states
        new_out = enc_out_retrieval.new(bsz, seqlen, dim).fill_(0)
        new_mask = mask_retrieval.new(bsz, seqlen).fill_(False)
        retr_offset = 0
        skip_offset = 0

        for i, skip in enumerate(skip_retrieval_vec):
            if skip:
                vec = enc_out_skip_retrieval[skip_offset]
                mask = mask_skip_retrieval[skip_offset]
                skip_offset += 1
            else:
                vec = enc_out_retrieval[retr_offset]
                mask = mask_retrieval[retr_offset]
                retr_offset += 1
            if right_padded:
                new_out[i, : vec.size(0), :] = vec
                new_mask[i, : mask.size(0)] = mask
            else:
                new_out[i, -vec.size(0):, :] = vec
                new_mask[i, -mask.size(0):] = mask
        return new_out, new_mask

        # parlai > agents > rag > model_types.py > RagTurn > compute_loss

    def _compute_loss(
            self,
            scores: torch.Tensor,
            preds: torch.LongTensor,
            label_vec: torch.LongTensor,
            input_turns_cnt: Any,
    ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]:
        if scores.size(1) != label_vec.size(1):
            # ignore start
            scores = scores[:, 1:, :]
            preds = preds[:, 1:]  # type: ignore

        real_bsz = label_vec.size(0)
        resize_label = real_bsz != scores.size(0)
        if resize_label:
            assert self.config.turn_marginalize == 'doc_only'
            label_vec = label_vec.repeat_interleave(
                input_turns_cnt, dim=0
            )  # type: ignore

        # compute loss
        score_view = scores.reshape(-1, scores.size(-1))
        loss = self.criterion(score_view, label_vec.view(-1))
        loss = loss.view(scores.shape[:-1]).sum(dim=1)
        metric_loss = loss.tolist()

        # compute metric counters
        notnull = label_vec.ne(self.pad_token_id)
        target_tokens = metric_target_tokens = notnull.long().sum(dim=-1)
        correct = metric_correct = ((label_vec == preds) * notnull).sum(dim=-1)

        loss = loss.sum()
        loss /= target_tokens.sum()  # average loss per token
        return loss, metric_loss, metric_correct, metric_target_tokens

    def forward_loss(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            lm_labels: torch.Tensor,
            skip_retrieval_vec: torch.Tensor,
            is_doc_added: torch.Tensor,
            return_dict: Optional = None,
            input_turns_cnt: Optional = None,
            max_n_docs: Optional = None,
    ):

        encoder_hidden_states = self.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        encoder_hidden_states = encoder_hidden_states.last_hidden_state

        if torch.all(skip_retrieval_vec):
            pass

        elif torch.all(~skip_retrieval_vec):
            encoder_hidden_states, attention_mask = self._concat_enc_outs(
                bsz=lm_labels.shape[0],
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                right_padded=True,
            )
        else:
            encoder_hidden_states, attention_mask = self._interleave_fid_combo_outputs(
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                skip_retrieval_vec=skip_retrieval_vec,
                is_doc_added=is_doc_added,
                right_padded=True,
            )

        # decode_forced
        decoder_input_ids = self._prepare_decoder_input_ids(lm_labels)
        decoder_attention_mask = decoder_input_ids.ne(self.pad_token_id)

        dec_out = self.model.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=attention_mask,
        )
        dec_out = dec_out.last_hidden_state
        lm_logits = self.lm_head(dec_out) + self.final_logits_bias
        _, preds = lm_logits.max(dim=-1)

        loss, metric_loss, metric_correct, metric_target_tokens = self._compute_loss(
            scores=lm_logits,
            preds=preds,
            label_vec=lm_labels,
            input_turns_cnt=input_turns_cnt,
        )

        return {
            'loss': loss,
            'metric_loss': metric_loss,
            'metric_target_tokens': metric_target_tokens,
            'metric_correct': metric_correct,
            'logits': lm_logits,
            'preds': preds,
        }
