import copy
import math
import warnings
from typing import Callable, List, Optional, Tuple

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn, vmap
from torch.func import functional_call, stack_module_state  # type: ignore

from .utils import combine_masks, KVCache, LayerKVCache, repeat_kv


class ShortCircuit(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        n_policy_layers: int,
        n_value_layers: int,
        act_fn: Callable,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        is_causal: bool = True,
        attention_bias: bool = False,
        position_embeddings: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super(ShortCircuit, self).__init__()
        self.embedding_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.n_policy_layers = n_policy_layers
        self.n_value_layers = n_value_layers
        self.act_fn = act_fn
        self.is_causal = is_causal
        self.position_embeddings = position_embeddings
        if num_key_value_heads is None:
            num_key_value_heads = self.n_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = self.embedding_size // self.n_heads

        self.hidden_module = ShortCircuitHiddenModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_layers=self.n_layers,
            intermediate_size=self.intermediate_size,
            n_policy_layers=self.n_policy_layers,
            act_fn=self.act_fn,
            num_key_value_heads=num_key_value_heads,
            dropout=dropout,
            attention_bias=attention_bias,
            position_embeddings=position_embeddings,
            rms_norm_eps=rms_norm_eps,
        )

        self.policy_head = ShortCircuitPolicyModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_policy_layers=self.n_policy_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.n_layers,
            dropout=dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

        self.value_head = ShortCircuitValueModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_value_layers=self.n_value_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.n_layers + 4 * self.n_policy_layers,
            dropout=dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )
        self.apply(self._init_weights)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        causal_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        num_outputs: int = 1,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
        get_value: bool = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

        # embed positions
        hidden_states = inputs_embeds

        mask = combine_masks(causal_mask, attention_mask)

        hidden_states = self.hidden_module(
            inputs_embeds=hidden_states,
            attention_mask=mask,
            cache=cache,
            use_cache=use_cache,
        )

        actions = self.policy_head(
            hidden_states=hidden_states,
            attention_mask=mask,
            num_outputs=num_outputs,
            cache=cache,
            use_cache=use_cache,
        )

        if not get_value:
            return actions

        value = self.value_head(
            hidden_states=hidden_states,
            attention_mask=combine_masks(None, attention_mask),
            num_outputs=num_outputs,
            cache=cache,
            use_cache=use_cache,
        )

        return actions, value

    def get_hidden_module(self) -> nn.Module:
        return self.hidden_module

    def get_policy_head(self) -> nn.Module:
        return self.policy_head

    def get_value_head(self) -> nn.Module:
        return self.value_head

    def _init_weights(self, module: nn.Module, std: float = 0.02) -> None:
        if isinstance(module, nn.Linear):
            # module.weight.data.normal_(mean=0.0, std=std)
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # nn.init.xavier_normal_(module.bias)
                module.bias.data.zero_()
        elif hasattr(module, "weight") and module.weight is not None:
            if module.weight.ndim > 1:
                # module.weight.data.normal_(mean=0.0, std=std)
                nn.init.xavier_uniform_(module.weight)
            else:
                module.weight.data.normal_(mean=0.0, std=std)
            if hasattr(module, "bias") and module.bias is not None:
                # nn.init.xavier_normal_(module.bias)
                module.bias.data.zero_()


class ShortCircuitHiddenModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        n_policy_layers: int,
        act_fn: Callable,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        position_embeddings: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.n_policy_layers = n_policy_layers
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        self.num_key_value_heads = num_key_value_heads
        self.dropout = dropout
        self.attention_bias = attention_bias
        self.rms_norm_eps = rms_norm_eps
        self.position_embeddings = position_embeddings
        if self.position_embeddings:
            self.position_embeddings_layer = PositionEmbeddingModule(
                self.embedding_size,
                self.rms_norm_eps,
            )
        self.base_layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    intermediate_size=self.intermediate_size,
                    layer_idx=layer_idx,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for layer_idx in range(self.n_layers)
            ]
        )
        self.norm = LlamaRMSNorm(self.embedding_size, eps=self.rms_norm_eps)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        num_outputs: int = 1,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:

        # embed positions
        hidden_states = inputs_embeds

        if self.position_embeddings:
            hidden_states = self.position_embeddings_layer(hidden_states, num_outputs)

        for layer in self.base_layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ShortCircuitPolicyModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_policy_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.n_policy_layers = n_policy_layers
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        self.layer_idx_offset = layer_idx_offset
        self.num_key_value_heads = num_key_value_heads
        self.dropout = dropout
        self.policy_layers = torch.nn.ModuleList(
            [
                ShortCircuitPolicyLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    n_layers=self.n_policy_layers,
                    intermediate_size=self.intermediate_size,
                    layer_idx_offset=new_layer_idx_offset,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for new_layer_idx_offset in range(
                    self.layer_idx_offset,
                    self.layer_idx_offset + 4 * self.n_policy_layers,
                    n_policy_layers,
                )
            ]
        )
        self.ensemble_ready = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        num_outputs: int = 1,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:
        # if self.training:
        # if True:
        # if not self.ensemble_ready:
        return torch.stack(
            [
                layer(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    num_outputs=num_outputs,
                    cache=cache,
                    use_cache=use_cache,
                )
                for layer in self.policy_layers
            ],
            dim=1,
        )
        # else:
        #     return self.ensemble_forward(
        #         hidden_states=hidden_states,
        #         attention_mask=attention_mask,
        #         num_outputs=num_outputs,
        #         cache=cache,
        #         use_cache=use_cache,
        #     )

    # def to(self, *args, **kwargs) -> "ShortCircuitPolicyModule":
    #     for param in self.params:
    #         self.params[param] = self.params[param].to(*args, **kwargs)
    #     for b in self.buffs:
    #         self.buffs[b] = self.buffs[b].to(*args, **kwargs)
    #     return self
    #     # return super().to(*args, **kwargs)

    def prepare_ensemble(self):
        self.meta_policy_layer = copy.deepcopy(self.policy_layers[0]).to("meta")
        self.params, self.buffs = stack_module_state(self.policy_layers)  # type: ignore
        self.ensemble_ready = True

    def ensemble_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        num_outputs: int = 1,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:
        def _call_single_policy_layer(
            params, buffers, hidden_states, attention_mask, num_outputs
        ):
            return functional_call(
                self.meta_policy_layer,  # type: ignore
                (params, buffers),
                (hidden_states, attention_mask, num_outputs),
            )

        out = torch.vmap(
            _call_single_policy_layer, in_dims=(0, 0, None, None, None), out_dims=1
        )(self.params, self.buffs, hidden_states, attention_mask, num_outputs)

        return out


class ShortCircuitPolicyLayer(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        rms_norm_eps: float = 1e-6,
        layer_idx_offset: int = 0,
    ) -> None:
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.act_fn = act_fn
        self.rms_norm_eps = rms_norm_eps
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.hidden_size,
                    n_heads=self.n_heads,
                    layer_idx=layer_idx,
                    intermediate_size=self.intermediate_size,
                    act_fn=self.act_fn,
                    rms_norm_eps=self.rms_norm_eps,
                )
                for layer_idx in range(
                    layer_idx_offset, layer_idx_offset + self.n_layers - 1
                )
            ]
        )
        self.norm = LlamaRMSNorm(self.hidden_size, eps=self.rms_norm_eps)
        self.action_layer = ShortCircuitActionHead(
            self.hidden_size, self.n_heads, layer_idx_offset + self.n_layers - 1
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        num_outputs: int = 1,
        cache: Optional[KVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.FloatTensor:

        for layer in self.layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        self_attn = self.action_layer(
            hidden_states,
            attention_mask=attention_mask,
            num_outputs=num_outputs,
            cache=cache[str(self.action_layer.layer_idx)]
            if cache is not None
            else None,
            use_cache=use_cache,
        )

        return self_attn


class ShortCircuitActionHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: Optional[int] = None,
        num_key_value_heads: Optional[int] = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_dropout = dropout
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.head_dim = self.hidden_size // self.num_heads
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        num_outputs: int = 1,
        cache: Optional[LayerKVCache] = None,
        use_cache: bool = False,
    ) -> torch.Tensor:

        # if cache is None: # We want to remove the output nodes from the attention
        # hidden_states = hidden_states[:, num_outputs:, :]

        bsz, q_len, _ = hidden_states.size()

        if q_len > 1:
            hidden_states = hidden_states[:, num_outputs:, :]
            q_len -= num_outputs

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )

        # if cache is not None and use_cache:
        #     key_states = cache.update_keys(key_states)
        #     query_states = cache.update_queries(query_states)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)

        key_states = repeat_kv(key_states, self.num_key_value_groups)

        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3))
            / math.sqrt(self.head_dim)
        ).mean(1)

        # if cache is not None:
        #     attn_weights = cache.update_attention_weights(attn_weights)

        return attn_weights


class ShortCircuitValueModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        self.hidden_module = ShortCircuitHiddenValueModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_value_layers=self.n_value_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.layer_idx_offset,
            num_key_value_heads=self.num_key_value_heads,
            dropout=self.attention_dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

        self.value_head = ShortCircuitValueHead(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_value_layers=self.n_value_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.layer_idx_offset + 4 * self.n_value_layers,
            num_key_value_heads=self.num_key_value_heads,
            dropout=self.attention_dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        num_outputs: int = 1,
        cache: Optional[KVCache] = None,
        use_cache: bool = False,
    ) -> torch.Tensor:

        hidden_states = self.hidden_module(
            hidden_states=hidden_states,
        )

        value = self.value_head(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            num_outputs=num_outputs,
            cache=cache,
            use_cache=use_cache,
        )

        return value


class ShortCircuitHiddenValueModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    layer_idx=new_layer_idx,
                    intermediate_size=self.intermediate_size,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for new_layer_idx in range(
                    self.layer_idx_offset,
                    self.layer_idx_offset + self.n_value_layers - 1,
                )
            ]
        )
        self.norm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[KVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.Tensor:

        for layer in self.layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ShortCircuitValueHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        self.mlp = LlamaMLP(self.hidden_size, intermediate_size, act_fn=act_fn)
        self.input_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=attention_bias)
        self.value_layer = nn.Linear(self.hidden_size, 1)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        num_outputs: int = 1,
        cache: Optional[LayerKVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.Tensor:
        # TODO: Check if this is correct
        # Value approximation
        bsz, q_len, em = hidden_states.size()

        residual_target = hidden_states[:, :num_outputs, :].view(
            bsz, num_outputs, self.embedding_size
        )

        hidden_states = self.input_layernorm(hidden_states)

        # MHA

        # if cache is None:
        hidden_target_states = hidden_states[:, :num_outputs, :]
        hidden_node_states = hidden_states[:, num_outputs:, :]
        query_states = self.q_proj(hidden_target_states)
        # else:
        #     # query_states = cache["v"].get_queries()  # type: ignore
        #     hidden_node_states = hidden_states

        key_states = self.k_proj(hidden_node_states)
        value_states = self.v_proj(hidden_node_states)

        query_states = query_states.view(
            bsz, num_outputs, self.num_heads, self.head_dim
        )
        key_states = key_states.view(
            bsz, key_states.shape[1], self.num_key_value_heads, self.head_dim
        )
        value_states = value_states.view(
            bsz, value_states.shape[1], self.num_key_value_heads, self.head_dim
        )

        # if cache is not None:
        #     key_states, value_states = cache.update_kvcache(  # type: ignore
        #         key_states,
        #         value_states,
        #     )

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        if attention_mask is not None:
            attention_mask = attention_mask[:, :, :num_outputs, num_outputs:]

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=False,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, num_outputs, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        hidden_target = attn_output + residual_target
        residual_target = hidden_target

        hidden_states = self.post_attention_layernorm(hidden_target)
        hidden_states = self.mlp(hidden_states)

        hidden_states = residual_target + hidden_states

        value = self.value_layer(self.act_fn(hidden_states))

        return value.view(bsz, num_outputs)


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: int,
        intermediate_size: int,
        act_fn: Callable,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.layer_idx = layer_idx

        # self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        self.self_attn = LlamaAttention(
            self.hidden_size, self.n_heads, layer_idx=layer_idx
        )

        self.mlp = LlamaMLP(self.hidden_size, self.intermediate_size, act_fn=act_fn)
        self.input_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
        """
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            cache=cache,
            use_cache=use_cache,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class LlamaAttention(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: Optional[int] = None,
        num_key_value_heads: Optional[int] = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_dropout = dropout
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.head_dim = self.hidden_size // self.num_heads
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: bool = False,
    ) -> torch.Tensor:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(
            hidden_states
        )  # this might be just the last embedding
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )

        # if use_cache and cache is not None:
        #     key_states, value_states = cache.update_kvcache(key_states, value_states)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # if attention_mask is not None:
        #     attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=False,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output


class LlamaMLP(nn.Module):
    def __init__(self, embedding_size: int, intermediate_size: int, act_fn: Callable):
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = act_fn

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class PositionEmbeddingModule(nn.Module):
    """A Module responsible for injecting learnable position embeddings
    to the nodes and the target
    """

    def __init__(self, embedding_size: int, rms_norm_eps: float = 1e-6):
        super(PositionEmbeddingModule, self).__init__()
        self.embedding_size = embedding_size
        self.target_embedding = nn.Parameter(torch.empty(embedding_size))
        self.node_embedding = nn.Parameter(torch.empty(embedding_size))
        torch.nn.init.xavier_uniform_(self.target_embedding.unsqueeze(0))
        torch.nn.init.xavier_uniform_(self.node_embedding.unsqueeze(0))
        self.layer_norm = LlamaRMSNorm(self.embedding_size, rms_norm_eps)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        num_outputs: int = 1,
    ):

        inputs_embeds[..., :num_outputs, :].add_(self.target_embedding)
        inputs_embeds[..., num_outputs:, :].add_(self.node_embedding)

        position_embeds = self.layer_norm(inputs_embeds)

        return position_embeds
