import gc
import inspect
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import deepspeed
import torch
import transformers
from torchtyping import TensorType
from transformers.modeling_outputs import ModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.bloom import modeling_bloom
from transformers.models.opt import modeling_opt
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from trlx.data.method_configs import MethodConfig, register_method
from trlx.models.modeling_base import PreTrainedModelWrapper
from trlx.utils.modeling import (
    hf_get_decoder,
    hf_get_decoder_blocks,
    hf_get_decoder_final_norm,
    hf_get_hidden_size,
    hf_get_lm_head,
    hf_get_num_hidden_layers,
    make_head,
)


# PPO Configs


@dataclass
@register_method
class P3OConfig(MethodConfig):
    """
    Config for P3O method

    :param p3o_epochs: Number of updates per batch
    :type p3o_epochs: int

    :param num_rollouts: Number  of experiences to observe before learning
    :type num_rollouts: int

    :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)
    :type cliprange: float

    :param cliprange_value: Clipping range for predicted values
                            (observed values - cliprange_value, observed values + cliprange_value)
    :type cliprange_value: float

    :param gen_kwargs: Additional kwargs for the generation
    :type gen_kwargs: Dict[str, Any]

    :param gen_experience_kwargs: if this is not None, then the experience is generated using this
    :type gen_experience_kwargs: Dict[str, Any]
    """

    p3o_epochs: int
    num_rollouts: int
    chunk_size: int
    num_responses_per_query: int
    kl_coef: float
    cliprange: float
    cliprange_ratio: float
    scale_reward: Optional[str]
    ref_mean: Optional[float]
    ref_std: Optional[float]
    cliprange_reward: float
    gen_kwargs: dict
    gen_experience_kwargs: Optional[dict] = None
    num_value_layers_unfrozen: int = 0
    clip_tokenwise: bool = False
    avg_tokenwise: bool = True
    scale_q: bool = False

    def loss(
        self,
        logratio: List[TensorType["batch_size"]],
        rewards: List[TensorType["batch_size"]],
        old_logratio: List[TensorType["batch_size"]],
    ):
        scale_q = False
        q_diff = (rewards[0] - rewards[1] - self.kl_coef * (logratio[0] - logratio[1])).detach()
        if scale_q:
            q_diff = q_diff / torch.std(q_diff)
        ratio = torch.exp((logratio[0] - old_logratio[0]) + (logratio[1] - old_logratio[1])).detach()
        cliped_ratio_old = torch.clamp(ratio, 1 / self.cliprange_ratio, self.cliprange_ratio)

        loss1 = -q_diff * cliped_ratio_old.detach() * (logratio[0] - logratio[1]) / 2
        loss1_clip = (
            -q_diff * cliped_ratio_old.detach() * torch.clamp(logratio[0] - logratio[1], old_logratio[0] - old_logratio[1] - self.cliprange, old_logratio[0] - old_logratio[1] + self.cliprange) / 2
        )
        loss = torch.max(loss1, loss1_clip).mean()

        # log quantity of interest
        logratio_chosen_mean = torch.mean(logratio[0] * (rewards[0] > rewards[1]).float() + logratio[1] * (rewards[1] > rewards[0]).float()).item()
        logratio_lose_mean = torch.mean(logratio[0] * (rewards[0] < rewards[1]).float() + logratio[1] * (rewards[1] < rewards[0]).float()).item()
        policy_clipfrac = torch.sum((loss1_clip > loss1).float()) / loss1_clip.shape[0]

        return loss, {
            "loss": loss.item(),
            "reward_diff_mean": torch.mean(rewards[0] - rewards[1]).item(),
            "reward_diff_abs": torch.mean(torch.abs(rewards[0] - rewards[1])).item(),
            "reward_diff_std": torch.std(rewards[0] - rewards[1]).item(),
            "logratio_chosen_mean": logratio_chosen_mean,
            "logratio_lose_mean": logratio_lose_mean,
            "logratio_gap_mean": logratio_chosen_mean - logratio_lose_mean,
            "ratio_mean": ratio.mean().item(),
            "ratio_max": ratio.max().item(),
            "ratio_min": ratio.min().item(),
            "policy_clipfrac": policy_clipfrac.item(),
            "q_diff_abs_mean": torch.mean(torch.abs(q_diff)).item(),
            "q_diff_mean": torch.mean(q_diff).item(),
            "q_diff_std": torch.std(q_diff).item(),
        }


# CausalLM architectures


@dataclass
class CausalLMOutputWithValue(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    value: Optional[torch.FloatTensor] = None


def make_value_branch(base_model, num_value_layers_unfrozen):
    value_head = make_head(hf_get_hidden_size(base_model.config), 1)
    if num_value_layers_unfrozen == 0:
        return value_head
    config = base_model.config
    branch_class = hf_get_branch_class(config)
    value_branch = branch_class(base_model, num_layers_unfrozen=num_value_layers_unfrozen, frozen=False)
    value_branch.lm_head = value_head
    return value_branch


# class MistralModelBranch(transformers.PreTrainedModel):
#     """
#     Take the last `num_layers_unfrozen` layers of the pretrained mistral model
#     """

#     def __init__(
#         self,
#         base_model: transformers.PreTrainedModel,
#         num_layers_unfrozen: int,
#     ):
#         super().__init__(base_model.config)
#         self.padding_idx = base_model.model.config.pad_token_id
#         self.vocab_size = base_model.model.config.vocab_size

#         self.embed_tokens = deepcopy(base_model.model.embed_tokens)
#         self.embed_tokens.requires_grad_(False)
#         self.layers = deepcopy(base_model.model.layers[-num_layers_unfrozen:])
#         self.norm = deepcopy(base_model.model.norm)
#         self.lm_head = deepcopy(base_model.lm_head)
#         self.gradient_checkpointing = False

#     def forward(
#         self,
#         input_ids: torch.LongTensor = None,
#         hidden_states: torch.FloatTensor = None,
#         attention_mask: Optional[torch.Tensor] = None,
#         position_ids: Optional[torch.LongTensor] = None,
#         past_key_values: Optional[List[torch.FloatTensor]] = None,
#         inputs_embeds: Optional[torch.FloatTensor] = None,
#         use_cache: Optional[bool] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#     ) -> Union[Tuple, CausalLMOutputWithPast]:
#         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
#         output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
#         use_cache = use_cache if use_cache is not None else self.config.use_cache

#         return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#         # retrieve input_ids and inputs_embeds
#         if input_ids is not None and inputs_embeds is not None:
#             raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
#         elif input_ids is not None:
#             batch_size, seq_length = input_ids.shape
#         elif inputs_embeds is not None:
#             batch_size, seq_length, _ = inputs_embeds.shape
#         else:
#             raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

#         seq_length_with_past = seq_length
#         past_key_values_length = 0

#         if past_key_values is not None:
#             past_key_values_length = past_key_values[0][0].shape[2]
#             seq_length_with_past = seq_length_with_past + past_key_values_length

#         if position_ids is None:
#             device = input_ids.device if input_ids is not None else inputs_embeds.device
#             position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
#             position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
#         else:
#             position_ids = position_ids.view(-1, seq_length).long()

#         if inputs_embeds is None:
#             inputs_embeds = self.embed_tokens(input_ids)

#         if attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled and past_key_values is not None:
#             is_padding_right = attention_mask[:, -1].sum().item() != batch_size
#             if is_padding_right:
#                 raise ValueError(
#                     "You are attempting to perform batched generation with padding_side='right'"
#                     " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
#                     " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
#                 )

#         if getattr(self.config, "_flash_attn_2_enabled", False):
#             # 2d mask is passed through the layers
#             attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
#         else:
#             # 4d mask is passed through the layers
#             attention_mask = _prepare_4d_causal_attention_mask(
#                 attention_mask,
#                 (batch_size, seq_length),
#                 inputs_embeds,
#                 past_key_values_length,
#                 sliding_window=self.config.sliding_window,
#             )

#         # hidden_states = inputs_embeds

#         if self.gradient_checkpointing and self.training:
#             if use_cache:
#                 print("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
#                 use_cache = False

#         # decoder layers
#         all_hidden_states = () if output_hidden_states else None
#         all_self_attns = () if output_attentions else None
#         next_decoder_cache = () if use_cache else None

#         for idx, decoder_layer in enumerate(self.layers):
#             if output_hidden_states:
#                 all_hidden_states += (hidden_states,)

#             past_key_value = past_key_values[idx] if past_key_values is not None else None

#             if self.gradient_checkpointing and self.training:
#                 layer_outputs = self._gradient_checkpointing_func(
#                     decoder_layer.__call__,
#                     hidden_states,
#                     attention_mask,
#                     position_ids,
#                     past_key_value,
#                     output_attentions,
#                     use_cache,
#                 )
#             else:
#                 layer_outputs = decoder_layer(
#                     hidden_states,
#                     attention_mask=attention_mask,
#                     position_ids=position_ids,
#                     past_key_value=past_key_value,
#                     output_attentions=output_attentions,
#                     use_cache=use_cache,
#                 )

#             hidden_states = layer_outputs[0]

#             if use_cache:
#                 next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

#             if output_attentions:
#                 all_self_attns += (layer_outputs[1],)

#         hidden_states = self.norm(hidden_states)

#         # add hidden states from the last decoder layer
#         if output_hidden_states:
#             all_hidden_states += (hidden_states,)

#         next_cache = next_decoder_cache if use_cache else None
#         if not return_dict:
#             outputs = tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
#         else:
#             outputs = BaseModelOutputWithPast(
#                 last_hidden_state=hidden_states,
#                 past_key_values=next_cache,
#                 hidden_states=all_hidden_states,
#                 attentions=all_self_attns,
#             )

#         logits = self.lm_head(hidden_states)
#         logits = logits.float()

#         if not return_dict:
#             return (logits,) + outputs[1:]

#         return CausalLMOutputWithPast(
#             loss=None,
#             logits=logits,
#             past_key_values=outputs.past_key_values,
#             hidden_states=outputs.hidden_states,
#             attentions=outputs.attentions,
#         )


# class MistralModelWithHydraHead(PreTrainedModelWrapper):
#     """An `AutoModel` class wrapper for `transformers` causal models that have a
#     language modeling head and a value head
#     """

#     def __init__(
#         self,
#         base_model: transformers.PreTrainedModel,
#         peft_config=None,
#         num_layers_unfrozen=0,
#     ):
#         super().__init__(base_model, peft_config=peft_config)
#         self.num_layers_unfrozen = num_layers_unfrozen
#         # self.v_head = make_mistral_value_branch(base_model, num_layers_unfrozen)
#         self.v_head = None
#         self.frozen_head = MistralModelBranch(base_model, num_layers_unfrozen)
#         for param in self.frozen_head.parameters():
#             param.requires_grad = False
#         self.frozen_head = self.frozen_head.eval()

#     def forward(
#         self,
#         input_ids: torch.LongTensor = None,
#         attention_mask: Optional[torch.Tensor] = None,
#         past_key_values: Optional[List[torch.FloatTensor]] = None,
#         position_ids: Optional[List[torch.FloatTensor]] = None,
#         head_mask: Optional[torch.Tensor] = None,
#         inputs_embeds: Optional[torch.FloatTensor] = None,
#         use_cache: Optional[bool] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#         ignore_peft_adapter: Optional[bool] = None,
#     ) -> Union[Tuple, CausalLMOutputWithValue]:
#         forward_kwargs = {
#             "input_ids": input_ids,
#             "attention_mask": attention_mask,
#             "position_ids": position_ids,
#             "output_hidden_states": True,
#             "return_dict": True,
#         }

#         outputs = self.base_model(**forward_kwargs)
#         forward_kwargs["hidden_states"] = outputs["hidden_states"][-(self.num_layers_unfrozen + 1)]
#         forward_kwargs.pop("return_dict", None)
#         # value = self.v_head(**forward_kwargs).logits.squeeze(-1)
#         value = None
#         if not return_dict:
#             outputs = (outputs.logits,) + outputs[1:] + (value,)
#             return outputs

#         return CausalLMOutputWithValue(**outputs, value=value)

#     def forward_hydra(
#         self,
#         input_ids: torch.LongTensor = None,
#         attention_mask: Optional[torch.Tensor] = None,
#         past_key_values: Optional[List[torch.FloatTensor]] = None,
#         position_ids: Optional[List[torch.FloatTensor]] = None,
#         head_mask: Optional[torch.Tensor] = None,
#         inputs_embeds: Optional[torch.FloatTensor] = None,
#         use_cache: Optional[bool] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#     ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]:
#         forward_kwargs = {
#             "input_ids": input_ids,
#             "attention_mask": attention_mask,
#             "position_ids": position_ids,
#             "output_hidden_states": True,
#             "return_dict": True,
#         }
#         outputs = self.forward(**forward_kwargs)
#         forward_kwargs["hidden_states"] = outputs["hidden_states"][-(self.num_layers_unfrozen + 1)]
#         hydra_outputs = self.frozen_head(**forward_kwargs)

#         if not return_dict:
#             return hydra_outputs.logits
#         return hydra_outputs

#     def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:
#         return self.base_model.generate(*args, **kwargs)

#     def state_dict(self, *args, heads_only=False, **kwargs):
#         """
#         Returns the state dictionary of the model. We add the state dictionary of the value head
#         to the state dictionary of the wrapped model by prepending the key with `v_head.`.
#         """
#         # state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
#         state_dict = {}
#         if not heads_only:
#             state_dict = {
#                 **state_dict,
#                 **self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)),
#             }

#             if self.frozen_head:
#                 state_dict = {
#                     **state_dict,
#                     **self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs)),
#                 }

#         return state_dict

#     def post_init(self, state_dict):
#         """
#         Load `state_dict` into the model. If peft was used to train the model,
#         only the value head would be present in the loaded `state_dict`, so the
#         loading has to be not strict. Also `frozen_head` will be recreated and
#         loaded from the checkpoint, to comply with deepspeed checkpoint loading.
#         """
#         strict = not self.peft_type and any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)

#         # if not self.peft_type and self.frozen_head is None:
#         #     for k in state_dict:
#         #         match = re.search(r"^frozen_head\..+\.(\d+)\.", k)
#         #         if match:
#         #             self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)

#         #     config = self.base_model.config
#         #     branch_class = hf_get_branch_class(config)
#         #     self.frozen_head = branch_class(
#         #         self.base_model,
#         #         num_layers_unfrozen=self.num_layers_unfrozen,
#         #     ).eval()

#         self.load_state_dict(state_dict, strict=strict)
#         del state_dict
#         gc.collect()  # noqa: E702


class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
    """An `AutoModel` class wrapper for `transformers` causal models that have a
    language modeling head and a value head
    """

    _auto_model_parent_class = transformers.AutoModelForCausalLM
    _supported_modules = ["v_head"]
    _supported_args = ["peft_config", "num_value_layers_unfrozen"]

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        peft_config=None,
        num_value_layers_unfrozen=0,
    ):
        super().__init__(base_model, peft_config=peft_config)
        self.num_value_layers_unfrozen = num_value_layers_unfrozen
        # self.v_head = make_value_branch(base_model, num_value_layers_unfrozen)
        self.v_head = None

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        position_ids: Optional[List[torch.FloatTensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ignore_peft_adapter: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        forward_kwargs = self.get_compatible_forward_kwargs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        forward_kwargs["output_hidden_states"] = True
        forward_kwargs["return_dict"] = True

        if self.peft_type == "PREFIX_TUNING":
            # In this case peft redefines past_key_values, remove it to avoid an exception.
            forward_kwargs.pop("past_key_values", None)

        if self.peft_type and ignore_peft_adapter:
            if "LORA" in self.peft_type:
                # For LORA, temporarily disable the adapter
                lora_model = self.base_model.base_model
                lora_model.disable_adapter_layers()
                outputs = self.base_model(**forward_kwargs)
                lora_model.enable_adapter_layers()
            else:
                # For prompt or prefix adapters, just use the base model of PeftModel
                outputs = self.base_model.base_model(**forward_kwargs)
        else:
            outputs = self.base_model(**forward_kwargs)

        # TODO: Apply PEFT to value branch
        if self.num_value_layers_unfrozen > 0:
            output_shape = outputs.hidden_states[-1].size()
            forward_kwargs.pop("input_ids", None)
            forward_kwargs.pop("inputs_embeds", None)
            forward_kwargs["return_dict"] = False
            # value = self.v_head(
            #     outputs.hidden_states[-(self.num_value_layers_unfrozen + 1)],
            #     output_shape=output_shape,
            #     **forward_kwargs,
            # )[
            #     0
            # ].squeeze(-1)
            value = None
        else:
            # value = self.v_head(outputs.hidden_states[-(self.num_value_layers_unfrozen + 1)]).squeeze(-1)
            value = None

        if not return_dict:
            outputs = (outputs.logits,) + outputs[1:] + (value,)
            return outputs

        return CausalLMOutputWithValue(**outputs, value=value)

    def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:
        return self.base_model.generate(*args, **kwargs)

    def state_dict(self, *args, heads_only=False, **kwargs):
        """
        Returns the state dictionary of the model. We add the state dictionary of the value head
        to the state dictionary of the wrapped model by prepending the key with `v_head.`.
        """
        # state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
        state_dict = {}
        if not heads_only:
            state_dict = {**state_dict, **self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs))}

        return {
            **self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
            **self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs)),
        }

    def post_init(self, state_dict):
        """
        Adds the state dictionary of the value head to the state dictionary of the wrapped model
        by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
        keys of the value head state dictionary.
        """
        super().post_init()

        trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
        self.load_state_dict(state_dict, strict=trlx_checkpoint)

        del state_dict
        gc.collect()  # noqa: E702


class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead):
    _supported_modules = ["v_head", "frozen_head"]
    _supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        *,
        num_layers_unfrozen: int = -1,
        peft_config=None,
        num_value_layers_unfrozen: int = 0,
    ):
        super().__init__(base_model, peft_config=peft_config, num_value_layers_unfrozen=num_value_layers_unfrozen)
        self.num_layers_unfrozen = num_layers_unfrozen

        if self.num_layers_unfrozen > 0 and not self.peft_type:
            config = self.base_model.config
            branch_class = hf_get_branch_class(config)
            self.frozen_head = branch_class(
                self.base_model,
                num_layers_unfrozen=self.num_layers_unfrozen,
            ).eval()
        else:
            self.frozen_head = None

    def forward_hydra(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        position_ids: Optional[List[torch.FloatTensor]] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]:
        forward_kwargs = self.get_compatible_forward_kwargs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return_dict = forward_kwargs.get("return_dict", True)
        forward_kwargs["return_dict"] = True
        forward_kwargs["output_hidden_states"] = True

        if self.peft_type:
            hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True)
        else:
            outputs = self.forward(**forward_kwargs)
            # Select the hidden state before the first branching layer
            input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)]

            output_shape = outputs.hidden_states[-1].size()
            # forward_kwargs.pop("input_ids", None)  # Ignore `input_ids` for branch head
            forward_kwargs.pop("inputs_embeds", None)  # Ignore `inputs_embeds` for branch head
            hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs)

        if not return_dict:
            return hydra_outputs.logits
        return hydra_outputs

    def state_dict(self, *args, heads_only=False, **kwargs):
        """
        Returns the state dictionary of the model. We add the state dictionary of the value head
        to the state dictionary of the wrapped model by prepending the key with `v_head.`.
        """
        # state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
        state_dict = {}
        if not heads_only:
            state_dict = {
                **state_dict,
                **self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)),
            }

            if self.frozen_head:
                state_dict = {
                    **state_dict,
                    **self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs)),
                }

        return state_dict

    def post_init(self, state_dict):
        """
        Load `state_dict` into the model. If peft was used to train the model,
        only the value head would be present in the loaded `state_dict`, so the
        loading has to be not strict. Also `frozen_head` will be recreated and
        loaded from the checkpoint, to comply with deepspeed checkpoint loading.
        """
        strict = not self.peft_type and any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)

        if not self.peft_type and self.frozen_head is None:
            for k in state_dict:
                match = re.search(r"^frozen_head\..+\.(\d+)\.", k)
                if match:
                    self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)

            config = self.base_model.config
            branch_class = hf_get_branch_class(config)
            self.frozen_head = branch_class(
                self.base_model,
                num_layers_unfrozen=self.num_layers_unfrozen,
            ).eval()

        self.load_state_dict(state_dict, strict=strict)
        del state_dict
        gc.collect()  # noqa: E702


class ModelBranch(transformers.PreTrainedModel):
    """Implements the upper trunk of the pretrained reference model used
    when computing the PPO KL-divergence penalty.
    """

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        *,
        num_layers_unfrozen: int,
        frozen=True,
    ):
        """
        Args:
            base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from
            num_layers_unfrozen (int): The number of trainable layers
        """
        super().__init__(base_model.config)

        # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model

        decoder_blocks = hf_get_decoder_blocks(base_model)[-num_layers_unfrozen:]
        final_norm = hf_get_decoder_final_norm(base_model)
        lm_head = hf_get_lm_head(base_model)

        with deepspeed.zero.GatheredParameters(
            list(decoder_blocks.parameters()) + list(final_norm.parameters()) + list(lm_head.parameters()),
            modifier_rank=None,
        ):
            self.decoder_blocks = deepcopy(decoder_blocks)
            self.final_norm = deepcopy(final_norm)
            self.lm_head = deepcopy(lm_head)

        self.hidden_size = hf_get_hidden_size(self.config)
        self.model_parallel = False
        self.device_map = None
        self.last_device = None
        self.gradient_checkpointing = False

        # Freeze the entire branch
        if frozen:
            for parameter in self.parameters():
                parameter.requires_grad_(False)

class MistralModelBranch(transformers.PreTrainedModel):
    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        num_layers_unfrozen: int,
    ):
        super().__init__(base_model.config)
        self.padding_idx = base_model.model.config.pad_token_id
        self.vocab_size = base_model.model.config.vocab_size

        self.embed_tokens = deepcopy(base_model.model.embed_tokens)
        self.embed_tokens.requires_grad_(False)
        self.layers = deepcopy(base_model.model.layers[-num_layers_unfrozen:])
        self.layers.requires_grad_(False)
        self.norm = deepcopy(base_model.model.norm)
        self.norm.requires_grad_(False)
        self.lm_head = deepcopy(base_model.lm_head)
        self.lm_head.requires_grad_(False)
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.FloatTensor = None,
        output_shape: torch.Size = None,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled and past_key_values is not None:
            is_padding_right = attention_mask[:, -1].sum().item() != batch_size
            if is_padding_right:
                raise ValueError(
                    "You are attempting to perform batched generation with padding_side='right'"
                    " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
                    " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
                )

        if getattr(self.config, "_flash_attn_2_enabled", False):
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask,
                (batch_size, seq_length),
                inputs_embeds,
                past_key_values_length,
                sliding_window=self.config.sliding_window,
            )

        # hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                print("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_value,
                    output_attentions,
                    use_cache,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            outputs = tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        else:
            outputs = BaseModelOutputWithPast(
                last_hidden_state=hidden_states,
                past_key_values=next_cache,
                hidden_states=all_hidden_states,
                attentions=all_self_attns,
            )

        logits = self.lm_head(hidden_states)
        # logits = logits.float()

        if not return_dict:
            return (logits,) + outputs[1:]
        
        return CausalLMOutputWithValue(
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

        return CausalLMOutputWithPast(
            loss=None,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class GPTModelBranch(ModelBranch):
    def forward(  # noqa: max-complexity
        self,
        hidden_states: torch.Tensor,  # Takes as input hidden_states instead of input_ids
        output_shape: torch.Tensor,  # output_size given by main trunk
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743  # noqa: E501
        """
        batch_size, seq_length = hidden_states.shape[:2]

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        device = hidden_states.device

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.decoder_blocks))
        else:
            past_length = past_key_values[0][0].size(-2)

        if position_ids is None:
            position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length)

        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            attention_mask = attention_mask[:, None, None, :]
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        if self.config.add_cross_attention and encoder_hidden_states is not None:
            (
                encoder_batch_size,
                encoder_sequence_length,
                _,
            ) = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                if layer_past is not None:
                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            kwargs = dict(
                layer_past=layer_past,
                attention_mask=attention_mask,
                position_ids=position_ids,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            # Assumes we are never training the branch
            block_params = inspect.getfullargspec(block.forward).args
            if "encoder_hidden_states" not in block_params:
                kwargs.pop("encoder_hidden_states")
                kwargs.pop("encoder_attention_mask")
            # Remove position_ids for GPT2Block
            if "position_ids" not in block_params:
                kwargs.pop("position_ids")

            outputs = block(hidden_states, **kwargs)

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.final_norm(hidden_states)

        hidden_states = hidden_states.view(output_shape)
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            outputs = (lm_logits,) + (None,) + (None,)
            return outputs

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class OPTModelBranch(ModelBranch):
    def forward(  # noqa: max-complexity
        self,
        hidden_states: torch.Tensor,
        output_shape: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840  # noqa: E501
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)

        input_shape = hidden_states.size()[:-1]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            # `modeling_opt._make_causal_mask` @ transformers==4.27.1 doesn't have the `device` argument
            if "device" in inspect.getfullargspec(modeling_opt._make_causal_mask).args:
                kwargs = dict(device=hidden_states.device)
            else:
                kwargs = {}

            combined_attention_mask = modeling_opt._make_causal_mask(
                input_shape,
                hidden_states.dtype,
                past_key_values_length=past_key_values_length,
                **kwargs,
            ).to(hidden_states.device)

        if attention_mask is not None:
            expanded_attn_mask = modeling_opt._expand_mask(attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]).to(hidden_states.device)
            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
        attention_mask = combined_attention_mask

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
            if attn_mask is not None:
                if attn_mask.size()[0] != (len(self.decoder_blocks)):
                    raise ValueError(f"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for" f" {head_mask.size()[0]}.")

        for idx, decoder_layer in enumerate(self.decoder_blocks):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            layer_outputs = decoder_layer(
                hidden_states,
                past_key_value=past_key_value,
                attention_mask=attention_mask,
                layer_head_mask=(head_mask[idx] if head_mask is not None else None),
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        if self.final_norm is not None:
            hidden_states = self.final_norm(hidden_states)

        # TODO: Add output projection support
        # https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499  # noqa: E501
        # if self.project_out is not None:
        #     hidden_states = self.project_out(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        lm_logits = self.lm_head(hidden_states).contiguous()

        if not return_dict:
            return tuple(
                v
                for v in [
                    lm_logits,
                    hidden_states,
                    next_cache,
                    all_hidden_states,
                    all_self_attns,
                ]
                if v is not None
            )

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class BloomModelBranch(ModelBranch):
    def forward(  # noqa: max-complexity
        self,
        hidden_states: torch.Tensor,  # Takes as input hidden_states instead of input_ids
        output_shape: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623  # noqa: E501
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = hidden_states.shape[:2]

        if past_key_values is None:
            past_key_values = tuple([None] * len(self.decoder_blocks))

        head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values[0] is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
        else:
            attention_mask = attention_mask.to(hidden_states.device)

        alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)

        combined_attention_mask = None
        device = attention_mask.device
        input_shape = (batch_size, seq_length)
        _, src_length = input_shape

        if src_length > 1:
            combined_attention_mask = modeling_bloom._make_causal_mask(
                input_shape,
                device=device,
                past_key_values_length=past_key_values_length,
            )

        expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length)
        combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
        causal_mask = combined_attention_mask

        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=causal_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
                output_attentions=output_attentions,
                alibi=alibi,
            )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

        hidden_states = self.final_norm(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return tuple(
                v
                for v in [
                    lm_logits,
                    hidden_states,
                    presents,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class LlamaModelBranch(ModelBranch):
    def _make_causal_mask(self, input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
        """
        Make causal mask used for bi-directional self-attention.
        """
        bsz, tgt_len = input_ids_shape
        mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
        mask_cond = torch.arange(mask.size(-1))
        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
        mask = mask.to(dtype)

        if past_key_values_length > 0:
            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

    def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        bsz, src_len = mask.size()
        tgt_len = tgt_len if tgt_len is not None else src_len

        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

        inverted_mask = 1.0 - expanded_mask

        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_states, past_key_values_length):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = self._make_causal_mask(input_shape, hidden_states.dtype, past_key_values_length=past_key_values_length).to(hidden_states.device)

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = self._expand_mask(attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]).to(hidden_states.device)
            combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
        return combined_attention_mask

    def forward(
        self,
        hidden_states: torch.Tensor,
        output_shape: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        batch_size, seq_length = hidden_states.shape[:2]
        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = hidden_states.device if hidden_states is not None else encoder_hidden_states.device
            position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
        attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.decoder_blocks):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.final_norm(hidden_states)
        hidden_states = hidden_states.view(output_shape)
        lm_logits = self.lm_head(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            outputs = (lm_logits,) + (None,) + (None,)
            return outputs

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class GPTBigCodeModelBranch(ModelBranch):
    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        *,
        num_layers_unfrozen: int,
    ):
        """
        Args:
            base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from
            num_layers_unfrozen (int): The number of trainable layers
        """
        super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen)
        self.config = base_model.transformer.config
        self.bias = base_model.transformer.bias
        self.multi_query = base_model.transformer.multi_query
        self.get_head_mask = base_model.transformer.get_head_mask

    def forward(  # noqa: C901
        self,
        hidden_states: torch.Tensor,  # Takes as input hidden_states instead of input_ids
        output_shape: torch.Tensor,  # output_size given by main trunk
        past_key_values: Optional[List[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L539
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = hidden_states.shape[:2]

        if batch_size <= 0:
            raise ValueError("batch_size has to be defined and > 0")

        device = hidden_states.device

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.decoder_blocks))
        else:
            past_length = past_key_values[0].size(-2)

        # Self-attention mask.
        query_length = seq_length
        key_length = past_length + query_length
        self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length].to(device)

        if attention_mask is not None:
            self_attention_mask = self_attention_mask * attention_mask.view(batch_size, 1, -1).to(dtype=torch.bool, device=self_attention_mask.device)

        # MQA models: (batch_size, query_length, n_heads, key_length)
        # MHA models: (batch_size, n_heads, query_length, key_length)
        attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.add_cross_attention and encoder_hidden_states is not None and encoder_attention_mask is not None:
            if encoder_attention_mask.dim() == 2:
                encoder_attention_mask.unsqueeze(1)
            assert encoder_attention_mask.dim() == 3
            encoder_attention_mask = encoder_attention_mask.bool().unsqueeze(2 if self.multi_query else 1)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        presents = [] if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            hidden_states = outputs[0]
            if use_cache:
                presents.append(outputs[1])

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

        hidden_states = self.final_norm(hidden_states)

        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        lm_logits = self.lm_head(hidden_states)

        if not return_dict:
            return tuple(
                v
                for v in [
                    lm_logits,
                    hidden_states,
                    presents,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )

        return CausalLMOutputWithValue(
            logits=lm_logits,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


# Seq2Seq architectures


@dataclass
class Seq2SeqLMOutputWithValue(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    value: Optional[torch.FloatTensor] = None


class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper):
    """An `AutoModel` class wrapper for `transformers` sequence-to-sequence
    models that have a language modeling head and a value head
    """

    _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM
    _supported_modules = ["v_head"]
    _supported_args = ["peft_config", "num_value_layers_unfrozen"]

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        peft_config=None,
        num_value_layers_unfrozen=0,
    ):
        super().__init__(base_model, peft_config=peft_config)
        # TODO: Support Seq2Seq value branching
        if num_value_layers_unfrozen > 0:
            raise NotImplementedError("Value branches unsupported for Seq2Seq architecture")
        self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.FloatTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = True,
        output_hidden_states: Optional[bool] = True,
        return_dict: Optional[bool] = None,
        ignore_peft_adapter: Optional[bool] = None,
    ) -> Seq2SeqLMOutputWithValue:
        forward_kwargs = self.get_compatible_forward_kwargs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        forward_kwargs["output_hidden_states"] = True
        forward_kwargs["return_dict"] = True

        if self.peft_type == "PREFIX_TUNING":
            # In this case peft redefines past_key_values, remove it to avoid an exception.
            forward_kwargs.pop("past_key_values", None)

        if self.peft_type and ignore_peft_adapter:
            if "LORA" in self.peft_type:
                # For LORA, temporarily disable the adapter
                lora_model = self.base_model.base_model
                lora_model.disable_adapter_layers()
                outputs = self.base_model(**forward_kwargs)
                lora_model.enable_adapter_layers()
            else:
                # For prompt or prefix adapters, just use the base model of PeftModel
                outputs = self.base_model.base_model(**forward_kwargs)
        else:
            outputs = self.base_model(**forward_kwargs)

        last_hidden_state = outputs.decoder_hidden_states[-1]
        value = self.v_head(last_hidden_state).squeeze(-1)

        return Seq2SeqLMOutputWithValue(**outputs, value=value)

    def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]:
        return self.base_model.generate(*args, **kwargs)

    def state_dict(self, *args, heads_only=False, **kwargs):
        """
        Returns the state dictionary of the model. We add the state dictionary of the value head
        to the state dictionary of the wrapped model by prepending the key with `v_head.`.
        """
        state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
        if not heads_only:
            state_dict = {**state_dict, **self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs))}

        return state_dict

    def post_init(self, state_dict):
        """
        Adds the state dictionary of the value head to the state dictionary of the wrapped model
        by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
        keys of the value head state dictionary.
        """
        super().post_init()

        trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
        self.load_state_dict(state_dict, strict=trlx_checkpoint)

        del state_dict
        gc.collect()  # noqa: E702


class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead):
    _supported_modules = ["v_head", "frozen_head"]
    _supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        *,
        num_layers_unfrozen: int = -1,
        peft_config=None,
        num_value_layers_unfrozen: int = 0,
    ):
        super().__init__(base_model, peft_config=peft_config, num_value_layers_unfrozen=num_value_layers_unfrozen)
        self.num_layers_unfrozen = num_layers_unfrozen

        if self.num_layers_unfrozen > 0 and not self.peft_type:
            branch_class = T5Branch  # TODO: Add support for other model branches
            self.frozen_head = branch_class(
                self.base_model,
                num_layers_unfrozen=self.num_layers_unfrozen,
            ).eval()
        else:
            self.frozen_head = None

    def forward_hydra(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.FloatTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Seq2SeqLMOutputWithValue:
        forward_kwargs = self.get_compatible_forward_kwargs(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return_dict = forward_kwargs.get("return_dict", True)
        forward_kwargs["output_hidden_states"] = True
        forward_kwargs["return_dict"] = True

        if self.peft_type:
            hydra_outputs = self.forward(**forward_kwargs, ignore_peft_adapter=True)
        else:
            outputs = self.forward(**forward_kwargs)
            # Select the hidden state before the first branching layer
            input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)]
            hydra_outputs = self.frozen_head(
                hidden_states=input_hidden_state,
                attention_mask=decoder_attention_mask,
                encoder_hidden_states=outputs.encoder_last_hidden_state,
                encoder_attention_mask=attention_mask,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=True,
                return_dict=return_dict,
            )

        if not return_dict:
            return hydra_outputs.logits
        return hydra_outputs

    def state_dict(self, *args, heads_only=False, **kwargs):
        """
        Returns the state dictionary of the model. We add the state dictionary of the value head
        to the state dictionary of the wrapped model by prepending the key with `v_head.`.
        """
        state_dict = self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs))
        if not heads_only:
            state_dict = {
                **state_dict,
                **self.base_model.state_dict(*args, **dict(prefix="" if self.peft_type else "base_model.", **kwargs)),
            }

            if self.frozen_head:
                state_dict = {
                    **state_dict,
                    **self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs)),
                }

        return state_dict

    def post_init(self, state_dict):
        """
        Load `state_dict` into the model. If peft was used to train the model,
        only the value head would be present in the loaded `state_dict`, so the
        loading has to be not strict. Also `frozen_head` will be recreated and
        loaded from the checkpoint, to comply with deepspeed checkpoint loading.
        """
        strict = not self.peft_type and any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)

        if not self.peft_type and self.frozen_head is None:
            for k in state_dict:
                match = re.search(r"^frozen_head\.decoder_blocks\.(\d+)", k)
                if match:
                    self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)

            branch_class = T5Branch  # TODO: Add support for other model branches
            self.frozen_head = branch_class(
                self.base_model,
                num_layers_unfrozen=self.num_layers_unfrozen,
            ).eval()

        self.load_state_dict(state_dict, strict=strict)
        del state_dict
        gc.collect()  # noqa: E702


class T5Branch(ModelBranch):
    """Decoder only T5 branch"""

    def __init__(
        self,
        base_model: transformers.PreTrainedModel,
        *,
        num_layers_unfrozen: int,
    ):
        super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen)
        self.dropout = hf_get_decoder(base_model).dropout
        self.is_decoder = True

    def forward(  # noqa: max-complexity
        self,
        hidden_states: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqLMOutputWithValue]:
        """Reference:
        https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899  # noqa: E501
        """
        batch_size, seq_length = hidden_states.shape[:2]
        input_shape = (batch_size, seq_length)

        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if attention_mask is None:
            attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
            encoder_seq_length = encoder_hidden_states.shape[1]
            encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long)

        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)

        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        position_bias = None
        encoder_decoder_position_bias = None

        for _, layer_module in enumerate(self.decoder_blocks):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states,
                attention_mask=extended_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_extended_attention_mask,
                encoder_decoder_position_bias=encoder_decoder_position_bias,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )

            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]

            hidden_states, present_key_value_state = layer_outputs[:2]

            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[3],)

        hidden_states = self.final_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        sequence_output = hidden_states

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586  # noqa: E501
            sequence_output = sequence_output * (self.config.d_model**-0.5)

        lm_logits = self.lm_head(sequence_output)

        if not return_dict:
            return (lm_logits,)

        return Seq2SeqLMOutputWithValue(
            logits=lm_logits,
            decoder_hidden_states=all_hidden_states,
            decoder_attentions=all_attentions,
        )


# Branch class utils


def hf_get_branch_class(
    config: transformers.PretrainedConfig,
) -> "ModelBranch":
    """Returns the model branch class for the given config."""
    gpt_branch_supported_archs = [
        "GPTJForCausalLM",
        "GPT2LMHeadModel",
        "GPTNeoForCausalLM",
        "GPTNeoXForCausalLM",
    ]
    opt_branch_supported_archs = ["OPTForCausalLM"]
    bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
    llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
    bigcode_branch_supported_archs = ["GPTBigCodeModel", "GPTBigCodeForCausalLM"]
    mistral_branch_supported_archs = ["MistralForCausalLM"]
    arch = config.architectures[0]
    if arch in gpt_branch_supported_archs:
        return GPTModelBranch
    elif arch in opt_branch_supported_archs:
        return OPTModelBranch
    elif arch in bloom_branch_supported_archs:
        return BloomModelBranch
    elif arch in llama_branch_supported_archs:
        return LlamaModelBranch
    elif arch in bigcode_branch_supported_archs:
        return GPTBigCodeModelBranch
    elif arch in mistral_branch_supported_archs:
        return MistralModelBranch
    else:
        all_supported_archs = sum(
            [
                gpt_branch_supported_archs,
                opt_branch_supported_archs,
                bloom_branch_supported_archs,
                llama_branch_supported_archs,
                bigcode_branch_supported_archs,
                mistral_branch_supported_archs,
            ],
            [],
        )
        raise ValueError(f"Unsupported architecture: `{arch}`. The following architectures are " f"available for model branching:\n{all_supported_archs}")
