import torch.nn.functional as F
import torch.nn as nn
import torch
from utils.constants import MODEL_VOCAB_SIZES, FEATURE_DIMS
from einops import repeat
from vit_pytorch import ViT
from utils.Architectures_utils import *
from matplotlib import pyplot as plt
import math

def get_model(args, max_sequence_length, actual_sequence_length, input_dim, input_shape):
    model_mapping = {
        # LOS-based
        'LOS-Net': LOS_Net,
        'ATP_R_MLP': ATP_R_MLP,
        'ATP_R_Transf': ATP_R_Transf,
        # activations-based
        'ACT-Vit': ACT_Vit,
        'ACT-Vit-with-symmetries': ACT_Vit_with_symmetries,
        'ACT-Vit-with-symmetries-V2': ACT_Vit_with_symmetries_V2,
        'ACT-MLP': ACT_MLP,
        'ACT-Vit-foundation': ACT_Vit_foundation,
        'ACT-MLP-foundation': ACT_MLP_Foundation,
    }
    
    if args.probe_model in {'ACT-ViT', 'ATP_R_Transf', 'Logit_Canonized'}:
        return model_mapping[args.probe_model](args=args, max_sequence_length=max_sequence_length, input_dim=input_dim)
    elif args.probe_model in {'ATP_R_MLP'}:
        return model_mapping[args.probe_model](args=args, actual_sequence_length=actual_sequence_length)
    elif args.probe_model in {'ACT-Vit', 'ACT-Vit-with-symmetries', 'ACT-MLP', 'ACT-Vit-with-symmetries-V2', 'ACT-Vit-foundation', 'ACT-MLP-foundation'}:
        return model_mapping[args.probe_model](args=args, input_dim=input_dim, input_shape=input_shape)
    else:
        raise ValueError(f"Unknown model: {args.probe_model}")
    

######################## LOS ########################
class ATP_R_MLP(nn.Module):

    def __init__(self, args, actual_sequence_length):

        super(ATP_R_MLP, self).__init__()        
        self.args = args
        self.hidden_dim = args.hidden_dim
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.actual_sequence_length = actual_sequence_length
        
        self.param_for_normalized_ATP = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
        if self.args.rank_encoding == 'scale_encoding':
            self.param_for_ATP_R = nn.Parameter(torch.randn(1, 1, self.hidden_dim))        
        elif self.args.rank_encoding == 'one_hot_encoding':
            self.one_hot_embedding = nn.Embedding(MODEL_VOCAB_SIZES[self.args.LLM],
            self.hidden_dim,
            # sparse=True
            )
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")

        
        # Linear layers
        self.lin_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for i in range(self.num_layers):
            in_dim = self.hidden_dim if i > 0 else self.hidden_dim * self.actual_sequence_length
            out_dim = self.hidden_dim if (i+1) < self.num_layers else 1
            self.lin_layers.append(nn.Linear(in_dim, out_dim))
            if (i+1) < self.num_layers:
                self.batch_norms.append(nn.BatchNorm1d(out_dim))

        # Output act
        self.sigmoid = nn.Sigmoid()
    def compute_encoded_ATP_R(self, normalized_ATP, ATP_R):
        """
        Computes encoded_ATP_R based on normalized_ATP and ATP_R.
        """
        encoded_ATP_R = 2 * (0.5 - (ATP_R / MODEL_VOCAB_SIZES[self.args.LLM]))
        
        return normalized_ATP * encoded_ATP_R.unsqueeze(-1) * self.param_for_ATP_R

    def forward(self, sorted_TDS_normalized, normalized_ATP, ATP_R):


        # Encoding one-hot rank
        if self.args.rank_encoding == 'scale_encoding':
            encoded_ATP_R = self.compute_encoded_ATP_R(normalized_ATP=normalized_ATP, ATP_R=ATP_R)
        elif self.args.rank_encoding == 'one_hot_encoding':
            encoded_ATP_R = normalized_ATP * self.one_hot_embedding(ATP_R)
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")
                    
        # Encoding normalized mark
        encoded_normalized_ATP = normalized_ATP * self.param_for_normalized_ATP
        x = encoded_ATP_R + encoded_normalized_ATP
        x = x.flatten(start_dim=1)
        
        for i in range(self.num_layers):
            x = self.lin_layers[i](x)
            if (i+1) < self.num_layers:
                x = self.batch_norms[i](x)
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout)


        return self.sigmoid(x).squeeze(-1)  # Apply sigmoid for binary classification


class ATP_R_Transf(nn.Module):
    
    def __init__(self, args, max_sequence_length, input_dim=1):
        
        super(ATP_R_Transf, self).__init__()
        self.args = args
        self.input_dim = input_dim
        self.max_sequence_length = max_sequence_length
        self.hidden_dim = args.hidden_dim
        self.heads = args.heads
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.pool = args.pool
        assert self.pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.param_for_normalized_ATP = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
        if self.args.rank_encoding == 'scale_encoding':
            self.param_for_ATP_R = nn.Parameter(torch.randn(1, 1, self.hidden_dim))        
        elif self.args.rank_encoding == 'one_hot_encoding':
            self.one_hot_embedding = nn.Embedding(MODEL_VOCAB_SIZES[self.args.LLM],
            self.hidden_dim,
            # sparse=True
            )
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")
        
        

        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim))

        # Positional embeddings with a predefined max sequence length
        self.pos_embedding = nn.Embedding(self.max_sequence_length + 1, self.hidden_dim)

        # Transformer encoder layers
        self.attention_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=self.hidden_dim,
                nhead=self.heads,
                dropout=self.dropout,
                dim_feedforward=self.hidden_dim,
                batch_first=True
            ) for _ in range(self.num_layers)
        ])

        # Classification head
        self.mlp_head = nn.Linear(self.hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def compute_encoded_ATP_R(self, normalized_ATP, ATP_R):
        """
        Computes encoded_ATP_R based on normalized_ATP and ATP_R.
        """
        encoded_ATP_R = 2 * (0.5 - (ATP_R / MODEL_VOCAB_SIZES[self.args.LLM]))
        
        return normalized_ATP * encoded_ATP_R.unsqueeze(-1) * self.param_for_ATP_R
    
    def forward(self, sorted_TDS_normalized, normalized_ATP, ATP_R):
            
        # Encoding one-hot rank
        if self.args.rank_encoding == 'scale_encoding':
            encoded_ATP_R = self.compute_encoded_ATP_R(normalized_ATP=normalized_ATP, ATP_R=ATP_R)
        elif self.args.rank_encoding == 'one_hot_encoding':
            encoded_ATP_R = normalized_ATP * self.one_hot_embedding(ATP_R)
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")
                    
        # Encoding normalized mark
        encoded_normalized_ATP = normalized_ATP * self.param_for_normalized_ATP
        x = encoded_ATP_R + encoded_normalized_ATP

    
        # Add [CLS] token
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # Shape: [B, 1, hidden_dim]
        x = torch.cat((cls_tokens, x), dim=1)  # Shape: [B, N+1, hidden_dim]

        # Generate positional indices and add embeddings
        pos_indices = torch.arange(n + 1, device=x.device).unsqueeze(0)  # Shape: [1, N+1]
        pos_embeddings = self.pos_embedding(pos_indices)  # Shape: [1, N+1, hidden_dim]
        x += pos_embeddings

        # Pass through Transformer layers
        for layer in self.attention_layers:
            x = layer(x)  # Shape remains [B, N+1, hidden_dim]

        # Pooling: Use the CLS token
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # Final classification head
        x = self.mlp_head(x)  # Shape: [B, 1]
        
        return self.sigmoid(x).squeeze(-1)  # Apply sigmoid for binary classification
    

class LOS_Net(nn.Module):
    def __init__(self, args, max_sequence_length, input_dim=1):
        super().__init__()
        
        self.args = args
        self.max_sequence_length = max_sequence_length
        self.input_dim = input_dim
        self.hidden_dim = args.hidden_dim
        self.heads = args.heads
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.pool = args.pool
        
        assert self.pool in {'cls', 'mean'}, "Pool type must be either 'cls' (CLS token) or 'mean' (mean pooling)"
        
        self.param_for_normalized_ATP = nn.Parameter(torch.randn(1, 1, self.hidden_dim // 2))


        self.param_for_normalized_ATP = nn.Parameter(torch.randn(1, 1, self.hidden_dim // 2))
        if self.args.rank_encoding == 'scale_encoding':
            self.param_for_ATP_R = nn.Parameter(torch.randn(1, 1, self.hidden_dim // 2))        
        elif self.args.rank_encoding == 'one_hot_encoding':
            self.one_hot_embedding = nn.Embedding(MODEL_VOCAB_SIZES[self.args.LLM],
            self.hidden_dim // 2,
            # sparse=True
            )
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")
        
        
        
        # Input embedding layer
        self.input_proj = nn.Linear(input_dim, self.hidden_dim // 2)
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
        
        # Positional embeddings
        self.pos_embedding = nn.Embedding(self.max_sequence_length + 1, self.hidden_dim)
        
        # Transformer encoder layers
        self.attention_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=self.hidden_dim,
                nhead=self.heads,
                dropout=self.dropout,
                dim_feedforward=self.hidden_dim,
                batch_first=True
            ) for _ in range(self.num_layers)
        ])
        
        # Classification head
        self.mlp_head = nn.Linear(self.hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def compute_encoded_ATP_R(self, normalized_ATP, ATP_R):
        """
        Computes encoded_ATP_R based on normalized_ATP and ATP_R.
        """
        encoded_ATP_R = 2 * (0.5 - (ATP_R / MODEL_VOCAB_SIZES[self.args.LLM]))
        return normalized_ATP * encoded_ATP_R.unsqueeze(-1) * self.param_for_ATP_R
    
    def forward(self, sorted_TDS_normalized, normalized_ATP, ATP_R):
        """
        Forward pass for LOS_Net.

        Args:
            sorted_TDS_normalized (torch.Tensor): Shape [B, N, V].
            normalized_ATP (torch.Tensor): Shape [B, N, 1].
            ATP_R (torch.Tensor): Shape [B, N].
            sigmoid (bool): Whether to apply sigmoid activation. Default is True.

        Returns:
            torch.Tensor: Output tensor of shape [B, 1] (if sigmoid=True) or raw logits (if sigmoid=False).
        """
        # Encoding one-hot rank
        if self.args.rank_encoding == 'scale_encoding':
            encoded_ATP_R = self.compute_encoded_ATP_R(normalized_ATP=normalized_ATP, ATP_R=ATP_R)
        elif self.args.rank_encoding == 'one_hot_encoding':
            encoded_ATP_R = normalized_ATP * self.one_hot_embedding(ATP_R)
        else:
            raise ValueError("Invalid encoding type. Please choose either 'scale_encoding' or 'one_hot_encoding'.")
            
        
        # Encoding normalized mark
        encoded_normalized_ATP = normalized_ATP * self.param_for_normalized_ATP
        
        
        # Encoding normalized vocab
        encoded_sorted_TDS_normalized = self.input_proj(sorted_TDS_normalized.to(torch.float32))
        
        # Concatenating embeddings
        x = torch.cat((encoded_sorted_TDS_normalized, encoded_ATP_R + encoded_normalized_ATP), dim=-1)
        
        # Adding CLS token
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Positional embeddings
        pos_indices = torch.arange(n + 1, device=x.device).unsqueeze(0)
        x += self.pos_embedding(pos_indices)
        
        # Transformer layers
        for layer in self.attention_layers:
            x = layer(x)
        
        # Pooling
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        
        # Classification head
        x = self.mlp_head(x)
        return self.sigmoid(x).squeeze(-1)
   
######################## LOS ########################


######################## Activations ########################


class ACT_Vit_with_symmetries(nn.Module):
    def __init__(self, args, input_dim, input_shape):
        super(ACT_Vit_with_symmetries, self).__init__()
        
        self.input_dim = input_dim
        self.input_shape = input_shape #
        self.hidden_dim = args.hidden_dim
        self.heads = args.heads
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.patch_size = eval(args.patch_size)
        
        
        
        # Unpack input shape
        L, N, _ = self.input_shape
        
        assert L % self.patch_size[0] == 0 and N % self.patch_size[1] == 0, "Patch size must be divisible by input shape"
        
        
        # Create the ViT module from vit_pytorch
        self.vit = SdInvVit(
            image_size=(L, N),
            patch_size=self.patch_size,
            num_classes=1,
            dim=self.hidden_dim,
            depth=self.num_layers,
            heads=self.heads,
            mlp_dim=self.hidden_dim,
            dropout=self.dropout,
            emb_dropout=self.dropout,
            channels=self.hidden_dim,
            pool=args.pool
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, activations):
        x = activations.to(torch.float)
        # Input x is of shape (B, L, N, d)
        B, L, N, d = x.shape 
        
        # Permute the input to (B, d, L, N) which is (B, channels, height, width)
        x = x.permute(0, 3, 1, 2)
        
        # Forward pass through ViT
        preds = self.vit(x)  # Output shape: (B, num_classes)
        preds = self.sigmoid(preds).squeeze(-1)
        return preds

class ACT_Vit_with_symmetries_V2(ACT_Vit_with_symmetries):
    def __init__(self, args, input_dim, input_shape):
        super(ACT_Vit_with_symmetries_V2, self).__init__(args, input_dim, input_shape)

        # Overwrite or adjust only the desired attribute, e.g., changing 'pool' argument
        L, N, _ = self.input_shape
        
        # Create the ViT module from vit_pytorch
        self.vit = SdInvVit(
            image_size=(L, N),
            patch_size=self.patch_size,
            num_classes=1,
            dim=self.hidden_dim,
            depth=self.num_layers,
            heads=self.heads,
            mlp_dim=self.hidden_dim,
            dropout=self.dropout,
            emb_dropout=self.dropout,
            channels=self.hidden_dim,
            pool=args.pool,
            DS_model=DSS
        )


class ACT_Vit(nn.Module):
    def __init__(self, args, input_dim, input_shape):
        super(ACT_Vit, self).__init__()
        
        self.input_dim = input_dim
        self.input_shape = input_shape #
        self.hidden_dim = args.hidden_dim
        self.heads = args.heads
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.patch_size = eval(args.patch_size)
        
        
        
        # Unpack input shape
        L, N, _ = self.input_shape
        
        assert L % self.patch_size[0] == 0 and N % self.patch_size[1] == 0, "Patch size must be divisible by input shape"
        
        self.first_layer = nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim)
        
        # Create the ViT module from vit_pytorch
        self.vit = ViT(
            image_size=(L, N),
            patch_size=self.patch_size,
            num_classes=1,
            dim=self.hidden_dim,
            depth=self.num_layers,
            heads=self.heads,
            mlp_dim=self.hidden_dim,
            dropout=self.dropout,
            emb_dropout=self.dropout,
            channels=self.hidden_dim,
            pool=args.pool
        )
        
        def pair(t):
            return t if isinstance(t, tuple) else (t, t)
        
        patch_height, patch_width = pair(self.patch_size)
        patch_dim = self.hidden_dim * patch_height * patch_width
        
        image_height, image_width = pair((L, N))
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, self.input_dim, patch_height * patch_width))
        # Override patch embedding method
        self.vit.to_patch_embedding = PatchEmbeddingWithPos(
            patch_height=patch_height,
            patch_width=patch_width,
            patch_dim=patch_dim,
            hidden_dim=self.hidden_dim,
        )

        
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, activations):
        x = self.first_layer(activations.to(torch.float))
        # Input x is of shape (B, L, N, d)
        B, L, N, d = x.shape 

        # Permute the input to (B, d, L, N) which is (B, channels, height, width)
        x = x.permute(0, 3, 1, 2)
        
        # Forward pass through ViT
        preds = self.vit(x)  # Output shape: (B, num_classes)
        preds = self.sigmoid(preds).squeeze(-1)
        return preds

class ACT_MLP(nn.Module):

    def __init__(self, args, input_dim, input_shape):

        super(ACT_MLP, self).__init__()        
        self.input_dim = input_dim
        self.input_shape = input_shape #
        self.hidden_dim = args.hidden_dim
        self.num_layers = args.num_layers
        self.dropout = args.dropout
        
        # Linear layers
        self.lin_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for i in range(self.num_layers):
            in_dim = self.hidden_dim if i > 0 else self.input_shape[-1] * self.input_shape[-2] * self.input_shape[-3]
            out_dim = self.hidden_dim if (i+1) < self.num_layers else 1
            self.lin_layers.append(nn.Linear(in_dim, out_dim))
            if (i+1) < self.num_layers:
                self.batch_norms.append(nn.BatchNorm1d(out_dim))

        # Output act
        self.sigmoid = nn.Sigmoid()

    def forward(self, activations):
        B, L, N, d = activations.shape 

        x = activations.reshape(B, L * N * d).to(torch.float32)
        
        for i in range(self.num_layers):
            x = self.lin_layers[i](x)
            if (i+1) < self.num_layers:
                x = self.batch_norms[i](x)
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout)
        return self.sigmoid(x).squeeze(-1)  # Apply sigmoid for binary classification

class ACT_MLP_Foundation(ACT_MLP):
    def forward(self, activations, indices):
        # Ignore indices and call parent's forward method
        return super().forward(activations)

# --

class ModuleListPerLLMLinear(nn.Module):
    def __init__(self, in_features, out_features, num_llms):
        super(ModuleListPerLLMLinear, self).__init__()
        self.num_llms = num_llms
        self.adapters = nn.ModuleList([nn.Linear(in_features, out_features) for _ in range(num_llms)])
        
    def forward(self, input, indices):
        out = torch.zeros(*input.shape[:-1], self.adapters[0].out_features,
                        device=input.device, dtype=input.dtype)
        for i in torch.unique(indices):
            mask = (indices == i)
            adapter_out = self.adapters[i](input[mask].to(torch.float32)).to(input.dtype)
            out[mask] = adapter_out
        return out

        
        
## inspired by nn.Linear
class PerLLMLinear(nn.Module):
    r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`.

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(*, H_{in})` where :math:`*` means any number of
          dimensions including none and :math:`H_{in} = \text{in\_features}`.
        - Output: :math:`(*, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """

    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
        num_llms: int = 1,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        # changes from original nn.Linear: shape is now (num_llms, out_features, in_features) and weights and biases are now embeddings and not parameters
        self.num_llms = num_llms
        self.weight = nn.Embedding(num_llms, out_features * in_features, **factory_kwargs)
        if bias:
            self.bias = nn.Embedding(num_llms, out_features, **factory_kwargs)
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        # changes from original nn.Linear: iterating over num_llms. each entry is consider a regular nn.Linear
        for i in range(self.num_llms):
            nn.init.kaiming_uniform_(self.weight(torch.tensor(i)).reshape(-1, self.in_features), a=math.sqrt(5))
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight(torch.tensor(i)).reshape(-1, self.in_features))
                bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
                nn.init.uniform_(self.bias(torch.tensor(i)), -bound, bound)

    def forward(self, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
        # input shape: (B, L, N, d)
        # indices shape: (B)
        weights = self.weight(indices).view(-1, self.out_features, self.in_features) # shape: (B, out_features, in_features)
        biases = self.bias(indices) # shape: (B, out_features)
        out = torch.einsum('blnd,bod->blno', input.to(torch.float32), weights)
        out = out + biases.unsqueeze(1).unsqueeze(1)
        return out

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"

class ACT_Vit_foundation(nn.Module):
    def __init__(self, args, input_dim, input_shape):
        super(ACT_Vit_foundation, self).__init__()
        
        self.input_dim = input_dim
        self.input_shape = input_shape #
        self.hidden_dim = args.hidden_dim
        self.heads = args.heads
        self.dropout = args.dropout
        self.num_layers = args.num_layers
        self.patch_size = eval(args.patch_size)
        
        
        # Unpack input shape
        L, N, _ = self.input_shape
        
        assert L % self.patch_size[0] == 0 and N % self.patch_size[1] == 0, "Patch size must be divisible by input shape"

        in_dim_per_llm = list(FEATURE_DIMS.values()) # should match the llm indices
        self.adapters_slow = ModuleListPerLLMLinear(max(in_dim_per_llm), self.hidden_dim, num_llms=len(in_dim_per_llm))

        # Create the ViT module from vit_pytorch
        self.vit = ViT(
            image_size=(L, N),
            patch_size=self.patch_size,
            num_classes=1,
            dim=self.hidden_dim,
            depth=self.num_layers,
            heads=self.heads,
            mlp_dim=self.hidden_dim,
            dropout=self.dropout,
            emb_dropout=self.dropout,
            channels=self.hidden_dim,
            pool=args.pool
        )
        
        def pair(t):
            return t if isinstance(t, tuple) else (t, t)
        
        patch_height, patch_width = pair(self.patch_size)
        patch_dim = self.hidden_dim * patch_height * patch_width
        
        image_height, image_width = pair((L, N))
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, self.input_dim, patch_height * patch_width))
        # Override patch embedding method
        self.vit.to_patch_embedding = PatchEmbeddingWithPos(
            patch_height=patch_height,
            patch_width=patch_width,
            patch_dim=patch_dim,
            hidden_dim=self.hidden_dim,
        )

        
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, activations, indices):
        # fast version
        # x = self.adapters_fast(activations, indices)
        x = self.adapters_slow(activations, indices)
            
        # Input x is of shape (B, L, N, d)
        B, L, N, d = x.shape 

        # Permute the input to (B, d, L, N) which is (B, channels, height, width)
        x = x.permute(0, 3, 1, 2)
        
        # Forward pass through ViT
        preds = self.vit(x)  # Output shape: (B, num_classes)
        preds = self.sigmoid(preds).squeeze(-1)
        return preds


######################## Activations ########################