from typing import Optional, Tuple, List
import flax.linen as nn
import jax
import copy 
import numpy as np 
import jax.numpy as jnp
from flax.linen.attention import dot_product_attention_weights
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers.models.vit.modeling_flax_vit import (
    FlaxViTPreTrainedModel,
    ViTConfig,
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPooling,
    FlaxViTPatchEmbeddings,
    FlaxViTPooler,
    ACT2FN,
    FlaxViTIntermediate,
    FlaxViTOutput,
    FlaxPreTrainedModel,
    FlaxSequenceClassifierOutput,
)
from typing import Callable

def print_model(flax_params, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items(): 
        name = "/".join(path)
        line = f"{name} {value.shape}"
        if file:
            print(line, file=file)
        else:
            print(line)
def print_model_with_prefix(flax_params, prefix: str, file=None):
    flat_params = flatten_dict(flax_params)
    for path, value in flat_params.items():
        name = ".".join(path)
        if name.startswith(prefix):
            line = f"{name} {value.shape} \n {value} \n \n"
            if file: print(line, file=file) 
            else: print(line)
def create_sinusoidal_positions(n_pos, dim):
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    sentinel = dim // 2 + dim % 2
    out = np.zeros_like(position_enc)
    out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
    out[:, sentinel:] = np.cos(position_enc[:, 1::2])

    return jnp.array(out)

class LMCFlaxViTEmbeddings(nn.Module):
    """Construct the CLS token, position and patch embeddings."""
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.cls_token = self.param(
            "cls_token",
            jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
            (1, 1, self.config.hidden_size),
        )
        self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
        num_patches = self.patch_embeddings.num_patches
        if self.config.position_embeddings == "learnable":
            self.position_embeddings = self.param(
                "position_embeddings",
                jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
                (1, num_patches + 1, self.config.hidden_size),
            )
        elif self.config.position_embeddings == "sinusoidal":
            self.position_embeddings = jnp.expand_dims(create_sinusoidal_positions(num_patches + 1, self.config.hidden_size),axis=0)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    def __call__(self, pixel_values, deterministic=True):
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values)
        cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
        embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
        if self.config.position_embeddings in ["learnable","sinusoidal"]:
            embeddings = embeddings + self.position_embeddings
        embeddings = self.dropout(embeddings, deterministic=deterministic)
        return embeddings

class LMCFlaxViTSelfAttention(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
                " {self.config.num_attention_heads}"
            )
        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
            ),
            use_bias=self.config.qkv_bias,
        )
        if self.config.position_embeddings == "rope":
            self.rotary_value = self.config.rotary_value
    def __call__(self, hidden_states, sinusoidal_pos, deterministic: bool = True, output_attentions: bool = False):
        head_dim = self.config.hidden_size // self.config.num_attention_heads
        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        if sinusoidal_pos is not None and self.config.position_embeddings == 'rope':
            apply_sinusoidal_pos = create_sinusoidal_positions(sinusoidal_pos.shape[0], head_dim)[: hidden_states.shape[1], :]
            if self.rotary_value:
                query_states, key_states, value_states = self.apply_rotary_position_embeddings(
                    apply_sinusoidal_pos, query_states, key_states, value_states
                )
            else:
                query_states, key_states = self.apply_rotary_position_embeddings(
                    apply_sinusoidal_pos, query_states, key_states
                )
        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")
        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
    @staticmethod
    def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
        sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
        sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
        cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
        def rotate_layer(layer, sin_pos, cos_pos):
            rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape)
            rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos)
            rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos)
            return rotary_matrix_cos + rotary_matrix_sin
        query_layer = rotate_layer(query_layer, sin_pos, cos_pos)
        key_layer = rotate_layer(key_layer, sin_pos, cos_pos)
        if value_layer is not None:
            value_layer = rotate_layer(value_layer, sin_pos, cos_pos)
            return query_layer, key_layer, value_layer
        return query_layer, key_layer
class LMCFlaxViTSelfOutput(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states
class LMCFlaxViTAttention(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.attention = LMCFlaxViTSelfAttention(self.config, dtype=self.dtype)
        self.output = LMCFlaxViTSelfOutput(self.config, dtype=self.dtype)
    def __call__(self, hidden_states, sinusoidal_pos, deterministic=True, output_attentions: bool = False):
        attn_outputs = self.attention(hidden_states, sinusoidal_pos, deterministic=deterministic, output_attentions=output_attentions)
        attn_output = attn_outputs[0]
        hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_outputs[1],)

        return outputs


class LMCFlaxViTMLP(nn.Module):
    config: ViTConfig
    intermediate_size : int 
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.intermediate = nn.Dense(
            self.intermediate_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        self.activation = ACT2FN[self.config.hidden_act]
        self.output = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
            dtype=self.dtype,
        )
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    def __call__(self, layer_output, deterministic: bool = True):
        hidden_states = self.intermediate(layer_output)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.output(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states

class LMCFlaxViTRouter(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        self.n_routed_experts = self.config.num_routed_experts
        self.n_group = getattr(self.config, "n_group", 1)
        self.topk_group = getattr(self.config, "topk_group", 1)
        self.top_k = self.config.topk
        self.routed_scaling_factor = self.config.routed_scaling_factor
        self.norm_topk_prob = getattr(self.config, "norm_topk_prob", False)
        # Weight and bias for router computation
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
        self.router_weight = self.param(
            "router_weight", kernel_init, (self.n_routed_experts, self.config.hidden_size)
        )
        self.router_bias = self.param(
            "router_bias", lambda rng, shape: jnp.zeros(shape, dtype=self.dtype), (self.n_routed_experts,)
        )
        self.e_score_correction_bias = self.param(
            "e_score_correction_bias", lambda rng, shape: jnp.zeros(shape, dtype=self.dtype), (self.n_routed_experts,)
        )

    def get_topk_indices(self, scores):
        scores_for_choice = scores + self.e_score_correction_bias[None, :]
        scores_grouped = scores_for_choice.reshape(-1, self.n_group, self.n_routed_experts // self.n_group)
        top2_scores = jax.lax.top_k(scores_grouped, 2)[0]
        group_scores = jnp.sum(top2_scores, axis=-1)
        top_group_scores, group_idx = jax.lax.top_k(group_scores, self.topk_group)
        group_mask = jnp.zeros_like(group_scores)
        group_mask = group_mask.at[jnp.arange(group_mask.shape[0])[:, None], group_idx].set(1)
        group_mask_expanded = jnp.repeat(group_mask[:, :, None], self.n_routed_experts // self.n_group, axis=-1)
        score_mask = group_mask_expanded.reshape(-1, self.n_routed_experts)
        scores_for_choice = jnp.where(
            score_mask,
            scores_for_choice,
            jnp.zeros_like(scores_for_choice),
        )
        topk_weights, topk_indices = jax.lax.top_k(scores_for_choice, self.top_k)
        return topk_indices, topk_weights

    def __call__(self, hidden_states):
        router_logits = jnp.matmul(hidden_states, self.router_weight.T) + self.router_bias
        scores = jax.nn.sigmoid(router_logits)
        if self.n_group > 1:
            topk_indices, topk_weights = self.get_topk_indices(scores)
        else:
            topk_weights, topk_indices = jax.lax.top_k(scores, self.top_k)
        if self.norm_topk_prob:
            denominator = jnp.sum(topk_weights, axis=-1, keepdims=True) + 1e-20
            topk_weights = topk_weights / denominator
        topk_weights = topk_weights * self.routed_scaling_factor
        return topk_indices, topk_weights

class LMCFlaxViTMoE(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.n_experts = self.config.num_routed_experts
        self.n_shared_experts = self.config.num_shared_experts
        self.top_k = self.config.topk
        self.router = LMCFlaxViTRouter(config=self.config, dtype=self.dtype)  # Using ViT router here
        self.experts = [
            LMCFlaxViTMLP(config=self.config, intermediate_size=self.config.intermediate_size, dtype=self.dtype,)
            for _ in range(self.n_experts)
        ]
        if self.n_shared_experts > 0:
            self.shared_experts = LMCFlaxViTMLP(config=self.config, intermediate_size=self.config.intermediate_size * self.n_shared_experts, dtype=self.dtype)
    def __call__(self, hidden_states, deterministic: bool = True):
        residual = hidden_states
        orig_shape = hidden_states.shape
        if self.n_experts == 0 or self.top_k == 0:
            hidden_states = self.shared_experts(hidden_states, deterministic=deterministic)
            return hidden_states
        hidden_states_flat = hidden_states.reshape(-1, hidden_states.shape[-1])
        topk_indices, topk_weights = self.router(hidden_states_flat)
        expert_outputs = []
        for expert in self.experts:
            output = expert(hidden_states_flat, deterministic=deterministic)  # [n_tokens, hidden_dim]
            expert_outputs.append(output)
        expert_outputs = jnp.stack(expert_outputs, axis=1)  # [n_tokens, n_experts, hidden_dim]
        # Build routing mask:
        routing_mask = jax.nn.one_hot(topk_indices, self.n_experts, dtype=self.dtype)
        routing_mask = routing_mask.sum(axis=1) > 0   # [n_tokens, n_experts]
        # Compute weights per expert
        weights_per_expert = jax.nn.one_hot(topk_indices, self.n_experts, dtype=self.dtype)
        weights_per_expert = (weights_per_expert * topk_weights[..., None]).sum(axis=1)
        # Mask out experts not routed
        weights_per_expert = weights_per_expert * routing_mask.astype(self.dtype)
        # Multiply and sum
        weighted_expert_outputs = expert_outputs * weights_per_expert[..., None]  # [n_tokens, n_experts, hidden_dim]
        final_output = weighted_expert_outputs.sum(axis=1)  # [n_tokens, hidden_dim]
        # Add shared expert
        if self.n_shared_experts > 0: 
            shared_output = self.shared_experts(residual, deterministic=deterministic)
            final_output = final_output.reshape(orig_shape) + shared_output
        else:
            final_output = final_output.reshape(orig_shape) 
        return final_output
class LMCFlaxViTLayer(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.attention = LMCFlaxViTAttention(self.config, dtype=self.dtype)
        self.moe = LMCFlaxViTMoE(self.config, dtype=self.dtype)
        self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
    def __call__(self, hidden_states,sinusoidal_pos, deterministic: bool = True, output_attentions: bool = False, attention_input:dict|None=None, idx:int|None=None):
        ln_hidden = self.layernorm_before(hidden_states)
        # (Optional) record whatever you want into the dict
        if attention_input is not None and idx is not None:
            attention_input[idx] = ln_hidden
        attention_outputs = self.attention(
            ln_hidden, sinusoidal_pos, deterministic=deterministic, output_attentions=output_attentions,
        )
        attention_output = attention_outputs[0]
        # first residual connection
        attention_output = attention_output + hidden_states
        # in ViT, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(attention_output)
        hidden_states = self.moe(layer_output, deterministic=deterministic)
        hidden_states = hidden_states + attention_output
        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attention_outputs[1],)
        return outputs

class LMCFlaxViTLayerCollection(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        self.layers = [
            LMCFlaxViTLayer(self.config.lmc_config, name=str(i), dtype=self.dtype) if i in self.config.lmc_layer_indices
            else LMCFlaxViTLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.num_hidden_layers)
        ]
    def __call__(
        self,
        hidden_states,
        sinusoidal_pos,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        attention_input:dict|None=None,
    ):
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        for i, layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if attention_input is None:
                layer_outputs = layer(hidden_states,sinusoidal_pos,deterministic=deterministic, output_attentions=output_attentions)
            else:
                layer_outputs = layer(hidden_states,sinusoidal_pos,deterministic=deterministic, output_attentions=output_attentions, attention_input=attention_input, idx=i)
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions += (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        outputs = (hidden_states,)
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
class LMCFlaxViTEncoder(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    def setup(self):
        num_patches = (self.config.image_size**2)//self.config.patch_size 
        self.embed_positions = create_sinusoidal_positions(num_patches + 1, self.config.hidden_size//self.config.num_attention_heads)
        self.layer = LMCFlaxViTLayerCollection(self.config, dtype=self.dtype)
    def __call__(
        self,
        hidden_states,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        attention_input:dict|None=None,
    ):  
        sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :]
        return self.layer(
            hidden_states,
            sinusoidal_pos,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            attention_input=attention_input,    
        )


class LMCFlaxViTPreTrainedModel(FlaxPreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """
    config_class = ViTConfig
    base_model_prefix = "vit"
    main_input_name = "pixel_values"
    module_class: nn.Module = None
    def __init__(
        self,
        config: ViTConfig,
        input_shape=None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        if input_shape is None:
            input_shape = (1, config.image_size, config.image_size, config.num_channels)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # init input tensors
        pixel_values = jnp.zeros(input_shape, dtype=self.dtype)

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]

        if params is not None:
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            self._missing_keys = set()
            return freeze(unflatten_dict(params))
        else:
            return random_params

    def __call__(
        self,
        pixel_values,
        params: Optional[dict] = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        attention_input:dict|None=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(pixel_values, dtype=jnp.float32),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            attention_input=attention_input,
            rngs=rngs,
        )


class LMCFlaxViTModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = LMCFlaxViTEmbeddings(self.config, dtype=self.dtype)
        self.encoder = LMCFlaxViTEncoder(self.config, dtype=self.dtype)
        self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None

    def __call__(
        self,
        pixel_values,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        attention_input:dict|None=None,

    ):
        hidden_states = self.embeddings(pixel_values, deterministic=deterministic)

        outputs = self.encoder(
            hidden_states,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            attention_input=attention_input,
        )
        hidden_states = outputs[0]
        hidden_states = self.layernorm(hidden_states)
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



class LMCFlaxViTModel(LMCFlaxViTPreTrainedModel):
    module_class = LMCFlaxViTModule

class LMCFlaxViTForImageClassificationModule(nn.Module):
    config: ViTConfig
    dtype: jnp.dtype = jnp.float32
    def setup(self):
        self.vit = LMCFlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
        self.classifier = nn.Dense(
            self.config.num_labels,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.variance_scaling(
                self.config.initializer_range**2, "fan_in", "truncated_normal"
            ),
        )

    def __call__(
        self,
        pixel_values=None,
        deterministic: bool = True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        attention_input:dict|None=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.vit(
            pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            attention_input=attention_input,
        )

        hidden_states = outputs[0]
        logits = self.classifier(hidden_states[:, 0, :])

        if not return_dict:
            output = (logits,) + outputs[2:]
            return output

        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
class LMCFlaxViTForImageClassification(LMCFlaxViTPreTrainedModel):
    module_class = LMCFlaxViTForImageClassificationModule
