
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import namedtuple
import os
import sys
import pickle
import time

curr_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, curr_path)

class BackboneWrapper(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.model_type = config["model_type"]

        if self.model_type == "transformer":
            from transformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "conv":
            from conv import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "nystromformer":
            from nystromformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "ammformer":
            from ammformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "linformer":
            from linformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "performer":
            from performer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "longformer":
            from longformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "bigbird":
            from bigbird import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "reformer":
            from reformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "htransformer1d":
            from htransformer1d import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "scatterbrain":
            from scatterbrain import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "soft":
            from soft import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "mra2":
            from mra2 import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "yoso_attention":
            from yoso_attention import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "hier_attention":
            from hier_attention import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "linear_transformer":
            from linear_transformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "mra_head":
            from mra_head import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "mra2_mra_head":
            from mra2_mra_head import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "linear_transformer_mra_head":
            from linear_transformer_mra_head import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "fourier_transformer":
            from fourier_transformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "mra_head_cuda":
            from mra_head_cuda import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "fmm_transformer":
            from fmm_transformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "luna_transformer":
            from luna_transformer import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "linformer_mra_head":
            from linformer_mra_head import Backbone
            self.backbone = Backbone(config)
        elif self.model_type == "fmm_transformer_mra_head":
            from fmm_transformer_mra_head import Backbone
            self.backbone = Backbone(config) 
        elif self.model_type == "performer_mra_head":
            from performer_mra_head import Backbone
            self.backbone = Backbone(config) 
        elif self.model_type == "luna_transformer_mra_head":
            from luna_transformer_mra_head import Backbone
            self.backbone = Backbone(config) 
        else:
            raise Exception()


    def forward(self, X, mask):
        return self.backbone(X, mask)
