# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
import warnings

import numpy as np
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

from celldiff.modules.attention import BasicTransformerBlock, FeedForward
from celldiff.modules.diffusionmodules.util import (
    ConditionEncoderWrapper,
    MaskedEncoderConditioner,
    timestep_embedding,
)
from celldiff.util import default
from .scmodel import (
    OmicsEmbeddingLayer,
    OmicsFeatureEmbeddingLayer,
    PostConditioner,
    EmbeddingDict,
    create_activation,
    create_norm,
)


class MaskedAutoencoder(nn.Module):
    def __init__(self, pretrained_gene_list, input_gene_list=None, dropout=0., cell_mask_ratio=0.75, mask_context=True,
                 encoder_type='stackffn', embed_dim=1024, depth=4, dim_head=64, num_heads=4, pe_type=None,
                 feat_mask_ratio=0., decoder_embed_dim=512, decoder_embed_type='linear', decoder_num_heads=4,
                 decoder_dim_head=64, cond_dim=None, cond_tokens=1, cond_type='crossattn', cond_strategy='full_mix',
                 cond_emb_type='linear', cond_num_dict=None, cond_mask_ratio=0.5, cond_cat_input=False,
                 post_cond_num_dict=None, post_cond_layers=2, post_cond_norm='layernorm', post_cond_type='add',
                 post_cond_mask_ratio=0.0, norm_layer='layernorm', mlp_time_embed=False, no_time_embed=False,
                 activation='gelu', mask_strategy='random', mask_mode='v1', mask_dec_cond=False,
                 mask_dec_cond_ratio=False, mask_dec_cond_se=False, mask_dec_cond_semlp=False,
                 mask_dec_cond_concat=False, mask_value=0, pad_value=0, decoder_mask=None, text_emb=None,
                 text_emb_file=None, freeze_text_emb=True, text_proj_type='linear', text_proj_act=None,
                 stackfnn_glu_flag=False, text_proj_hidden_dim=512, text_proj_num_layers=2, text_proj_norm=None,
                 cond_emb_norm=None, num_perts=None, gears_flag=False, gears_hidden_size=64,
                 gears_mode="single", gears_mlp_layers=2, gears_norm=None, num_go_gnn_layers=1):
        super().__init__()
        self.depth = depth

        # --------------------------------------------------------------------------
        # MAE masking options
        self.cell_mask_ratio = cell_mask_ratio
        self.feat_mask_ratio = feat_mask_ratio
        self.mask_context = mask_context
        self.mask_mode = mask_mode
        self.mask_strategy = mask_strategy
        self.mask_value = mask_value
        self.pad_value = pad_value
        self.decoder_mask = decoder_mask
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        activation = create_activation(activation)
        # self.in_dim = len(input_gene_list) if input_gene_list is not None else len(pretrained_gene_list)
        self.in_dim = len(pretrained_gene_list) if pretrained_gene_list is not None else len(input_gene_list)
        self.pretrained_gene_list = pretrained_gene_list
        self.input_gene_list = input_gene_list
        pretrained_gene_index = dict(zip(self.pretrained_gene_list, list(range(len(self.pretrained_gene_list)))))
        self.input_gene_idx = torch.tensor([
            pretrained_gene_index[o] for o in self.input_gene_list
            if o in pretrained_gene_index
        ]).long() if self.input_gene_list is not None else None

        assert embed_dim == decoder_embed_dim  # XXX: this seems to be required for MAE (see forward dec)?
        full_embed_dim = embed_dim * cond_tokens
        self.post_encoder_layer = Rearrange('b (n d) -> b n d', n=cond_tokens, d=embed_dim)

        self.encoder_embed = OmicsEmbeddingLayer(pretrained_gene_list, full_embed_dim, 'layernorm',
                                                 dropout=dropout, pe_type=pe_type)

        self.encoder_type = encoder_type
        if encoder_type == 'attn':
            self.blocks = nn.ModuleList([
                BasicTransformerBlock(full_embed_dim, num_heads, dim_head, self_attn=True, cross_attn=False,
                                      dropout=dropout, qkv_bias=True, final_act=activation)
                for _ in range(depth)])
        elif encoder_type in ('mlp', 'mlpparallel'):
            self.blocks = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(full_embed_dim, full_embed_dim),
                    activation,
                    create_norm(norm_layer, full_embed_dim),
                ) for _ in range(depth)])
        elif encoder_type in ('stackffn', 'ffnparallel'):
            self.blocks = nn.ModuleList([
                # FeedForward(full_embed_dim, mult=4, glu=False, dropout=dropout)
                nn.Sequential(
                    FeedForward(full_embed_dim, mult=4, glu=False, dropout=dropout),
                    create_norm(norm_layer, full_embed_dim),
                ) for _ in range(depth)])
        elif encoder_type == 'none':
            self.blocks = None
        else:
            raise ValueError(f'Unknown encoder type {encoder_type}')
        # self.encoder_proj = nn.Linear(full_embed_dim, latent_dim)
        # self.norm = create_norm(norm_layer, full_embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.subset_output = True
        self.decoder_embed_dim = decoder_embed_dim
        self.time_embed = nn.Sequential(
            nn.Linear(decoder_embed_dim, 4 * decoder_embed_dim),
            nn.SiLU(),
            nn.Linear(4 * decoder_embed_dim, decoder_embed_dim),
        ) if mlp_time_embed else nn.Identity()
        self.no_time_embed = no_time_embed

        self.cond_type = cond_type
        assert cond_strategy in ("full_mix", "pre_mix")
        self.cond_strategy = cond_strategy
        self.cond_emb_type = cond_emb_type
        self.cond_tokens = cond_tokens
        self.cond_cat_input = cond_cat_input
        if cond_dim is not None or cond_num_dict is not None:
            if cond_emb_type == 'linear':
                assert cond_dim is not None
                self.cond_embed = nn.Sequential(
                    nn.Linear(cond_dim, decoder_embed_dim * cond_tokens),
                    Rearrange('b (n d) -> b n d', n=cond_tokens, d=decoder_embed_dim),
                )
            elif cond_emb_type == 'embedding':
                assert cond_num_dict is not None
                self.cond_embed = EmbeddingDict(cond_num_dict, decoder_embed_dim, depth,
                                                cond_tokens, mask_ratio=cond_mask_ratio,
                                                text_emb=text_emb, text_emb_file=text_emb_file,
                                                norm_layer=cond_emb_norm,
                                                freeze_text_emb=freeze_text_emb,
                                                text_proj_type=text_proj_type,
                                                text_proj_num_layers=text_proj_num_layers,
                                                stackfnn_glu_flag=stackfnn_glu_flag,
                                                text_proj_hidden_dim=text_proj_hidden_dim,
                                                text_proj_act=text_proj_act,
                                                text_proj_norm=text_proj_norm,
                                                # text_proj_dropout=dropout, G_go=G_go,
                                                # G_go_weight=G_go_weight, num_perts=num_perts,
                                                text_proj_dropout=dropout, gears_flag=gears_flag, num_perts=num_perts,
                                                gears_hidden_size=gears_hidden_size, gears_mode=gears_mode,
                                                gears_mlp_layers=gears_mlp_layers, gears_norm=gears_norm,
                                                num_go_gnn_layers=num_go_gnn_layers)
            elif cond_emb_type == 'none':
                self.cond_embed = None
            else:
                raise ValueError(f"Unknwon condition embedder type {cond_emb_type}")
        else:
            self.cond_embed = None

        if cond_type == 'crossattn':
            self.cond_enc_blocks = nn.ModuleList([
                BasicTransformerBlock(decoder_embed_dim, decoder_num_heads, decoder_dim_head,
                                      self_attn=False, cross_attn=True, context_dim=embed_dim,
                                      qkv_bias=True, dropout=dropout, final_act=activation)
                for _ in range(depth)])
        elif cond_type == 'mlp':
            self.cond_enc_blocks = nn.ModuleList([
                ConditionEncoderWrapper(nn.Sequential(
                    nn.Linear(decoder_embed_dim, decoder_embed_dim),
                    activation,
                    create_norm(norm_layer, decoder_embed_dim),
                    nn.Dropout(dropout),
                )) for _ in range(depth)])
        elif cond_type == 'stackffn':
            self.cond_enc_blocks = nn.ModuleList([
                ConditionEncoderWrapper(
                    FeedForward(decoder_embed_dim, mult=4, glu=False, dropout=dropout)
                ) for _ in range(depth)])
        else:
            raise ValueError(f'Unknown conditioning type {cond_type!r}')

        # self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim))
        self.decoder_embed_type = decoder_embed_type
        assert decoder_embed_type in ['linear', 'embedder', 'encoder']
        if decoder_embed_type == 'linear':
            self.decoder_embed = nn.Linear(self.in_dim, decoder_embed_dim)
        elif decoder_embed_type == 'embedder':
            self.decoder_embed = OmicsEmbeddingLayer(pretrained_gene_list, decoder_embed_dim,
                                                     'layernorm', dropout=dropout, pe_type=pe_type)
        elif decoder_embed_type == 'encoder':
            self.decoder_embed = self.encoder_embed

        self.mask_decoder_conditioner = MaskedEncoderConditioner(
            decoder_embed_dim, mult=4, use_ratio=mask_dec_cond_ratio, use_se=mask_dec_cond_se,
            use_semlp=mask_dec_cond_semlp, concat=mask_dec_cond_concat, disable=not mask_dec_cond)

        # self.decoder_blocks = nn.ModuleList([
        #     BasicTransformerBlock(decoder_embed_dim, decoder_num_heads, decoder_dim_head,
        #                           self_attn=False, cross_attn=True, context_dim=embed_dim,
        #                           qkv_bias=True, dropout=dropout, final_act=activation)
        #     for i in range(depth)])

        self.decoder_norm = create_norm(norm_layer, decoder_embed_dim)
        self.post_conditioner = PostConditioner(decoder_embed_dim, self.in_dim, dropout, post_cond_norm,
                                                post_cond_layers, post_cond_num_dict, post_cond_type, act=activation,
                                                cond_emb_dim=decoder_embed_dim, cond_mask_ratio=post_cond_mask_ratio)
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        # w = self.patch_embed.proj.weight.data
        # torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        # torch.nn.init.normal_(self.cls_token, std=.02)
        # torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # TODO: move to DDPM and get mask from there (masking is indepdent on forward)?
    def random_masking(self, x):
        # mask: 0 keep, 1 drop
        cell_mask_ratio = self.cell_mask_ratio
        feat_mask_ratio = self.feat_mask_ratio
        N, D = x.shape  # batch, dim

        if self.mask_mode == "v1":
            x_masked = x.clone()

            # apply cell masking
            len_keep = int(N * (1 - cell_mask_ratio))
            perm = np.random.permutation(N)
            idx_keep = perm[:len_keep]

            # generate the binary mask: 0 is keep, 1 is remove
            mask = torch.ones([N, D], device=x.device)
            mask[idx_keep] = 0

            # apply feature masking on the remaining part
            if feat_mask_ratio > 0:
                if self.mask_strategy == 'random':
                    feat_mask = mask[idx_keep]
                    feat_mask[torch.rand(len_keep, D) <= feat_mask_ratio] = 1
                    mask[idx_keep] = feat_mask
                elif self.mask_strategy == 'none_pad':
                    for i in idx_keep:
                        row = x_masked[i]
                        non_padding_idx = torch.nonzero(row - self.pad_value)[0]
                        n_mask = int(len(non_padding_idx) * feat_mask_ratio)
                        mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False)
                        mask[i][mask_idx] = 1
                else:
                    raise NotImplementedError(f'Unsupported mask strategy: {self.mask_strategy}')

            x_masked[mask.bool()] = self.mask_value
        elif self.mask_mode == "v2":
            if feat_mask_ratio != 0:
                warnings.warn(
                    "v2 mask disregards feat_mask_ratio, which is currently "
                    f"set to {feat_mask_ratio!r}.",
                    UserWarning,
                    stacklevel=2,
                )
            mask_ratios = torch.rand(N, 1, device=x.device)
            mask_ratios[torch.rand(N) < self.cell_mask_ratio] = 1
            mask = torch.rand_like(x) < mask_ratios

            x_masked = torch.zeros_like(x).masked_scatter(~mask, x)

        return x_masked, mask

    def forward_encoder(self, x, pe_input=None, input_gene_list=None, input_gene_idx=None):
        # embed input
        input_gene_list = default(input_gene_list, self.input_gene_list)
        input_gene_idx = default(input_gene_idx, self.input_gene_idx)
        x, gene_idx = self.encoder_embed(x, pe_input, input_gene_list, input_gene_idx)

        if self.blocks is None:
            hist = [None] * self.depth
        elif self.encoder_type in ("mlpparallel", "ffnparallel"):
            hist = [self.post_encoder_layer(blk(x)) for blk in self.blocks]
        else:
            hist = []
            for blk in self.blocks:  # apply context encoder blocks
                x = blk(x)
                hist.append(self.post_encoder_layer(x))

        return hist, gene_idx

    def forward_decoder(self, x, context_list, timesteps=None, pe_input=None, conditions=None,
                        input_gene_list=None, input_gene_idx=None, aug_graph=None,
                        return_latent=False, mask=None):
        # embed tokens
        if self.decoder_embed_type == 'linear':
            x = self.decoder_embed(x)
        else:
            input_gene_list = default(input_gene_list, self.input_gene_list)
            input_gene_idx = default(input_gene_idx, self.input_gene_idx)
            x, _ = self.decoder_embed(x, pe_input, input_gene_list, input_gene_idx)

        # apply masked conditioner
        x = self.mask_decoder_conditioner(x, mask)

        # calculate time embedding
        if timesteps is not None and not self.no_time_embed:
            timesteps = timesteps.repeat(x.shape[0]) if len(timesteps) == 1 else timesteps
            time_embed = self.time_embed(timestep_embedding(timesteps, self.decoder_embed_dim))
            x = x + time_embed
            # x = torch.cat([x, time_embed], dim=0)

        # calculate cell condition embedding
        cond_emb_list = None if self.cond_embed is None else self.cond_embed(conditions, aug_graph=aug_graph)
        if not isinstance(cond_emb_list, list):
            cond_emb_list = [cond_emb_list] * self.depth

        # apply conditioning transformer blocks
        x = x.unsqueeze(1)
        stack = zip(self.cond_enc_blocks, reversed(context_list), reversed(cond_emb_list))
        for i, (blk, ctxt, cond_emb) in enumerate(stack):
            full_cond_emb_list = list(filter(lambda x: x is not None, (ctxt, cond_emb)))
            if self.cond_cat_input:
                full_cond_emb_list.append(x)
            full_cond_emb = torch.cat(full_cond_emb_list, dim=1) if full_cond_emb_list else None

            if self.cond_strategy == "full_mix":
                x = blk(x, context=full_cond_emb)
            elif self.cond_strategy == "pre_mix":
                x = blk(x, context=full_cond_emb if i == 0 else None)
            else:
                raise ValueError(f"Unknown cond_strategy {self.cond_strategy!r}")

        x = x.squeeze(1)

        # apply post conditioner layers
        x = self.decoder_norm(x)
        return x if return_latent else self.post_conditioner(x, conditions)

    def forward_loss(self, target, pred, mask=None):
        if mask is None:
            mask = torch.ones(target.shape, device=target.device)
        loss = (pred - target) ** 2
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed tokens
        return loss

    def get_latent(self, x_orig, x, timesteps=None, pe_input=None, conditions=None,
                   input_gene_list=None, text_embeddings=None, aug_graph=None, mask=None):
        # embed input
        context_list, _ = self.forward_encoder(x_orig, pe_input, input_gene_list)
        latent = self.forward_decoder(x, context_list, timesteps, pe_input, conditions, input_gene_list,
                                      text_embeddings, aug_graph=aug_graph, return_latent=True, mask=mask)
        return latent

    def forward(self, x_orig, x, timesteps=None, pe_input=None, conditions=None, input_gene_list=None,
                input_gene_idx=None, target_gene_list=None, aug_graph=None, mask=True):
        # masking: length -> length * mask_ratio
        if isinstance(mask, torch.Tensor):
            x_orig_masked = x_orig * ~mask.bool()

        elif isinstance(mask, bool):
            if mask:
                x_orig_masked, mask = self.random_masking(x_orig)
                if self.decoder_mask is not None:
                    if self.decoder_mask == 'enc':
                        x[mask.bool()] = self.mask_value
                    elif self.decoder_mask == 'inv_enc':
                        x[~mask.bool()] = self.mask_value
                        # mask = torch.ones_like(x_orig)
                    elif self.decoder_mask == 'dec':
                        _, dec_mask, _, _ = self.random_masking(x)
                        x[dec_mask.bool()] = self.mask_value
                        mask = (mask.bool() | dec_mask.bool()).float()
                    else:
                        raise NotImplementedError(f"Unsuppoted decoder mask choice: {self.decoder_mask}")
            else:
                x_orig_masked = x_orig
                mask = torch.zeros_like(x_orig, dtype=bool)
        elif isinstance(mask, str):
            if mask == "all":
                x_orig_masked = x_orig * 0  # XXX: assumes mask value is 0
                mask = torch.ones_like(x_orig, dtype=bool)
            elif mask == "showcontext":
                x_orig_masked = x_orig
                mask = torch.ones_like(x_orig, dtype=bool)
            else:
                raise ValueError(f"Unknwon mask type {mask!r}")
        else:
            raise TypeError(f"Unknwon mask specification type {type(mask)}")

        if self.mask_context:
            warnings.warn(
                "After v6.0, mask_context should only be set in the DDPM level, instead of the diffusion model.",
                DeprecationWarning,
                stacklevel=2,
            )
            x = x * mask.bool()

        context_list, gene_idx = self.forward_encoder(x_orig_masked, pe_input, input_gene_list, input_gene_idx)
        pred = self.forward_decoder(x, context_list, timesteps, pe_input, conditions, input_gene_list,
                                    input_gene_idx, aug_graph=aug_graph, mask=mask)

        if target_gene_list is not None:
            gene_to_idx = dict(zip(self.pretrained_gene_list, list(range(len(self.pretrained_gene_list)))))
            target_gene_idx = torch.tensor([gene_to_idx[o] for o in target_gene_list if o in gene_to_idx]).long()
            target_gene_idx = target_gene_idx.to(x.device)
            ignored_gene_idx = [x for x in range(len(gene_idx)) if gene_idx[x] not in target_gene_idx]
            mask[:, ignored_gene_idx] = 0

        if self.subset_output:
            pred = pred[:, gene_idx]

        return pred, mask


class GeneMaskedAutoencoder(nn.Module):
    def __init__(
        self,
        pretrained_gene_list,
        input_gene_list=None,
        dropout=0.,
        mask_ratio=0.75,
        embed_dim=1024,
        depth=4,
        dim_head=64,
        num_heads=4,
        pe_type=None,
        decoder_embed_dim=512,
        decoder_num_heads=4,
        cond_dim=None,
        subset_output=False,
        decoder_dim_head=64,
        norm_layer=nn.LayerNorm,
        mlp_time_embed=False,
        encoder_linear_attn=False,
        decoder_linear_attn=False,
    ):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.in_dim = len(input_gene_list) if input_gene_list is not None else len(pretrained_gene_list)
        self.input_gene_list = input_gene_list
        self.mask_ratio = mask_ratio
        self.encoder_embed = OmicsFeatureEmbeddingLayer(
            pretrained_gene_list,
            'layernorm',
            embed_dim,
            dropout=dropout,
        )

        self.blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    embed_dim,
                    num_heads,
                    dim_head,
                    self_attn=True,
                    cross_attn=False,
                    dropout=dropout,
                    qkv_bias=True,
                    final_act=nn.GELU(),
                    linear_attn=encoder_linear_attn,
                )
                for _ in range(depth)
            ]
        )
        # self.encoder_proj = nn.Linear(embed_dim, latent_dim)
        # self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.subset_output = subset_output
        self.decoder_embed_dim = decoder_embed_dim
        self.time_embed = nn.Sequential(
            nn.Linear(decoder_embed_dim, 4 * decoder_embed_dim),
            nn.SiLU(),
            nn.Linear(4 * decoder_embed_dim, decoder_embed_dim),
        ) if mlp_time_embed else nn.Identity()

        if cond_dim is not None:
            self.cond_embed = nn.Linear(cond_dim, decoder_embed_dim)

        # self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim))
        # self.decoder_embed = nn.Linear(self.in_dim, decoder_embed_dim)
        self.decoder_embed = OmicsFeatureEmbeddingLayer(
            pretrained_gene_list,
            'layernorm',
            decoder_embed_dim,
            dropout=dropout,
        )

        self.decoder_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    decoder_embed_dim,
                    decoder_num_heads,
                    decoder_dim_head,
                    self_attn=False,
                    cross_attn=True,
                    context_dim=embed_dim,
                    qkv_bias=True,
                    dropout=dropout,
                    final_act=nn.GELU(),
                    linear_attn=decoder_linear_attn,
                )
                for _ in range(depth)
            ]
        )

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Sequential(
            nn.Linear(decoder_embed_dim, 1),
            nn.GELU(),
        )
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        # w = self.patch_embed.proj.weight.data
        # torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        # torch.nn.init.normal_(self.cls_token, std=.02)
        # torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x: torch.Tensor, inplace: bool = False):
        mask = torch.rand(x.shape, device=x.device) <= self.mask_ratio
        x_masked = x.clone() if not inplace else x
        x_masked[mask] = 0
        return x_masked, mask

    def forward_encoder(self, x, pe_input=None):
        x, gene_idx = self.encoder_embed(x, pe_input, self.input_gene_list)

        # apply Transformer blocks
        hist = []
        for blk in self.blocks:
            hist.append(blk(x))

        return hist, gene_idx

    def forward_decoder(self, x, context_list, timesteps=None, conditions=None):
        # embed tokens
        x, _ = self.decoder_embed(x)

        # calculate time embedding
        if timesteps is not None:
            timesteps = timesteps.repeat(x.shape[0]) if len(timesteps) == 1 else timesteps
            time_embed = self.time_embed(timestep_embedding(timesteps, self.decoder_embed_dim))
            x = x + time_embed.unsqueeze(1)  # expand token dim
            # x = torch.cat([x, time_embed], dim=0)

        # apply Transformer blocks
        layer_count = 0
        for blk in self.decoder_blocks:
            if conditions is not None:
                x = x + self.cond_embed(conditions).unsqueeze(1)  # expand token dim
            layer_count += 1
            x = blk(x, context=context_list[-layer_count])
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x).squeeze(-1)

        return x

    def forward_loss(self, target, pred, mask=None):
        if mask is None:
            mask = torch.ones(target.shape, device=target.device)
        loss = (pred - target) ** 2
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed tokens
        return loss

    def get_latent(self, x, pe_input=None):
        # embed input
        x, _ = self.encoder_embed(x, pe_input, self.input_gene_list)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        return x

    def forward(self, x_orig, x, timesteps=None, pe_input=None, conditions=None, mask=True):
        if mask:
            x_orig_masked, mask = self.random_masking(x_orig)
        else:
            x_orig_masked, mask = x_orig, torch.zeros_like(x_orig)

        context_list, gene_idx = self.forward_encoder(x_orig_masked, pe_input)
        pred = self.forward_decoder(x, context_list, timesteps, conditions)

        if self.subset_output:
            pred = pred[:, gene_idx]
            mask = mask[:, gene_idx]

        return pred, mask


class PathwayMaskedAutoencoder(nn.Module):
    def __init__(
        self,
        pretrained_gene_list,
        input_gene_list=None,
        num_pathways=64,
        dropout=0.,
        mask_ratio=0.75,
        drop_ratio=0.75,
        embed_dim=1024,
        head_embed_dim=1024,
        depth=4,
        dim_head=64,
        num_heads=4,
        pe_type=None,
        decoder_embed_dim=512,
        decoder_num_heads=4,
        cond_dim=None,
        subset_output=False,
        decoder_dim_head=64,
        norm_layer=nn.LayerNorm,
        mlp_time_embed=False,
        encoder_linear_attn=False,
        decoder_linear_attn=False,
        encoder_embed_attn_mask_mode=None,
        decoder_embed_attn_mask_mode=None,
    ):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        self.input_gene_list = input_gene_list
        self.num_feats = len(input_gene_list or pretrained_gene_list)
        self.num_pathways = num_pathways
        self.mask_ratio = mask_ratio
        self.drop_ratio = drop_ratio
        from .scmodel import BasePathwayEncodingLayer
        self.encoder_embed = BasePathwayEncodingLayer(
            pretrained_gene_list,
            'layernorm',
            embed_dim,
            dropout=dropout,
        )
        self.encoder_embed_attn_mask_mode = encoder_embed_attn_mask_mode

        self.blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    embed_dim,
                    num_heads,
                    dim_head,
                    self_attn=True,
                    cross_attn=False,
                    dropout=dropout,
                    qkv_bias=True,
                    final_act=nn.GELU(),
                    linear_attn=encoder_linear_attn,
                )
                for _ in range(depth)
            ]
        )
        # self.encoder_proj = nn.Linear(embed_dim, latent_dim)
        # self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.subset_output = subset_output
        self.decoder_embed_dim = decoder_embed_dim
        self.time_embed = nn.Sequential(
            nn.Linear(decoder_embed_dim, 4 * decoder_embed_dim),
            nn.SiLU(),
            nn.Linear(4 * decoder_embed_dim, decoder_embed_dim),
        ) if mlp_time_embed else nn.Identity()

        if cond_dim is not None:
            self.cond_embed = nn.Linear(cond_dim, decoder_embed_dim)

        # self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim))
        self.decoder_embed = nn.Linear(self.num_pathways, decoder_embed_dim)
        self.decoder_embed = BasePathwayEncodingLayer(
            pretrained_gene_list,
            'layernorm',
            decoder_embed_dim,
            dropout=dropout,
        )

        self.decoder_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    decoder_embed_dim,
                    decoder_num_heads,
                    decoder_dim_head,
                    self_attn=False,
                    cross_attn=True,
                    context_dim=embed_dim,
                    qkv_bias=True,
                    dropout=dropout,
                    final_act=nn.GELU(),
                    linear_attn=decoder_linear_attn,
                )
                for _ in range(depth)
            ]
        )

        from einops.layers.torch import Rearrange
        # self.decoder_norm = norm_layer(decoder_embed_dim)
        cell_emb_dim = decoder_embed_dim * self.num_pathways
        self.decoder_pred = nn.Sequential(
            Rearrange('b n d -> b (n d)'),
            norm_layer(cell_emb_dim),
            nn.Linear(cell_emb_dim, head_embed_dim),
            nn.GELU(),
            norm_layer(head_embed_dim),
            nn.Linear(head_embed_dim, self.num_feats),
            nn.GELU(),
        )
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        # w = self.patch_embed.proj.weight.data
        # torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        # torch.nn.init.normal_(self.cls_token, std=.02)
        # torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x: torch.Tensor, mask: bool = True, mask_all: bool = False):
        assert not (mask and mask_all), "Cannot set mask and mask_all at the same time."
        if (not mask) and (not mask_all):  # no mask applied
            x_masked, mask = x, torch.zeros_like(x, dtype=bool)
        elif (torch.rand(1) < self.drop_ratio) or mask_all:  # drop all
            x_masked, mask = None, torch.ones_like(x, dtype=bool)
        else:  # random mask entries
            mask = torch.rand(x.shape, device=x.device) <= self.mask_ratio
            x_masked = torch.where(mask, torch.zeros_like(x), x)
        return x_masked, mask

    def forward_encoder(self, x, pe_input=None):
        if x is None:
            return None, None

        x, gene_idx = self.encoder_embed(x, pe_input, self.input_gene_list,
                                         attn_mask_mode=self.encoder_embed_attn_mask_mode)

        # apply Transformer masked context encoder
        hist = []
        for blk in self.blocks:
            hist.append(blk(x))

        return hist, gene_idx

    def forward_decoder(self, x, context_list=None, timesteps=None, conditions=None):
        # embed tokens
        x, _ = self.decoder_embed(x)

        # calculate time embedding
        if timesteps is not None:
            timesteps = timesteps.repeat(x.shape[0]) if len(timesteps) == 1 else timesteps
            time_embed = self.time_embed(timestep_embedding(timesteps, self.decoder_embed_dim))
            x = x + time_embed.unsqueeze(1)  # expand token dim
            # x = torch.cat([x, time_embed], dim=0)

        # condition embedding
        x = x + self.cond_embed(conditions).unsqueeze(1)  # expand token dim

        # apply Transformer decoder with masked encoder conditions
        context_list = context_list or [None] * len(self.decoder_blocks)
        for blk, context in zip(self.decoder_blocks, reversed(context_list)):
            x = blk(x, context=context)
        # x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)  # .squeeze(-1)

        return x

    def forward_loss(self, target, pred, mask=None):
        if mask is None:
            mask = torch.ones(target.shape, device=target.device)
        loss = (pred - target) ** 2
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed tokens
        return loss

    def get_latent(self, x, pe_input=None):
        # embed input
        x, _ = self.encoder_embed(x, pe_input, self.input_gene_list)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        return x

    def forward(
        self,
        x_orig,
        x,
        timesteps=None,
        pe_input=None,
        conditions=None,
        mask=True,
        mask_all=False,
    ):
        x_orig_masked, mask = self.random_masking(x_orig, mask=mask, mask_all=mask_all)
        context_list, gene_idx = self.forward_encoder(x_orig_masked, pe_input)
        pred = self.forward_decoder(x, context_list, timesteps, conditions)

        if self.subset_output:
            raise NotImplementedError("Not supported yet since gene_idx might not be defined")
            # pred = pred[:, gene_idx]
            # mask = mask[:, gene_idx]

        return pred, mask


class PriorMaskedAutoencoder(nn.Module):
    def __init__(
        self,
        pretrained_gene_list,
        input_gene_list=None,
        cell_mask_ratio=0.,
        feat_mask_ratio=0.5,
        mask_strategy='random',
        decoder_mask=None,
        depth=4,
        embed_dim=1024,
        dim_head=64,
        num_heads=4,
        decoder_embed_dim=512,
        decoder_dim_head=64,
        decoder_num_heads=4,
        pe_type=None,
        cond_dim=None,
        dropout=0.,
        norm_layer=nn.LayerNorm,
        mlp_time_embed=False,
        decoder_embed_type='embedder',
        block_type='transformer',  # TODO: add block type option
        decoder_only=True,
        activation='gelu',
        prior='nb',  # zinb, nb, poisson
        pred_type='separate',  # separate, joint
        target_sum=None,
        lib_norm_flag=True,
        log_norm_flag=True,
        rescale_flag=False,
    ):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE masking options
        self.cell_mask_ratio = cell_mask_ratio
        self.feat_mask_ratio = feat_mask_ratio
        self.mask_strategy = mask_strategy
        self.decoder_mask = decoder_mask
        self.mask_value = 0
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE encoder specifics
        activation = create_activation(activation)
        self.in_dim = len(input_gene_list) if input_gene_list is not None else len(pretrained_gene_list)
        self.pretrained_gene_list = pretrained_gene_list
        self.input_gene_list = input_gene_list
        self.decoder_only = decoder_only
        if not self.decoder_only:
            self.encoder_embed = OmicsEmbeddingLayer(
                pretrained_gene_list, embed_dim, 'layernorm', dropout=dropout, pe_type=pe_type)
            self.blocks = nn.ModuleList([
                BasicTransformerBlock(embed_dim, num_heads, dim_head, self_attn=True, cross_attn=False,
                                      dropout=dropout, qkv_bias=True, final_act=activation)
                for i in range(depth)
            ])
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed_dim = decoder_embed_dim
        self.time_embed = nn.Sequential(
            nn.Linear(decoder_embed_dim, 4 * decoder_embed_dim),
            nn.SiLU(),
            nn.Linear(4 * decoder_embed_dim, decoder_embed_dim),
        ) if mlp_time_embed else nn.Identity()

        if cond_dim is not None:
            self.cond_embed = nn.Linear(cond_dim, decoder_embed_dim)

        # self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim))
        self.decoder_embed_type = decoder_embed_type
        assert decoder_embed_type in ['linear', 'embedder', 'encoder']
        if decoder_embed_type == 'linear':
            self.decoder_embed = nn.Linear(self.in_dim, decoder_embed_dim)
        elif decoder_embed_type == 'embedder':
            self.decoder_embed = OmicsEmbeddingLayer(
                pretrained_gene_list, decoder_embed_dim, 'layernorm', dropout=dropout, pe_type=pe_type)
        elif decoder_embed_type == 'encoder':
            self.decoder_embed = self.encoder_embed

        if not self.decoder_only:
            self.decoder_blocks = nn.ModuleList([
                BasicTransformerBlock(decoder_embed_dim, decoder_num_heads, decoder_dim_head, self_attn=False, cross_attn=True,
                                      context_dim=embed_dim, qkv_bias=True, dropout=dropout, final_act=activation)
                for i in range(depth)
            ])
        else:  # self attention when decoder only
            self.decoder_blocks = nn.ModuleList([
                BasicTransformerBlock(decoder_embed_dim, decoder_num_heads, decoder_dim_head, self_attn=True, cross_attn=False,
                                      context_dim=embed_dim, qkv_bias=True, dropout=dropout, final_act=activation)
                for i in range(depth)
            ])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE prior specifics
        self.lib_norm_flag = lib_norm_flag
        self.log_norm_flag = log_norm_flag
        self.pred_type = pred_type
        self.target_sum = target_sum
        self.prior = prior
        self.mean_act = nn.Softmax(dim=-1) if rescale_flag else nn.Identity()
        if self.prior == 'nb':
            self.lambda_pred = nn.ModuleDict({
                'mu': nn.Sequential(nn.Linear(decoder_embed_dim, self.in_dim), self.mean_act),
                'phi': nn.Linear(decoder_embed_dim, self.in_dim),
            })
        elif self.prior == 'zinb':
            self.lambda_pred = nn.ModuleDict({
                'mu': nn.Sequential(nn.Linear(decoder_embed_dim, self.in_dim), self.mean_act),
                'phi': nn.Linear(decoder_embed_dim, self.in_dim),
                'pi_logits': nn.Linear(decoder_embed_dim, self.in_dim),
            })
        elif self.prior == 'poisson':
            self.lambda_pred = nn.ModuleDict({
                'lbd': nn.Sequential(nn.Linear(decoder_embed_dim, self.in_dim), self.mean_act),
            })
        else:
            raise NotImplementedError(f'Unsupported prior: {self.prior}')
        # --------------------------------------------------------------------------

        self.initialize_weights()

    def rescale(self, x):
        x = x * self.library_size

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def scnormalize(self, x, target_sum=None, eps=1e-10):
        library_size = x.sum(1, keepdim=True)
        target_sum = library_size.median() + eps if target_sum is None else target_sum
        if self.lib_norm_flag:
            x = x * target_sum / (library_size + eps)
        if self.log_norm_flag:
            x = torch.log1p(x)
        return x, library_size

    def random_masking(self, x):
        cell_mask_ratio = self.cell_mask_ratio
        feat_mask_ratio = self.feat_mask_ratio
        N, D = x.shape  # batch, dim
        len_keep = int(N * (1 - cell_mask_ratio))
        perm = np.random.permutation(N)

        # keep the first subset
        idx_keep = perm[:len_keep]
        idx_mask = perm[len_keep:]
        idx_restore = np.argsort(perm)
        x_masked = x[idx_keep]

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, D], device=x.device)
        mask[idx_keep] = self.mask_value

        # apply feature masking on the remaining part
        if feat_mask_ratio > 0:
            feat_mask = mask[idx_keep]
            if self.mask_strategy == 'random':
                feat_mask[torch.rand(len_keep, D) <= feat_mask_ratio] = 1
            elif self.mask_strategy == 'none_zero':
                for i in range(len(x_masked)):
                    row = x_masked[i]
                    non_padding_idx = torch.nonzero(row - self.pad_value)[0]
                    n_mask = int(len(non_padding_idx) * feat_mask_ratio)
                    mask_idx = np.random.choice(non_padding_idx, n_mask, replace=False)
                    feat_mask[i][mask_idx] = 1
            else:
                raise NotImplementedError(f'Unsupported mask strategy: {self.mask_strategy}')
            x_masked[feat_mask.bool()] = self.mask_value
            mask[idx_keep] = feat_mask

        return x_masked, mask, idx_keep, idx_mask

    def forward_encoder(self, x, pe_input=None, input_gene_list=None):
        # embed input
        input_gene_list = input_gene_list if input_gene_list is not None else self.input_gene_list
        x, gene_idx = self.encoder_embed(x, pe_input, input_gene_list)

        # apply Transformer blocks
        hist = []
        for blk in self.blocks:
            hist.append(blk(x))

        return hist

    def forward_decoder(self, x, context_list=None, timesteps=None, pe_input=None, conditions=None,
                        input_gene_list=None):
        # embed tokens
        if self.decoder_embed_type == 'linear':
            x = self.decoder_embed(x)
        else:
            input_gene_list = input_gene_list if input_gene_list is not None else self.input_gene_list
            x, _ = self.decoder_embed(x, pe_input, input_gene_list)

        # calculate time embedding
        if timesteps is not None:
            timesteps = timesteps.repeat(x.shape[0]) if len(timesteps) == 1 else timesteps
            time_embed = self.time_embed(timestep_embedding(timesteps, self.decoder_embed_dim))
            x = x + time_embed
            # x = torch.cat([x, time_embed], dim=0)

        # apply Transformer blocks
        context_list = [None] * len(self.decoder_blocks) if context_list is None else context_list
        layer_count = 0
        for blk in self.decoder_blocks:
            if conditions is not None:
                x = x + self.cond_embed(conditions)
            layer_count += 1
            x = blk(x, context=context_list[-layer_count])
        x = self.decoder_norm(x)

        # predictor projection
        out_dict = {}
        for k in list(self.lambda_pred):
            out_dict[k] = self.lambda_pred[k](x)
        return out_dict

    # TODO: currently only support decoder only
    def forward(self, x, timesteps=None, pe_input=None, conditions=None, input_gene_list=None,
                mask_flag=True, mask=None):
        # masking: length -> length * mask_ratio
        x, _ = self.scnormalize(x)
        if mask is not None:
            x_masked, pe_input_masked = x, pe_input  # TODO: pe_input modification
            x_masked[mask.bool()] = self.mask_value
        elif mask_flag:
            x_masked, mask, idx_keep, _ = self.random_masking(x)
            pe_input_masked = pe_input[idx_keep] if pe_input is not None else None
        else:
            x_masked, pe_input_masked = x, pe_input
            mask = torch.zeros_like(x)

        out_dict = self.forward_decoder(x_masked, None, timesteps, pe_input_masked, conditions, input_gene_list)
        return out_dict, mask

    def get_latent(self, x, timesteps=None, pe_input=None, conditions=None, input_gene_list=None):
        # embed tokens
        if self.decoder_embed_type == 'linear':
            x = self.decoder_embed(x)
        else:
            input_gene_list = input_gene_list if input_gene_list is not None else self.input_gene_list
            x, _ = self.decoder_embed(x, pe_input, input_gene_list)

        # calculate time embedding
        if timesteps is not None:
            timesteps = timesteps.repeat(x.shape[0]) if len(timesteps) == 1 else timesteps
        else:
            timesteps = torch.full((x.shape[0],), fill_value=0, device=x.device)
        time_embed = self.time_embed(timestep_embedding(timesteps, self.decoder_embed_dim))
        x = x + time_embed
        # x = torch.cat([x, time_embed], dim=0)

        # apply Transformer blocks
        context_list = [None] * len(self.decoder_blocks)
        layer_count = 0
        for blk in self.decoder_blocks:
            if conditions is not None:
                x = x + self.cond_embed(conditions)
            layer_count += 1
            x = blk(x, context=context_list[-layer_count])
        x = self.decoder_norm(x)
        return x

    # def forward_model(self, x, timesteps=None, pe_input=None, conditions=None, input_gene_list=None,
    #             mask_flag=True, mask=None):
    #     # masking: length -> length * mask_ratio
    #     if mask_flag and mask is None:
    #         x_context_masked, enc_mask, idx_keep, _ = self.random_masking(x)
    #         if not self.decoder_only:
    #             if enc_mask is None:
    #                 x_context_masked, enc_mask, idx_keep, _ = self.random_masking(x_context)
    #                 pe_input_masked = pe_input[idx_keep] if pe_input is not None else None
    #             if self.decoder_mask is not None and dec_mask is None:
    #                 if self.decoder_mask == 'enc':
    #                     dec_mask = enc_mask
    #                 elif self.decoder_mask == 'inv_enc':
    #                     dec_mask = (~enc_mask.bool()).float()
    #                 elif self.decoder_mask == 'dec':
    #                     _, dec_mask, _, _ = self.random_masking(x_input)
    #                 else:
    #                     raise NotImplementedError(f"Unsuppoted decoder mask choice: {self.decoder_mask}")
    #             x_input[dec_mask.bool()] = self.mask_value
    #             mask = (enc_mask.bool() | dec_mask.bool()).float()
    #         else:
    #             x_input, mask, idx_keep, _ = self.random_masking(x_input)
    #             pe_input_masked = pe_input[idx_keep] if pe_input is not None else None
    #     elif mask is None:
    #         x_context_masked, pe_input_masked = x_context, pe_input
    #         mask = torch.zeros_like(x_context)

    #     if not self.decoder_only:
    #         context_list = self.forward_encoder(x_context_masked, pe_input_masked, input_gene_list)
    #     else:
    #         context_list = None
    #     out_dict = self.forward_decoder(x_input, context_list, timesteps, pe_input, conditions, input_gene_list)
    #     return out_dict, mask

    # def forward(self, x_orig, x=None, timesteps=None, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True,
    #             sample_flag=False): # target_gene_list=None,
    #     # masking: length -> length * mask_ratio
    #     x_orig, self.library_size = self.scnormalize(x_orig, target_sum=self.target_sum)
    #     if x is not None:
    #         x, _ = self.scnormalize(x, target_sum=self.target_sum)
    #         if self.pred_type == 'separate':
    #             if not sample_flag:
    #                 t_0 = torch.full((x_orig.shape[0],), fill_value=0, device=x_orig.device)
    #                 out_dict_0, mask = self.forward_model(x_orig, x_orig, t_0, pe_input, conditions, input_gene_list, mask_flag=mask_flag)
    #             else:
    #                 out_dict_0 = {}
    #             out_dict_t, mask = self.forward_model(x, x, timesteps, pe_input, conditions, input_gene_list, mask_flag=mask_flag, mask=mask)
    #         else:
    #             raise NotImplementedError(f'Unsupported pred type: {self.pred_type}')
    #     else:
    #         t_0 = torch.full((x_orig.shape[0],), fill_value=0, device=x_orig.device)
    #         out_dict_0, mask = self.forward_model(x_orig, x_orig, t_0, pe_input, conditions, input_gene_list, mask_flag=mask_flag)
    #         out_dict_t = {}
    #     out_dict = {
    #         'lbd_0': out_dict_0,
    #         'lbd_t': out_dict_t,
    #     }
    #     return out_dict, mask
