from __future__ import annotations
import copy
import math
import os
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch import nn
from torch.func import functional_call, stack_module_state
import yaml  # type: ignore

from .layers import LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm

from .utils import ActionHeadKVCache, combine_masks, KVCache, LayerKVCache, repeat_kv
from .vae import VAEConfig, VAEEncoder


class ShortCircuit3(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        n_policy_layers: int,
        n_value_layers: int,
        act_fn: Callable,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        is_causal: bool = True,
        attention_bias: bool = False,
        position_embeddings: bool = False,
        rms_norm_eps: float = 1e-6,
        create_value_head: bool = True,
        truth_table_encoder: Optional[TruthTableEncoder] = None,
    ) -> None:
        super(ShortCircuit3, self).__init__()
        self.embedding_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.n_policy_layers = n_policy_layers
        self.n_value_layers = n_value_layers
        self.act_fn = act_fn
        self.is_causal = is_causal
        self.position_embeddings = position_embeddings
        self.truth_table_encoder = truth_table_encoder
        if num_key_value_heads is None:
            num_key_value_heads = self.n_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = self.embedding_size // self.n_heads

        self.hidden_module = ShortCircuit3HiddenModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_layers=self.n_layers,
            intermediate_size=self.intermediate_size,
            n_policy_layers=self.n_policy_layers,
            act_fn=self.act_fn,
            num_key_value_heads=num_key_value_heads,
            dropout=dropout,
            attention_bias=attention_bias,
            position_embeddings=position_embeddings,
            rms_norm_eps=rms_norm_eps,
        )

        # self.policy_head = StackedPolicyModule(
        #     embedding_size=self.embedding_size,
        #     n_heads=self.n_heads,
        #     n_policy_layers=self.n_policy_layers,
        #     intermediate_size=self.intermediate_size,
        #     act_fn=self.act_fn,
        #     layer_idx_offset=self.n_layers,
        #     dropout=dropout,
        #     attention_bias=attention_bias,
        #     rms_norm_eps=rms_norm_eps,
        # )

        self.policy_head = ShortCircuitPolicyModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_policy_layers=self.n_policy_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.n_layers,
            dropout=dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

        if create_value_head:
            self.value_head = ShortCircuitValueModule(
                embedding_size=self.embedding_size,
                n_heads=self.n_heads,
                n_value_layers=self.n_value_layers,
                intermediate_size=self.intermediate_size,
                act_fn=self.act_fn,
                layer_idx_offset=self.n_layers + 4 * self.n_policy_layers,
                dropout=dropout,
                attention_bias=attention_bias,
                rms_norm_eps=rms_norm_eps,
            )
        self.apply(self._init_weights)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        causal_mask: torch.Tensor | None = None,
        attention_mask: torch.Tensor | None = None,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
        get_value: bool = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        
        hidden_states = self.truth_table_encoder(inputs_embeds)

        mask = combine_masks(causal_mask, attention_mask)

        hidden_states = self.hidden_module(
            inputs_embeds=hidden_states,
            attention_mask=mask,
            cache=cache,
            use_cache=use_cache,
        )

        actions = self.policy_head(
            hidden_states=hidden_states,
            attention_mask=mask,
            cache=cache,
            use_cache=use_cache,
        )

        if not get_value:
            return actions

        value = self.value_head(
            hidden_states=hidden_states,
            attention_mask=combine_masks(None, attention_mask),
            cache=cache,
            use_cache=use_cache,
        )

        return actions, value

    def get_hidden_module(self) -> nn.Module:
        return self.hidden_module

    def get_policy_head(self) -> nn.Module:
        return self.policy_head

    def get_value_head(self) -> nn.Module:
        return self.value_head

    def _init_weights(self, module: nn.Module, std: float = 0.02) -> None:
        if isinstance(module, nn.Linear):
            # module.weight.data.normal_(mean=0.0, std=std)
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                # nn.init.xavier_normal_(module.bias)
                module.bias.data.zero_()
        elif hasattr(module, "weight") and module.weight is not None:
            if module.weight.ndim > 1:  # type: ignore
                # module.weight.data.normal_(mean=0.0, std=std)
                nn.init.xavier_uniform_(module.weight)  # type: ignore
            else:
                module.weight.data.normal_(mean=0.0, std=std)  # type: ignore
            if hasattr(module, "bias") and module.bias is not None:
                # nn.init.xavier_normal_(module.bias)
                module.bias.data.zero_()  # type: ignore

    def create_embeddings(self, input_nodes: torch.Tensor, cache_node_embeddings: torch.Tensor | None = None) -> torch.Tensor:
        if cache_node_embeddings is not None:
            return torch.cat(
                [
                    cache_node_embeddings,
                    self.truth_table_encoder(input_nodes[:, cache_node_embeddings.size(-2):, :])
                ],
                dim=-2
            )
        return self.truth_table_encoder(input_nodes)

    def get_action_logits(
        self,
        node_embeddings: torch.Tensor,
        causal_mask: torch.Tensor,
    ) -> torch.Tensor:
        
        hidden_states = self.hidden_module(
            inputs_embeds=node_embeddings,
            attention_mask=causal_mask,
            cache=None,
            use_cache=False,
        )

        actions = self.policy_head(
            hidden_states=hidden_states,
            attention_mask=causal_mask,
            cache=None,
            use_cache=False,
        )
        return actions

class ShortCircuit3HiddenModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        n_policy_layers: int,
        act_fn: Callable,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        position_embeddings: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.n_policy_layers = n_policy_layers
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        self.num_key_value_heads = num_key_value_heads
        self.dropout = dropout
        self.attention_bias = attention_bias
        self.rms_norm_eps = rms_norm_eps
        self.position_embeddings = position_embeddings

        # if self.position_embeddings:
        self.position_embeddings_layer = PositionEmbeddingModule(
            self.embedding_size,
            self.rms_norm_eps,
        )

        self.base_layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    intermediate_size=self.intermediate_size,
                    layer_idx=layer_idx,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for layer_idx in range(self.n_layers)
            ]
        )
        self.norm = LlamaRMSNorm(
            self.embedding_size, eps=self.rms_norm_eps
        )

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:

        # embed positions
        hidden_states = inputs_embeds

        hidden_states = self.position_embeddings_layer(hidden_states)

        for layer in self.base_layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ShortCircuitPolicyModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_policy_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.n_policy_layers = n_policy_layers
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        self.layer_idx_offset = layer_idx_offset
        self.num_key_value_heads = num_key_value_heads
        self.dropout = dropout
        self.policy_layers = torch.nn.ModuleList(
            [
                ShortCircuitPolicyLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    n_layers=self.n_policy_layers,
                    intermediate_size=self.intermediate_size,
                    layer_idx_offset=new_layer_idx_offset,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for new_layer_idx_offset in range(
                    self.layer_idx_offset,
                    self.layer_idx_offset + 4 * self.n_policy_layers,
                    n_policy_layers,
                )
            ]
        )
        self.meta_policy_layer = None
        self.params = None
        self.buffs = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:
        if True:
            return torch.stack(
                [
                    layer(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        cache=cache,
                        use_cache=use_cache,
                    )
                    for layer in self.policy_layers
                ],
                dim=1,
            )

        return self.ensemble_forward(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            cache=cache,
            use_cache=use_cache,
        )

    def _construct_meta_layer(self):
        self.meta_policy_layer = copy.deepcopy(self.policy_layers[0]).to("meta")

    def ensemble_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:
        def _call_single_policy_layer(
            params, buffers, hidden_states, attention_mask
        ):
            if self.meta_policy_layer is None:
                self._construct_meta_layer()
            return functional_call(
                self.meta_policy_layer,  # type: ignore
                (params, buffers),
                (hidden_states, attention_mask),
            )

        if self.params is None:
            self.params, self.buffs = stack_module_state([*self.policy_layers])

        out = torch.vmap(
            _call_single_policy_layer, in_dims=(0, 0, None, None, None), out_dims=1
        )(self.params, self.buffers, hidden_states, attention_mask)

        return out


class ShortCircuitPolicyLayer(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        rms_norm_eps: float = 1e-6,
        layer_idx_offset: int = 0,
    ) -> None:
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.act_fn = act_fn
        self.rms_norm_eps = rms_norm_eps
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.hidden_size,
                    n_heads=self.n_heads,
                    layer_idx=layer_idx,
                    intermediate_size=self.intermediate_size,
                    act_fn=self.act_fn,
                    rms_norm_eps=self.rms_norm_eps,
                )
                for layer_idx in range(
                    layer_idx_offset, layer_idx_offset + self.n_layers - 1
                )
            ]
        )
        self.norm = LlamaRMSNorm(self.hidden_size, eps=self.rms_norm_eps)
        self.action_layer = ShortCircuitActionHead(
            self.hidden_size, self.n_heads, layer_idx_offset + self.n_layers - 1
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[KVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.FloatTensor:

        for layer in self.layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        self_attn = self.action_layer(
            hidden_states,
            attention_mask=attention_mask,
            cache=cache[str(self.action_layer.layer_idx)]
            if cache is not None
            else None,
            use_cache=use_cache,
        )

        return self_attn


class ShortCircuitActionHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: Optional[int] = None,
        num_key_value_heads: Optional[int] = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_dropout = dropout
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.head_dim = self.hidden_size // self.num_heads
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: bool = False,
        num_outputs: int = 1,
    ) -> torch.Tensor:

        # if cache is None: # We want to remove the output nodes from the attention
        # hidden_states = hidden_states[:, num_outputs:, :]

        bsz, q_len, _ = hidden_states.size()

        if q_len > 1:
            hidden_states = hidden_states[:, num_outputs:, :]
            q_len -= num_outputs

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        )

        if cache is not None and use_cache:
            if self.layer_id not in cache:
                layer_cache = ActionHeadKVCache(
                    key_states,
                )
                cache[self.layer_id] = layer_cache
            else:
                layer_cache = cache.get_layer_cache(self.layer_id)
                key_states, value_states = layer_cache.update_kvcache(key_states, value_states)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)

        key_states = repeat_kv(key_states, self.num_key_value_groups)

        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3))
            / math.sqrt(self.head_dim)
        ).mean(1)

        # if cache is not None:
        #     attn_weights = cache.update_attention_weights(attn_weights)

        return attn_weights


class ShortCircuitValueModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        self.hidden_module = ShortCircuitHiddenValueModule(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_value_layers=self.n_value_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.layer_idx_offset,
            num_key_value_heads=self.num_key_value_heads,
            dropout=self.attention_dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

        self.value_head = ShortCircuitValueHead(
            embedding_size=self.embedding_size,
            n_heads=self.n_heads,
            n_value_layers=self.n_value_layers,
            intermediate_size=self.intermediate_size,
            act_fn=self.act_fn,
            layer_idx_offset=self.layer_idx_offset + 4 * self.n_value_layers,
            num_key_value_heads=self.num_key_value_heads,
            dropout=self.attention_dropout,
            attention_bias=attention_bias,
            rms_norm_eps=rms_norm_eps,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[KVCache] = None,
        use_cache: bool = False,
    ) -> torch.Tensor:

        hidden_states = self.hidden_module(
            hidden_states=hidden_states,
        )

        value = self.value_head(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            cache=cache,
            use_cache=use_cache,
        )

        return value


class ShortCircuitHiddenValueModule(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    layer_idx=new_layer_idx,
                    intermediate_size=self.intermediate_size,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                )
                for new_layer_idx in range(
                    self.layer_idx_offset,
                    self.layer_idx_offset + self.n_value_layers - 1,
                )
            ]
        )
        self.norm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[KVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.Tensor:

        for layer in self.layers:

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                cache=cache[str(layer.layer_idx)] if cache is not None else None,
                use_cache=use_cache,
            )

        hidden_states = self.norm(hidden_states)

        return hidden_states


class ShortCircuitValueHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_value_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.attention_dropout = dropout
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.intermediate_size = intermediate_size
        self.head_dim = self.hidden_size // self.num_heads
        self.layer_idx_offset = layer_idx_offset
        self.act_fn = act_fn
        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.n_value_layers = n_value_layers

        self.mlp = LlamaMLP(self.hidden_size, intermediate_size, act_fn=act_fn)
        self.input_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(self.hidden_size, eps=rms_norm_eps)

        self.q_proj = nn.Linear(
            self.hidden_size, self.num_heads * self.head_dim, bias=attention_bias
        )
        self.k_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.v_proj = nn.Linear(
            self.hidden_size,
            self.num_key_value_heads * self.head_dim,
            bias=attention_bias,
        )
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=attention_bias)
        self.value_layer = nn.Linear(self.hidden_size, 1)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: Optional[bool] = False,
    ) -> torch.Tensor:
        # TODO: Check if this is correct
        # Value approximation
        residual_target = hidden_states[:, -1, :][
            :, None, :
        ]  # Maybe no need for skip connection

        hidden_states = self.input_layernorm(hidden_states)

        # MHA
        bsz, q_len, _ = hidden_states.size()

        if cache is None:
            hidden_target_states = hidden_states[:, :1, :]
            query_states = self.q_proj(hidden_target_states)
            hidden_node_states = hidden_states[:, 1:, :]
        else:
            query_states = cache["v"].get_queries()  # type: ignore
            hidden_node_states = hidden_states

        key_states = self.k_proj(hidden_node_states)
        value_states = self.v_proj(hidden_node_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(
            bsz, key_states.shape[1], self.num_key_value_heads, self.head_dim
        )
        value_states = value_states.view(
            bsz, value_states.shape[1], self.num_key_value_heads, self.head_dim
        )

        if cache is not None:
            key_states, value_states = cache.update_kvcache(  # type: ignore
                key_states,
                value_states,
            )

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.attention_dropout if self.training else 0.0,
                is_causal=False,
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, 1, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        hidden_target = attn_output + residual_target
        residual_target = hidden_target

        hidden_states = self.post_attention_layernorm(hidden_target)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual_target + hidden_states

        value = self.value_layer(self.act_fn(hidden_states))

        return value.view(bsz, 1)


class PositionEmbeddingModule(nn.Module):
    """A Module responsible for injecting learnable position embeddings
    to the nodes and the target
    """

    def __init__(self, embedding_size: int, rms_norm_eps: float = 1e-6):
        super(PositionEmbeddingModule, self).__init__()
        self.embedding_size = embedding_size
        self.target_embedding = nn.Parameter(torch.empty(embedding_size))
        self.node_embedding = nn.Parameter(torch.empty(embedding_size))
        torch.nn.init.xavier_uniform_(self.target_embedding.unsqueeze(0))
        torch.nn.init.xavier_uniform_(self.node_embedding.unsqueeze(0))
        self.layer_norm = LlamaRMSNorm(self.embedding_size, rms_norm_eps)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
    ):

        inputs_embeds[:, :1, :].add_(self.target_embedding)
        inputs_embeds[:, 1:, :].add_(self.node_embedding)

        position_embeds = self.layer_norm(inputs_embeds)

        return position_embeds


class StackedPolicyModule(nn.Module):
    """
    A more efficient implementation of ShortCircuitPolicyModule that processes
    all ensemble members in a single forward pass by stacking weights.
    """

    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        n_policy_layers: int,
        intermediate_size: int,
        act_fn: Callable,
        layer_idx_offset: int = 0,
        num_key_value_heads: int | None = None,
        ensemble_size: int = 4,  # Number of policy networks in the ensemble
        dropout: float = 0.0,
        attention_bias: bool = False,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.embedding_size = embedding_size
        self.n_heads = n_heads
        self.n_policy_layers = n_policy_layers
        self.intermediate_size = intermediate_size
        self.act_fn = act_fn
        self.layer_idx_offset = layer_idx_offset
        self.num_key_value_heads = num_key_value_heads
        self.dropout = dropout
        self.ensemble_size = ensemble_size

        # Create a single policy layer that we'll replicate across the batch dimension
        self.layers = nn.ModuleList(
            [
                BatchedDecoderLayer(
                    embedding_size=self.embedding_size,
                    n_heads=self.n_heads,
                    intermediate_size=self.intermediate_size,
                    layer_idx=layer_idx,
                    act_fn=act_fn,
                    rms_norm_eps=rms_norm_eps,
                    ensemble_size=ensemble_size,
                )
                for layer_idx in range(
                    layer_idx_offset, layer_idx_offset + self.n_policy_layers - 1
                )
            ]
        )

        # Batch norm layer
        self.norm = BatchedRMSNorm(
            self.embedding_size, eps=rms_norm_eps, ensemble_size=ensemble_size
        )

        # Batched action head
        self.action_layer = BatchedShortCircuitActionHead(
            self.embedding_size,
            self.n_heads,
            layer_idx_offset + self.n_policy_layers - 1,
            ensemble_size=ensemble_size,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor | None = None,
        cache: KVCache | None = None,
        use_cache: bool | None = None,
    ) -> torch.Tensor:
        # Expand hidden states to process all ensemble members at once
        # [batch, seq, dim] -> [batch, ensemble, seq, dim]
        expanded_hidden_states = hidden_states.unsqueeze(1).expand(
            -1, self.ensemble_size, -1, -1
        )
        # expanded_hidden_states = hidden_states.unsqueeze(1).repeat(1, self.ensemble_size, 1, 1)
        attention_mask = attention_mask.unsqueeze(1).expand(
            -1, self.ensemble_size, -1, -1, -1
        ) if attention_mask is not None else None

        # Process all layers maintaining the ensemble dimension
        for layer in self.layers:
            expanded_hidden_states = layer(
                expanded_hidden_states,
                attention_mask=attention_mask,
                cache=cache,
                use_cache=use_cache,
            )

        # Apply layer normalization
        expanded_hidden_states = self.norm(expanded_hidden_states)

        # Get attention outputs for all ensemble members
        attention_weights = self.action_layer(
            expanded_hidden_states,
            attention_mask=attention_mask,
            cache=cache,
            use_cache=use_cache,
        )

        # Return with shape [batch, ensemble, seq, seq]
        return attention_weights


class BatchedRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6, ensemble_size: int = 4):
        super().__init__()
        # Create separate weights for each ensemble member
        self.weight = nn.Parameter(torch.ones(ensemble_size, hidden_size))
        self.variance_epsilon = eps
        self.ensemble_size = ensemble_size

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # hidden_states shape: [batch, ensemble, seq_len, hidden_size]
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)

        # Calculate variance along last dimension
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        # Apply weights for each ensemble member separately
        # self.weight shape: [ensemble, hidden_size]
        return self.weight.unsqueeze(0).unsqueeze(2) * hidden_states.to(input_dtype)


class BatchedDecoderLayer(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: int,
        intermediate_size: int,
        act_fn: Callable,
        ensemble_size: int = 4,
        rms_norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.n_heads = n_heads
        self.layer_idx = layer_idx
        self.ensemble_size = ensemble_size

        # Batched attention with separate parameters for each ensemble member
        self.self_attn = BatchedAttention(
            self.hidden_size,
            self.n_heads,
            layer_idx=layer_idx,
            ensemble_size=ensemble_size,
        )

        # Batched MLP with separate parameters for each ensemble member
        self.mlp = BatchedMLP(
            self.hidden_size,
            self.intermediate_size,
            act_fn=act_fn,
            ensemble_size=ensemble_size,
        )

        # Batched layer norms
        self.input_layernorm = BatchedRMSNorm(
            self.hidden_size, eps=rms_norm_eps, ensemble_size=ensemble_size
        )
        self.post_attention_layernorm = BatchedRMSNorm(
            self.hidden_size, eps=rms_norm_eps, ensemble_size=ensemble_size
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: Optional[bool] = False,
        **kwargs,
    ) -> torch.Tensor:
        # hidden_states shape: [batch, ensemble, seq_len, embedding_size]
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            cache=cache,
            use_cache=use_cache,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class BatchedAttention(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: Optional[int] = None,
        ensemble_size: int = 4,
        num_key_value_heads: Optional[int] = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_dropout = dropout
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.ensemble_size = ensemble_size

        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        # Create separate projection matrices for each ensemble member
        self.q_proj = nn.Parameter(
            torch.empty(ensemble_size, self.hidden_size, self.num_heads * self.head_dim)
        )
        self.k_proj = nn.Parameter(
            torch.empty(
                ensemble_size,
                self.hidden_size,
                self.num_key_value_heads * self.head_dim,
            )
        )
        self.v_proj = nn.Parameter(
            torch.empty(
                ensemble_size,
                self.hidden_size,
                self.num_key_value_heads * self.head_dim,
            )
        )
        self.o_proj = nn.Parameter(
            torch.empty(ensemble_size, self.hidden_size, self.hidden_size)
        )

        if attention_bias:
            self.q_bias = nn.Parameter(
                torch.zeros(ensemble_size, self.num_heads * self.head_dim)
            )
            self.k_bias = nn.Parameter(
                torch.zeros(ensemble_size, self.num_key_value_heads * self.head_dim)
            )
            self.v_bias = nn.Parameter(
                torch.zeros(ensemble_size, self.num_key_value_heads * self.head_dim)
            )
            self.o_bias = nn.Parameter(torch.zeros(ensemble_size, self.hidden_size))
        else:
            self.register_parameter("q_bias", None)
            self.register_parameter("k_bias", None)
            self.register_parameter("v_bias", None)
            self.register_parameter("o_bias", None)

        # Initialize parameters
        for param in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]:
            nn.init.xavier_uniform_(param)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: bool = False,
    ) -> torch.Tensor:
        # hidden_states: [batch, ensemble, seq, hidden]
        bsz, ensemble_size, q_len, hidden_dim = hidden_states.size()
        
        # --- Projections using einsum (keeps ensemble dim separate) ---
        # [B, E, L, D] x [E, D, H*Hd] -> [B, E, L, H*Hd]
        query_states = torch.einsum("beld,edh->belh", hidden_states, self.q_proj)
        # [B, E, L, D] x [E, D, KvH*Hd] -> [B, E, L, KvH*Hd]
        key_states = torch.einsum("beld,edh->belh", hidden_states, self.k_proj)
        value_states = torch.einsum("beld,edh->belh", hidden_states, self.v_proj)

        if self.q_bias is not None:
            query_states = query_states + self.q_bias.view(1, ensemble_size, 1, -1)
            key_states = key_states + self.k_bias.view(1, ensemble_size, 1, -1)
            value_states = value_states + self.v_bias.view(1, ensemble_size, 1, -1)

        # --- Reshape and Transpose for Attention ---
        # Reshape: [B, E, L, H*Hd] -> [B*E, L, H, Hd]
        query_states = query_states.view(
            bsz, ensemble_size, q_len, self.num_heads, self.head_dim
        )
        # Reshape: [B, E, L, KvH*Hd] -> [B*E, L, KvH, Hd]
        key_states = key_states.view(
            bsz, ensemble_size, q_len, self.num_key_value_heads, self.head_dim
        )
        value_states = value_states.view(
            bsz, ensemble_size, q_len, self.num_key_value_heads, self.head_dim
        )

        query_states = query_states.transpose(-3, -2)
        key_states = key_states.transpose(-3, -2)
        value_states = value_states.transpose(-3, -2)



        # --- KV Caching (Not implemented for batched version here) ---
        if use_cache:
            # KVCache logic would need significant adaptation for the ensemble dimension
            # For now, assume use_cache=False for BatchedAttention
            if cache is not None:
                 raise NotImplementedError("KVCache not implemented for BatchedAttention")
            
        # --- Repeat KV Heads ---
        key_states = batched_repeat_kv(key_states, self.num_key_value_groups)
        value_states = batched_repeat_kv(value_states, self.num_key_value_groups)


        # --- Scaled Dot-Product Attention ---
        # Input shapes: Q=[B, E, H, L, Hd], K=[B, E, H, L, Hd], V=[B, E, H, L, Hd]
        # Mask shape: [B, 1, L, L]
        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
            # attn_output = torch.vmap(
            #     F.scaled_dot_product_attention,
            #     in_dims=(1, 1, 1, None),
            #     out_dims=1,
            # )(query_states, key_states, value_states, attention_mask)
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.attention_dropout if self.training else 0.0,
            )
        # Output shape: [B*E, H, L, Hd]

        # --- Reshape and Transpose Output ---
        # Transpose: [B*E, H, L, Hd] -> [B*E, L, H, Hd]
        attn_output = attn_output.transpose(-2, -3).contiguous()
        # Reshape: [B*E, L, H, Hd] -> [B, E, L, D]
        attn_output = attn_output.view(bsz, ensemble_size, q_len, self.hidden_size)

        # --- Output Projection ---
        # Input: [B, E, L, D], Weight: [E, D, D] -> Output: [B, E, L, D]
        attn_output = torch.einsum("beld,edh->belh", attn_output, self.o_proj)

        if self.o_bias is not None:
            attn_output = attn_output + self.o_bias.view(1, ensemble_size, 1, -1)

        return attn_output


class BatchedMLP(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        intermediate_size: int,
        act_fn: Callable,
        ensemble_size: int = 4,
    ):
        super().__init__()
        self.hidden_size = embedding_size
        self.intermediate_size = intermediate_size
        self.ensemble_size = ensemble_size

        # Create separate projection matrices for each ensemble member
        self.gate_proj = nn.Parameter(
            torch.empty(ensemble_size, self.hidden_size, self.intermediate_size)
        )
        self.up_proj = nn.Parameter(
            torch.empty(ensemble_size, self.hidden_size, self.intermediate_size)
        )
        self.down_proj = nn.Parameter(
            torch.empty(ensemble_size, self.intermediate_size, self.hidden_size)
        )
        # self.gate_proj = nn.Parameter(
        #     torch.empty(ensemble_size, self.intermediate_size, self.hidden_size)
        # )
        # self.up_proj = nn.Parameter(
        #     torch.empty(ensemble_size, self.intermediate_size, self.hidden_size)
        # )
        # self.down_proj = nn.Parameter(
        #     torch.empty(ensemble_size, self.hidden_size, self.intermediate_size)
        # )

        self.act_fn = act_fn

        # Initialize parameters
        for param in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(param)

    def forward(self, x):
        # x: [batch, ensemble, seq, hidden]
        # bsz, ensemble_size, seq_len, hidden_size = x.size()

        # Apply projections for all ensemble members
        gate = torch.einsum("besd,edm->besm", x, self.gate_proj)
        up = torch.einsum("besd,edm->besm", x, self.up_proj)

        # Apply activation and gating
        intermediate = self.act_fn(gate) * up

        # Apply down projection
        down = torch.einsum("besm,emh->besh", intermediate, self.down_proj)
        # lambda x, g, u, d: F.linear(self.act_fn(F.linear(x, g)) * F.linear(x, u), d)

        # output = torch.vmap(
        #     lambda x, g, u, d: F.linear(self.act_fn(F.linear(x, g)) * F.linear(x, u), d),
        #     in_dims=(1, 0, 0, 0),
        #     out_dims=1,
        # )(x, self.gate_proj, self.up_proj, self.down_proj)

        return down


class BatchedShortCircuitActionHead(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        n_heads: int,
        layer_idx: Optional[int] = None,
        ensemble_size: int = 4,
        num_key_value_heads: Optional[int] = None,
        dropout: float = 0.0,
        attention_bias: bool = False,
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.attention_dropout = dropout
        self.hidden_size = embedding_size
        self.num_heads = n_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.ensemble_size = ensemble_size

        if num_key_value_heads is None:
            num_key_value_heads = self.num_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        # Create separate projection matrices for each ensemble member
        self.q_proj = nn.Parameter(
            torch.empty(ensemble_size, self.hidden_size, self.num_heads * self.head_dim)
        )
        self.k_proj = nn.Parameter(
            torch.empty(
                ensemble_size,
                self.hidden_size,
                self.num_key_value_heads * self.head_dim,
            )
        )

        if attention_bias:
            self.q_bias = nn.Parameter(
                torch.zeros(ensemble_size, self.num_heads * self.head_dim)
            )
            self.k_bias = nn.Parameter(
                torch.zeros(ensemble_size, self.num_key_value_heads * self.head_dim)
            )
        else:
            self.register_parameter("q_bias", None)
            self.register_parameter("k_bias", None)

        # Initialize parameters
        for param in [self.q_proj, self.k_proj]:
            nn.init.xavier_uniform_(param)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        cache: Optional[LayerKVCache] = None,
        use_cache: bool = False,
        num_outputs: int = 1,
    ) -> torch.Tensor:
        # hidden_states: [batch, ensemble, seq, hidden]
        bsz, ensemble_size, q_len, _ = hidden_states.size()

        # --- Slicing ---
        # Remove target/output nodes for Q and K calculation
        if q_len > 1:
            # Slice along sequence dim: keep nodes from num_outputs onwards
            hidden_states_nodes = hidden_states[:, :, num_outputs:, :]
            nodes_len = q_len - num_outputs
        else:
            # If only one token (e.g., during generation), handle appropriately
            # This head might not make sense with q_len=1 if it needs context
            # For now, assume q_len > 1 during training/where this head is used
            hidden_states_nodes = hidden_states 
            nodes_len = q_len
            # Consider raising an error or adjusting logic if q_len=1 is expected

        # --- Projections ---
        # Project only the node states
        # [B, E, L_nodes, D] x [E, D, H*Hd] -> [B, E, L_nodes, H*Hd]
        query_states = torch.einsum("beld,edh->belh", hidden_states_nodes, self.q_proj)
        # [B, E, L_nodes, D] x [E, D, KvH*Hd] -> [B, E, L_nodes, KvH*Hd]
        key_states = torch.einsum("beld,edh->belh", hidden_states_nodes, self.k_proj)

        if self.q_bias is not None:
            query_states = query_states + self.q_bias.view(1, ensemble_size, 1, -1)
            key_states = key_states + self.k_bias.view(1, ensemble_size, 1, -1)

        # --- Reshape and Transpose ---
        # Reshape: [B, E, L_nodes, H*Hd] -> [B*E, L_nodes, H, Hd]
        query_states = query_states.view(
            bsz, ensemble_size, nodes_len, self.num_heads, self.head_dim
        )
        # Reshape: [B, E, L_nodes, KvH*Hd] -> [B*E, L_nodes, KvH, Hd]
        key_states = key_states.view(
            bsz, ensemble_size, nodes_len, self.num_key_value_heads, self.head_dim
        )

        # Transpose: [B*E, L_nodes, H, Hd] -> [B*E, H, L_nodes, Hd]
        query_states = query_states.transpose(-2, -3)
        # Transpose: [B*E, L_nodes, KvH, Hd] -> [B*E, KvH, L_nodes, Hd]
        key_states = key_states.transpose(-2, -3)

        # --- KV Caching (Not implemented) ---
        if use_cache:
             if cache is not None:
                 raise NotImplementedError("KVCache not implemented for BatchedShortCircuitActionHead")

        # --- Repeat KV Heads ---
        # Input: [B*E, KvH, L_nodes, Hd]
        key_states = batched_repeat_kv(key_states, self.num_key_value_groups)
        # Output: [B*E, H, L_nodes, Hd]

        # --- Compute Attention Scores (No SDPA needed, just matmul) ---
        # Q=[B*E, H, L_nodes, Hd], K.T=[B*E, H, Hd, L_nodes] -> [B*E, H, L_nodes, L_nodes]
        attn_weights = torch.matmul(
            query_states, key_states.transpose(-1, -2)
        ) / math.sqrt(self.head_dim)

        # --- Average Over Heads ---
        # Input: [B*E, H, L_nodes, L_nodes] -> Output: [B*E, L_nodes, L_nodes]
        attn_weights = attn_weights.mean(2)

        # --- Reshape back to include ensemble dimension ---
        # Output: [B, E, L_nodes, L_nodes]
        attn_weights = attn_weights.view(bsz, ensemble_size, nodes_len, nodes_len)

        return attn_weights
        key_states = batched_repeat_kv(key_states, self.num_key_value_groups)

        # Compute attention scores
        attn_weights = torch.matmul(
            query_states, key_states.transpose(3, 4)
        ) / math.sqrt(self.head_dim)

        # Average over heads dimension to get final scores
        attn_weights = attn_weights.mean(2)  # [batch, ensemble, seq, seq]

        return attn_weights


def batched_repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Batched version of repeat_kv that handles the ensemble dimension.

    Args:
        x: Tensor with shape [batch, ensemble, heads, seq, head_dim].
        n_rep: Number of times to repeat.

    Returns:
        Tensor with shape [batch, ensemble, heads*n_rep, seq, head_dim].
    """
    B, E, H, L, D = x.size()
    if n_rep == 1:
        return x
    x = x.unsqueeze(3).expand(B, E, H, n_rep, L, D)
    return x.reshape(B, E, H * n_rep, L, D)


class TruthTableEncoder(nn.Module):
    """
    Encoder that converts truth tables to embeddings.
    Can be initialized from scratch or loaded from a pre-trained VAE checkpoint.
    """

    def __init__(
        self,
        input_dim: int | None = None,
        hidden_dims: list[int] | None = None,
        latent_dim: int | None = None,
        checkpoint_path: str | None = None,
        config_path: str | None = None,
        freeze: bool = True,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.dtype = dtype
        if checkpoint_path is not None:
            # Load from checkpoint
            self._load_from_checkpoint(checkpoint_path, config_path)
        else:
            # Create from scratch
            if None in (input_dim, latent_dim):
                raise ValueError(
                    "Must provide input_dim and latent_dim when not loading from checkpoint"
                )

            self.input_dim = input_dim
            self.latent_dim = latent_dim

            # Default hidden dimensions if not provided
            if hidden_dims is None:
                hidden_dims = [256, 128, 64]

            # Create encoder network
            self._build_encoder(hidden_dims)

        # Freeze the encoder if requested
        if freeze:
            for param in self.parameters():
                param.requires_grad = False

    def _build_encoder(self, hidden_dims: list[int]):
        """Build the encoder network architecture"""
        hidden_dims = [self.input_dim] + hidden_dims  # type: ignore
        modules = []

        for i in range(1, len(hidden_dims)):
            h_dim1 = hidden_dims[i - 1]
            h_dim2 = hidden_dims[i]
            linear = nn.Linear(h_dim1, h_dim2)
            if self.dtype is not None:
                linear = linear.to(self.dtype)
            modules.extend(
                [
                    linear,
                    LlamaRMSNorm(h_dim2),
                    nn.LeakyReLU(),
                ]
            )

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1], self.latent_dim)  # type: ignore
        if self.dtype is not None:
            self.fc_mu = self.fc_mu.to(dtype=self.dtype)

    def _load_from_checkpoint(self, checkpoint_path: str, config_path: str | None):
        """Load the encoder from a VAE checkpoint"""
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")

        if config_path is None:
            config_path = str(Path(checkpoint_path).parent.parent / "config.yaml")
        model_config = self._extract_model_config_from_yaml(config_path)
        self.input_dim = model_config.get("input_dim", 256)
        self.latent_dim = model_config.get("latent_dim", 1024)
        hidden_dims = model_config.get("hidden_dims", [256, 128, 64])

        # Load the model checkpoint
        checkpoint_data = torch.load(
            checkpoint_path, map_location="cpu", weights_only=False
        )

        # Build the encoder with the right dimensions
        self._build_encoder(hidden_dims)
        self.encoder = self.encoder.to(self.dtype)
        self.fc_mu = self.fc_mu.to(self.dtype)

        # Try to load weights
        if "model_state" in checkpoint_data:
            # Full VAE state dict
            vae_state_dict = checkpoint_data["model_state"]

            # Extract only encoder-related parameters
            encoder_state_dict = {}
            for key, value in vae_state_dict.items():
                if key.startswith("encoder."):
                    # Remove 'encoder.' prefix
                    encoder_key = key[8:]  # len('encoder.') = 8
                    encoder_state_dict[encoder_key] = value
                elif key.startswith("fc_mu."):
                    encoder_state_dict[key] = value

            # Load the encoder state dict
            missing_keys, unexpected_keys = self.load_state_dict(
                encoder_state_dict, strict=False
            )

            if missing_keys:
                print(f"Warning: Missing keys when loading encoder: {missing_keys}")
            if unexpected_keys:
                print(
                    f"Warning: Unexpected keys when loading encoder: {unexpected_keys}"
                )
        else:
            raise ValueError("No model state found in checkpoint")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        return self.fc_mu(x)

    def _extract_model_config_from_yaml(self, config_path: str) -> dict:
        """Extract model configuration from wandb config.yaml file"""
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)

        # Extract VAE config - first check if there's a 'model' section
        if "model" in config and "value" in config["model"]:
            model_config = config["model"]["value"]
        else:
            # Fall back to top-level config
            model_config = {}
            for key in ["embedding_size", "hidden_dims", "latent_dim", "input_dim"]:
                if key in config and "value" in config[key]:
                    model_config[key] = config[key]["value"]

        return model_config
