from itertools import repeat
from typing import List, Literal, Optional, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange

from ..layers import CosformerLayer, PerformerLayer, VanillaTransformerLayer
from .util import select_pe_encoder
from celldiff.util import check_str_option
from celldiff.modules.attention import FeedForward
from celldiff.models.gears.model import GEARS_Conditioner

ATTN_MASK_MODE = Literal["nonzero", "subset_nonzero"]


def create_activation(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "gelu":
        return nn.GELU()
    elif name == "glu":
        return nn.GLU()
    elif name == "sigmoid":
        return nn.Sigmoid()
    elif name == "prelu":
        return nn.PReLU()
    elif name == "elu":
        return nn.ELU()
    elif name is None:
        return nn.Identity()
    else:
        raise NotImplementedError(f"{name} is not implemented.")


def create_norm(name, n, h=16):
    if name == "layernorm":
        return nn.LayerNorm(n)
    elif name == "batchnorm":
        return nn.BatchNorm1d(n)
    elif name == "groupnorm":
        return nn.GroupNorm(h, n)
    elif name is None:
        return nn.Identity()
    else:
        raise NotImplementedError(f"{name} is not implemented.")


def batch_apply_norm(norm: nn.Module, x: torch.Tensor) -> torch.Tensor:
    if len(x.shape) == 2:  # (length, channel)
        return norm(x)
    elif len(x.shape) == 3:  # (batch, length, channel)
        if isinstance(norm, nn.Identity):
            return x
        elif isinstance(norm, nn.LayerNorm):
            return norm(x)
        elif isinstance(norm, (nn.BatchNorm1d, nn.GroupNorm)):
            return norm(x.transpose(-1, -2)).transpose(-1, -2)
        else:
            raise NotImplementedError(f"{norm!r} not supported yet")
    else:
        raise ValueError(f"Invalid dimension of x: {x.shape=}")


class MLPLayers(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, num_layers, dropout, norm=None, act="prelu"):
        super().__init__()
        layer_dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(len(layer_dims)-1):
            self.layers.append(nn.Sequential(
                nn.Linear(layer_dims[i], layer_dims[i+1]),
                create_activation(act),
                nn.Dropout(dropout),
            ))
            self.norms.append(create_norm(norm, layer_dims[i+1]))

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x)
            x = batch_apply_norm(norm, x)
        return x


class ResMLPLayers(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim, num_layers, dropout, norm):
        super().__init__()
        assert num_layers > 1, 'At least two layers for MLPs.'
        layer_dims = [in_dim] + [hidden_dim * (num_layers - 1)] + [out_dim]
        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        for i in range(len(layer_dims)-2):
            self.layers.append(nn.Sequential(
                nn.Linear(layer_dims[i], layer_dims[i+1]),
                nn.PReLU(),
                nn.Dropout(dropout),
            ))
            self.norms.append(create_norm(norm, layer_dims[i+1]))
        self.out_layer = nn.Sequential(
            nn.Linear(sum(layer_dims[:-1]), layer_dims[-1]),
            nn.PReLU(),
            nn.Dropout(dropout),
        )
        self.out_norm = create_norm(norm, layer_dims[-1])

    def forward(self, x):
        hist = []
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x)
            x = batch_apply_norm(norm, x)
            hist.append(x)
        out = self.out_layer(torch.cat(hist, 1))
        out = batch_apply_norm(self.out_norm, out)
        return out


class EmbeddingDict(nn.Module):
    TEXT_EMB_DIR = './data/ontology_resources'

    def __init__(self, num_embed_dict, embedding_dim, depth, embedding_tokens=1,
                 norm_layer=None, freeze=False, mask_ratio=0.0, text_emb=None,
                 text_emb_file=None, freeze_text_emb=True, text_proj_type='linear',
                 stackfnn_glu_flag=False, text_proj_hidden_dim=512, text_proj_act=None,
                 text_proj_num_layers=2, text_proj_norm=None, text_proj_dropout=0.,
                 gears_flag=False, gears_mode="single", num_perts=None, gears_hidden_size=64,
                 gears_mlp_layers=2, gears_norm=None, num_go_gnn_layers=1):
        super().__init__()
        size = embedding_dim * embedding_tokens
        n = embedding_tokens
        d = embedding_dim

        self.keys = sorted(num_embed_dict)  # ensure consistent ordering
        self.mask_ratio = mask_ratio

        self.emb_dict = nn.ModuleDict()
        for key in self.keys:
            self.emb_dict[key] = nn.ModuleList([
                nn.Sequential(
                    nn.Embedding(
                        num_embed_dict[key],
                        size,
                        _freeze=freeze,
                    ),
                    create_norm(norm_layer, size),
                    Rearrange('b (n d) -> b n d', n=n, d=d),
                )
                for _ in range(depth)
            ])

        if text_emb is not None or text_emb_file is not None:
            if text_emb is None:
                text_emb = torch.load(f'{self.TEXT_EMB_DIR}/{text_emb_file}')
            if text_proj_type == 'linear':
                text_proj = nn.Linear(text_emb.shape[1], size)
            elif text_proj_type == 'stackffn':
                text_proj = FeedForward(text_emb.shape[1], dim_out=size, mult=4, glu=stackfnn_glu_flag)
            elif text_proj_type == 'mlp':
                text_proj = MLPLayers(text_emb.shape[1], size, text_proj_hidden_dim, text_proj_num_layers,
                                      text_proj_dropout, text_proj_norm, text_proj_act)
            else:
                raise NotImplementedError(f"Unsupported text_proj_type {text_proj_type}")

            text_act = create_activation(text_proj_act)
            if text_proj_norm is None and norm_layer is not None:
                text_norm = create_norm(norm_layer, size)
            else:
                text_norm = create_norm(text_proj_norm, size)
            self.keys.append("text")
            self.emb_dict['text'] = nn.ModuleList([
                nn.Sequential(
                    nn.Embedding.from_pretrained(text_emb, freeze=freeze_text_emb),
                    text_proj,
                    text_norm,
                    text_act,
                    Rearrange('b (n d) -> b n d', n=n, d=d),
                )
                for _ in range(depth)
            ])

        if num_perts is not None and gears_flag:
            self.keys.append('pert')
            self.gears_mode = gears_mode
            gears_kwargs = dict(num_perts=num_perts, out_dim=size, mode=gears_mode,
                                hidden_size=gears_hidden_size, mlp_layers=gears_mlp_layers)
            if gears_mode == "single":
                self.emb_dict['pert'] = nn.ModuleList([
                    nn.Sequential(
                        GEARS_Conditioner(num_go_gnn_layers=num_go_gnn_layers, **gears_kwargs),
                        create_norm(gears_norm, size),
                        Rearrange('b (n d) -> b n d', n=n, d=d),
                    )
                    for _ in range(depth)
                ])
            else:
                self.emb_dict['pert'] = nn.ModuleList([
                    GEARS_Conditioner(num_go_gnn_layers=depth, **gears_kwargs),
                    nn.ModuleList([create_norm(gears_norm, size) for _ in range(depth)]),
                    Rearrange('b (n d) -> b n d', n=n, d=d),
                ])

    def __iter__(self):
        yield from self.keys

    def __getitem__(self, key):
        return self.emb_dict[key]

    def forward(self, input: Dict[str, torch.Tensor], aug_graph=None) -> List[torch.Tensor]:
        # Outer list: condition types; inner list: layer depth
        out = []
        for key in self.keys:
            if self.training:
                # NOTE: NULL condition token added during dataset init, and is
                # set to be the first token (index zero).
                mask = torch.rand_like(input[key].float()) < self.mask_ratio
                masked_input = input[key].long()
                if key != 'text' and key != "pert":
                    masked_input[mask] = 0
            else:
                masked_input = input[key].long()

            if (
                isinstance(self[key][0], GEARS_Conditioner)  # single
                or isinstance(self[key][0][0], GEARS_Conditioner)  # parallel | sequential
            ):
                emb_list = []
                if self.gears_mode == "single":
                    for emb in self[key]:
                        gears_out = emb[0](masked_input, aug_graph)
                        emb_list.append(emb[1:](gears_out))
                else:
                    gears_out = self[key][0](masked_input, aug_graph)
                    stack = zip(gears_out, self[key][1], repeat(self[key][2]))
                    for emb, norm, rearrange in stack:
                        emb_list.append(rearrange(norm(emb)))
            else:
                emb_list = [emb(masked_input) for emb in self[key]]

            out.append(emb_list)

        # Consolidate by concatenating along the token dimention in each layer
        out = [torch.cat(embs, dim=1) for embs in zip(*out)]

        return out


class EmbeddingList(EmbeddingDict):
    def __init__(self, num_embed_list, embedding_dim, depth,
                 embedding_tokens=1, freeze=False):
        num_embed_dict = {i: j for i, j in enumerate(num_embed_list)}
        super().__init__(num_embed_dict, embedding_dim, depth, embedding_tokens, freeze)

    def forward(self, input: List[torch.Tensor]) -> List[torch.Tensor]:
        input_dict = {i: j for i, j in enumerate(input)}
        return super().forward(input_dict)


class PostConditioner(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0., norm_type="layernorm", num_layers=1, cond_num_dict=None,
                 cond_type='add', cond_emb_dim=None, cond_mask_ratio=0., act="gelu", out_act=None):
        super().__init__()
        if isinstance(act, str) or act is None:
            act = create_activation(act)
        if isinstance(out_act, str) or out_act is None:
            out_act = create_activation(out_act)
        assert cond_type == 'add'

        self.cond_type = cond_type
        self.cond_num_dict = cond_num_dict
        if self.cond_num_dict is not None:
            cond_emb_dim = cond_emb_dim if cond_emb_dim is not None else in_dim
            if self.cond_type == 'add':
                assert cond_emb_dim == in_dim, "cond_emb_dim must match in_dim when cond_type == 'add'."
            self.cond_embed = EmbeddingDict(cond_num_dict, cond_emb_dim, 1, 1, None, mask_ratio=cond_mask_ratio)
            # self.cond_embed = EmbeddingDict(cond_num_dict, cond_emb_dim, num_layers, 1, norm_layer,
            #                                 mask_ratio=cond_mask_ratio)
        else:
            self.cond_embed = None

        self.layers = nn.ModuleList()
        dim = in_dim + cond_emb_dim if cond_num_dict is not None and cond_type == 'concat' else in_dim
        for _ in range(num_layers - 1):
            self.layers.append(nn.Sequential(
                nn.Linear(dim, dim),
                act,
                create_norm(norm_type, dim),
                nn.Dropout(dropout),
            ))
        self.layers.append(nn.Sequential(nn.Linear(dim, out_dim), out_act))

    def forward(self, x, conditions=None):
        if self.cond_embed is not None:
            cond_emb = self.cond_embed(conditions)[0]
            x = x + cond_emb.squeeze(1)

        for layer in self.layers:
            # if self.cond_num_dict is not None:
            #     if self.cond_type == 'add':
            #         x = x + cond_emb_list[i].squeeze(1)
            #     elif self.cond_type == 'concat':
            #         x = torch.hstack([x, cond_emb_list[i].squeeze(1)])
            #     else:
            #         raise NotImplementedError(f"Unsupported cond_type {self.cond_type}")
            x = layer(x)

        return x


class TransformerLayers(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_layers, dropout, activation, model_type='performer'):
        super().__init__()
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList()
        if model_type == 'cosformer':
            TransformerLayer = CosformerLayer
        elif model_type == 'performer':
            TransformerLayer = PerformerLayer
        elif model_type == 'transformer':
            TransformerLayer = VanillaTransformerLayer
        else:
            raise NotImplementedError(f'Not implemented transformer type: {model_type}')

        for i in range(num_layers):
            self.layers.append(
                TransformerLayer(
                    embed_dim=hidden_dim, num_heads=num_heads,
                    dropout=dropout)
            )

        self.act = create_activation(activation)

    def forward(self, x, output_attentions=False):
        attn_list = []
        for l in range(self.num_layers):
            if output_attentions:
                x, attn = self.layers[l](x, output_attentions=True)
                attn_list.append(attn)
            else:
                x = self.layers[l](x)

        if output_attentions:
            return x, attn_list
        else:
            return x


class Encoder(nn.Module):
    def __init__(self, in_dim, latent_dim, hidden_dim, num_layers, dropout, norm,
                 layer_type='mlp', attn_type='performer', num_heads=8, activation='gelu', **kwargs):
        super().__init__()
        self.layer_type = layer_type
        self.mods = nn.ModuleList()
        if layer_type == 'mlp':
            self.mods.append(MLPLayers(in_dim, latent_dim, hidden_dim, num_layers, dropout, norm))
        elif layer_type == 'resmlp':
            self.mods.append(ResMLPLayers(in_dim, latent_dim, hidden_dim, num_layers, dropout, norm))
        elif layer_type == 'trans':
            self.mods.append(nn.Linear(in_dim, hidden_dim))
            self.mods.append(TransformerLayers(hidden_dim, num_heads, num_layers, dropout, activation, attn_type))
            self.mods.append(nn.Linear(hidden_dim, latent_dim))

    def forward(self, x, output_attentions=False):
        if self.layer_type == 'trans' and output_attentions:
            mod_count = 0
            for mod in self.mods:
                if mod_count == 1:
                    x, attn_list = mod(x)
                else:
                    x = mod(x)
                mod_count += 1
            return x, attn_list
        else:
            for mod in self.mods:
                x = mod(x)
            return x


class Decoder(nn.Module):
    def __init__(self, latent_dim, out_dim, hidden_dim, num_layers, dropout, norm,
                 layer_type='mlp', attn_type='performer', num_heads=8, activation='gelu', **kwargs):
        super().__init__()
        self.layer_type = layer_type
        self.mods = nn.ModuleList()
        if layer_type == 'mlp':
            self.mods.append(MLPLayers(latent_dim, out_dim, hidden_dim, num_layers, dropout, norm))
        elif layer_type == 'resmlp':
            self.mods.append(ResMLPLayers(latent_dim, out_dim, hidden_dim, num_layers, dropout, norm))
        elif layer_type == 'trans':
            self.mods.append(nn.Linear(latent_dim, hidden_dim))
            self.mods.append(TransformerLayers(hidden_dim, num_heads, num_layers, dropout, activation, attn_type))
            self.mods.append(nn.Linear(hidden_dim, out_dim))

    def forward(self, x, output_attentions=False):
        if self.layer_type == 'trans' and output_attentions:
            mod_count = 0
            for mod in self.mods:
                if mod_count == 1:
                    x, attn_list = mod(x)
                else:
                    x = mod(x)
                mod_count += 1
            return x, attn_list
        else:
            for mod in self.mods:
                x = mod(x)
            return x


class OmicsEmbedder(nn.Module):
    def __init__(self, pretrained_gene_list, num_hid, gene_emb=None, fix_embedding=False):
        super().__init__()
        self.pretrained_gene_list = pretrained_gene_list
        self.gene_index = dict(zip(pretrained_gene_list, list(range(len(pretrained_gene_list)))))
        self.num_hid = num_hid

        if gene_emb is not None:
            self.emb = nn.Parameter(gene_emb, requires_grad=not fix_embedding)
        else:
            self.emb = nn.Parameter(torch.randn([len(pretrained_gene_list), num_hid], dtype=torch.float32)*0.005)
            if fix_embedding:
                self.emb.requires_grad = False

    def forward(self, x, input_gene_list=None, input_gene_idx=None):
        if input_gene_idx is not None:
            gene_idx = input_gene_idx
        elif input_gene_list is not None:
            gene_idx = torch.tensor([self.gene_index[o] for o in input_gene_list if o in self.gene_index]).long()
        else:
            if x.shape[1] != len(self.pretrained_gene_list):
                raise ValueError(
                    'The input gene size is not the same as the pretrained gene list. Please provide the input gene list.')
            gene_idx = torch.arange(x.shape[1]).long()
        gene_idx = gene_idx.to(x.device)
        feat = F.embedding(gene_idx, self.emb)
        feat = torch.sparse.mm(x, feat)
        return feat, gene_idx


class OmicsEmbeddingLayer(nn.Module):
    def __init__(self, gene_list, num_hidden, norm, activation='gelu', dropout=0.3, pe_type=None, cat_pe=False, gene_emb=None):
        super().__init__()

        self.pe_type = pe_type
        self.cat_pe = cat_pe
        self.act = create_activation(activation)
        self.norm0 = create_norm(norm, num_hidden)
        self.dropout = nn.Dropout(dropout)
        if pe_type is not None:
            if cat_pe:
                num_emb = num_hidden // 2
            else:
                num_emb = num_hidden
            self.pe_enc = select_pe_encoder(pe_type)(num_emb)
        else:
            self.pe_enc = None
            num_emb = num_hidden

        self.feat_enc = OmicsEmbedder(gene_list, num_emb, gene_emb)

    def forward(self, x, pe_input=None, input_gene_list=None, input_gene_idx=None):
        x, gene_idx = self.feat_enc(x, input_gene_list, input_gene_idx)
        x = self.act(x)
        if self.pe_enc is not None and pe_input is not None:
            pe = self.pe_enc(pe_input)
            if self.cat_pe:
                x = torch.cat([x, pe], 1)
            else:
                x = x + pe

        x = self.norm0(self.dropout(x))
        return x, gene_idx


class FeatureEmbedding(nn.Module):

    def __init__(
        self,
        feat_ids: List[str],
        dim: int,
        feat_emb: Optional[torch.Tensor] = None,
        fix_embedding: bool = False,
        scale_grad_by_freq: bool = False,
    ):
        super().__init__()

        self.feat_ids = feat_ids
        self.feat_id_to_idx = {j: i for i, j in enumerate(feat_ids)}

        self.size = len(feat_ids)
        self.dim = dim

        # Initialize embeddings
        self.emb = nn.Embedding(self.size, self.dim, scale_grad_by_freq=scale_grad_by_freq)
        if feat_emb is None:
            feat_emb = torch.randn(self.size, self.dim).mul_(0.005)
        self.emb.weight.data[:] = feat_emb

        if fix_embedding:
            self.emb.requires_grad_(False)

    def get_feat_idx(
        self,
        x: torch.Tensor,
        input_feat_ids: Optional[List[str]] = None,
    ) -> torch.Tensor:
        if input_feat_ids is None:
            if x.shape[1] != self.size:
                raise ValueError(
                    "The input gene size is not the same as the pretrained "
                    "gene list. Please provide the input gene list.",
                )
            feat_idx = torch.arange(self.size)
        else:
            feat_idx = torch.tensor([self.feat_id_to_idx[i] for i in input_feat_ids])
        return feat_idx.to(x.device)

    def forward(
        self,
        x: Optional[torch.Tensor] = None,
        input_feat_ids: Optional[List[str]] = None,
        *,
        return_all: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if return_all:
            feat, feat_idx = self.emb.weight, None
        else:
            feat_idx = self.get_feat_idx(x, input_feat_ids)
            feat = self.emb(feat_idx)
        return feat, feat_idx


class ExpressionEmbedding(nn.Module):

    def __init__(
        self,
        dim_hid: int,
        dim_out: Optional[int] = None,
        num_layers: int = 3,
        dropout: float = 0.0,
        norm: Optional[str] = "batchnorm",
        act: Optional[str] = "relu",
    ):
        super().__init__()
        dim_out = dim_out or dim_hid
        self.mlp = MLPLayers(1, dim_out, dim_hid, num_layers, dropout, norm, act)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class OmicsFeatureEmbeddingLayer(nn.Module):
    def __init__(
        self,
        feat_ids: List[str],
        norm: Optional[str],
        feat_emb_dim: int,
        exp_emb_dim: Optional[int] = None,
        dropout: float = 0.3,
        cat: bool = False,
        feat_emb: Optional[torch.Tensor] = None,
        fix_feat_embedding: bool = False,
        exp_emb_num_layers: int = 3,
        exp_emb_norm: Optional[str] = "batchnorm",
        exp_emb_act: Optional[str] = "relu",
    ):
        super().__init__()
        exp_emb_dim = exp_emb_dim or feat_emb_dim

        self.cat = cat
        if self.cat:
            hid_dim = feat_emb_dim + exp_emb_dim
        else:
            if feat_emb_dim != exp_emb_dim:
                raise ValueError(
                    "Feature embedding dimension must be equal to the expression "
                    "embedding when concat mode is turned off.",
                )
            hid_dim = feat_emb_dim

        self.feat_enc = FeatureEmbedding(
            feat_ids,
            feat_emb_dim,
            feat_emb,
            fix_feat_embedding,
        )
        self.exp_enc = ExpressionEmbedding(
            exp_emb_dim,
            num_layers=exp_emb_num_layers,
            norm=exp_emb_norm,
            act=exp_emb_act,
        )

        self.norm = create_norm(norm, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        sample_pe_input: Optional[torch.Tensor] = None,
        input_feat_ids: Optional[List[str]] = None,
    ) -> Tuple[torch.Tensor, List[str]]:
        # input: (num_samples, num_tokens)
        # output: (num_samples, num_tokens, emb_dim)

        if sample_pe_input is not None:
            import warnings
            warnings.warn(
                "sample_pe_input passed in but not being used currently",
                UserWarning,
                stacklevel=2,
            )

        # feat_emb: (num_tokens, emb_dim)
        feat_emb, feat_idx = self.feat_enc(x, input_feat_ids)
        # exp_emb: (num_samples, num_tokens, emb_dim)
        exp_emb = self.exp_enc(x.unsqueeze(-1))

        # h: (num_samples, num_tokens, emb_dim)
        feat_emb = feat_emb.unsqueeze(0).repeat(x.shape[0], 1, 1)
        if self.cat:
            h = torch.cat((feat_emb, exp_emb), dim=-1)
        else:
            h = feat_emb + exp_emb

        out = batch_apply_norm(self.norm, h)
        out = self.dropout(h)

        return out, feat_idx


class BasePathwayEncodingLayer(nn.Module):

    def __init__(
        self,
        feat_ids: List[str],
        norm: Optional[str],
        feat_emb_dim: int,
        num_pathways: int = 64,
        path_emb_dim: Optional[int] = None,
        exp_emb_dim: Optional[int] = None,
        dropout: float = 0.3,
        feat_emb: Optional[torch.Tensor] = None,
        path_emb: Optional[torch.Tensor] = None,
        fix_feat_embedding: bool = False,
        fix_path_embedding: bool = False,
        exp_emb_num_layers: int = 3,
        exp_emb_norm: Optional[str] = "batchnorm",
        exp_emb_act: Optional[str] = "relu",
    ):
        # :(encode) feat emb + exp emb -> (decode) path emb
        super().__init__()
        exp_emb_dim = exp_emb_dim or feat_emb_dim
        path_emb_dim = path_emb_dim or feat_emb_dim
        path_ids = list(range(num_pathways))
        self.num_pathways = num_pathways

        if (feat_emb_dim != exp_emb_dim) or (exp_emb_dim != path_emb_dim):
            raise ValueError(
                "Feature embedding, pathway embedding, and expression embedding "
                "must all be equal in the current implementation.",
            )
        hid_dim = feat_emb_dim

        self.exp_enc = ExpressionEmbedding(
            exp_emb_dim,
            num_layers=exp_emb_num_layers,
            norm=exp_emb_norm,
            act=exp_emb_act,
        )
        self.feat_enc = FeatureEmbedding(
            feat_ids,
            feat_emb_dim,
            feat_emb,
            fix_feat_embedding,
        )
        self.path_enc = FeatureEmbedding(
            path_ids,
            path_emb_dim,
            path_emb,
            fix_path_embedding,
        )
        from celldiff.modules.attention import BasicTransformerBlock
        self.path_dec = BasicTransformerBlock(
            hid_dim,
            n_heads=4,
            cross_attn=True,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        sample_pe_input: Optional[torch.Tensor] = None,
        input_feat_ids: Optional[List[str]] = None,
        attn_mask_mode: Optional[ATTN_MASK_MODE] = None,
    ) -> Tuple[torch.Tensor, List[str]]:
        # input: (num_samples, num_feat_tokens)
        # output: (num_samples, num_path_tokens, emb_dim)

        if sample_pe_input is not None:
            import warnings
            warnings.warn(
                "sample_pe_input passed in but not being used currently",
                UserWarning,
                stacklevel=2,
            )

        attn_mask_mode = check_str_option("attn_mask_mode", attn_mask_mode, ATTN_MASK_MODE)
        if attn_mask_mode == "subset_nonzero":
            # observed genes in batch only (better to move to loader)
            if input_feat_ids is None:
                input_feat_ids = self.feat_enc.feat_ids
            else:
                x = x[:, list(map(self.feat_enc.feat_ids.index, input_feat_ids))]

            nz = x != 0
            ind = nz.any(0)  # only genes that are observed in the batch

            x = x[:, ind]
            input_feat_ids = [j for i, j in zip(ind, input_feat_ids) if i]

            # (num_samples, num_path_tokens, num_feat_tokens)
            attn_mask = (x != 0).unsqueeze(1).repeat(1, self.num_pathways, 1)
        elif attn_mask_mode == "nonzero":
            attn_mask = (x != 0).unsqueeze(1).repeat(1, self.num_pathways, 1)
        else:
            attn_mask = None

        # exp_emb: (num_samples, num_feat_tokens, emb_dim)
        exp_emb = self.exp_enc(x.unsqueeze(-1))
        # feat_emb: (num_feat_tokens, emb_dim)
        feat_emb, feat_idx = self.feat_enc(x, input_feat_ids)
        # path_emb: (num_path_tokens, emb_dim)
        path_emb, _ = self.path_enc(return_all=True)

        # h: (num_samples, num_gene_tokens, emb_dim)
        h = feat_emb.unsqueeze(0) + exp_emb
        h = self.dropout(h)

        # out: (num_samples, num_path_tokens, emb_dim)
        path_emb = path_emb.unsqueeze(0).repeat(h.shape[0], 1, 1)
        out = self.path_dec(path_emb, h, cross_mask=attn_mask)
        out = self.dropout(out)

        return out, feat_idx


class LinearPathwayEncodingLayer(BasePathwayEncodingLayer):
    ...


class MLPPathwayEncodingLayer(BasePathwayEncodingLayer):
    ...


class TransformerPathwayEncodingLayer(BasePathwayEncodingLayer):
    ...
