import torch.nn as nn

from modules import LinearizedConv2d, LinearizedVIT, MemEffAttention, MLP, LinearAttention, SwitchMoE, RNNAtt
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

REPLACERS = {
    nn.Conv2d: LinearizedConv2d,
    nn.MultiheadAttention: LinearAttention,
    nn.Linear: LinearizedVIT,
    MemEffAttention: MLP,
    Qwen2Attention: SwitchMoE,
    GPT2Attention: RNNAtt
}