from dataclasses import dataclass
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import inspect
import math
from loguru import logger
import os
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from zinb import *

# modified from https://nn.labml.ai/zh/transformers/rope/index.html
class ROPE_Batch(torch.nn.Module):

    def __init__(self, d,base = 100):
        """
        * `d` is the number of features $d$
        * `base` is the constant used for calculating $\Theta$
        """
        super().__init__()

        self.base = base
        self.d = d
        self.cos_cached = None
        self.sin_cached = None
        
    def build_cache(self, seq_idx=None):
        """
        Cache $\cos$ and $\sin$ values
        """

        # Get sequence length
        seq_len = seq_idx.shape[1]

        # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(seq_idx.device)
        
        self.seq_idx = seq_idx

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.einsum('bn,d->bnd', self.seq_idx, theta)

        # Concatenate so that for row $m$ we have
        # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=2)
        # Cache them
        self.cos_cached = idx_theta2.cos().unsqueeze(-2)
        self.sin_cached = idx_theta2.sin().unsqueeze(-2)

    def _neg_half(self, x: torch.Tensor):
        # $\frac{d}{2}$
        d_2 = self.d // 2

        # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

    def forward(self, x):
        """
        * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
        """

        # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

        # Calculate
        # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
        neg_half_x = self._neg_half(x_rope)

        # Calculate
        #
        # \begin{align}
        # \begin{pmatrix}
        # x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
        # x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
        # \end{pmatrix} \\
        # \end{align}
        #
        # for $i \in {1, 2, ..., \frac{d}{2}}$
        x_rope = (x_rope * self.cos_cached) + (neg_half_x * self.sin_cached)
        #
        return torch.cat((x_rope, x_pass), dim=-1)

class ROPE_2D(torch.nn.Module):

    def __init__(self, d, base = 100):
        """
        * `d` is the number of features $d$
        * `base` is the constant used for calculating $\Theta$
        """
        super().__init__()

        self.rotary_pe_x = ROPE_Batch(d//2,base=base)
        self.rotary_pe_y = ROPE_Batch(d//2,base=base)
        self.base = base
        self.d = d

        
    def build_cache(self,seq_idx=None):
        # Cache $\cos$ and $\sin$ values
        self.rotary_pe_x.build_cache(seq_idx[...,0])
        self.rotary_pe_y.build_cache(seq_idx[...,1])
        
        
    def forward(self, x: torch.Tensor):
        """
        * `x` is the Tensor at the head of a key or a query with shape `[batch_size,seq_len, n_heads, d]`
        """
        x_x = x[...,:x.shape[-1]//2]
        x_y = x[...,x.shape[-1]//2:]
        x_rope_x = self.rotary_pe_x(x_x)
        x_rope_y = self.rotary_pe_y(x_y)
        #
        return torch.cat((x_rope_x, x_rope_y), dim=-1)

class ROPE_2D_rotate(torch.nn.Module):

    def __init__(self, batch_size, d, base = 100,rottimes=4):
        """
        * `d` is the number of features $d$
        * `base` is the constant used for calculating $\Theta$
        """
        super().__init__()

        self.base = base
        self.d = d
        self.rottimes = rottimes
        self.batch_size = batch_size

        rotmtx = []
        rdlist = []
        for i in range(self.rottimes):
            angle_radians = np.radians((90/self.rottimes)*i)
            rotation_matrix = torch.tensor([[np.cos(angle_radians), -np.sin(angle_radians)],
                                        [np.sin(angle_radians), np.cos(angle_radians)]])
            rotmtx.append(rotation_matrix)
            rdlist.append(ROPE_2D(self.d//self.rottimes,base=self.base))
        self.rotmtx = rotmtx
        self.rdlist = rdlist
    
    def build_cache(self,seq_idx=None):
        # Cache $\cos$ and $\sin$ values
        for i in range(self.rottimes):
            self.rdlist[i].build_cache(torch.bmm(seq_idx, torch.unsqueeze(self.rotmtx[i],0).repeat(seq_idx.size(0),1,1).to(seq_idx.dtype).to(seq_idx.device)))

    def forward(self, x: torch.Tensor):
        """
        * `x` is the Tensor at the head of a key or a query with shape `[batch_size,seq_len, n_heads, d]`
        """
        B,L,H,D = x.size()
        x = x.view(B,L,H,self.rottimes,-1)
        ot =[]
        for i in range(self.rottimes):
            ot.append(self.rdlist[i](x[...,i,:]))        
        # slide = x.shape[-1]//self.rottimes
        # ot =[]
        # for i in range(self.rottimes):
        #     ot.append(self.rdlist[i](x[...,slide*(i):slide*(i+1)]))

        return torch.cat(ot, dim=-1)

class Sinusoidal2dPE(nn.Module):
    def __init__(self, d_model, height=100, width=100,base=100):
        """
        :param d_model: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        """
        super().__init__()
        if d_model % 4 != 0:
            raise ValueError(
                "Cannot use sin/cos positional encoding with "
                "odd dimension (got dim={:d})".format(d_model)
            )
        self.d_model = d_model
        self.height = height
        self.width = width
        self.pe_key = "coord"
        self.missing_pe = nn.Parameter(torch.randn(d_model))

        pe = torch.zeros(d_model, height, width)
        # Each dimension use half of d_model
        d_model = int(d_model / 2)
        div_term = torch.exp(
            torch.arange(0.0, d_model, 2) * -(math.log(base) / d_model)
        )
        pos_w = torch.arange(0.0, width).unsqueeze(1)
        pos_h = torch.arange(0.0, height).unsqueeze(1)
        pe[0:d_model:2, :, :] = (
            torch.sin(pos_w * div_term)
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )
        pe[1:d_model:2, :, :] = (
            torch.cos(pos_w * div_term)
            .transpose(0, 1)
            .unsqueeze(1)
            .repeat(1, height, 1)
        )
        pe[d_model::2, :, :] = (
            torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )
        pe[d_model + 1 :: 2, :, :] = (
            torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
        )
        self.pe_enc = nn.Embedding.from_pretrained(pe.flatten(1).T)

    def forward(self, coordinates, mask):
        x = coordinates[..., 0].long()
        y = coordinates[..., 1].long()
        x[x >= self.width] = self.width - 1
        y[y >= self.height] = self.height - 1
        x[x <= 0] = 0
        y[y <= 0] = 0
        pe_input = x * self.width + y

        peemb = self.pe_enc(pe_input)
        peemb[~mask] = self.missing_pe
        return peemb

class Sinusoidal2dPEmodified(nn.Module):
    def __init__(self, d_model, height=200, width=200,base=100,device_type='cuda',rotationtime=8):
        """
        :param d_model: dimension of the model
        :param height: height of the positions
        :param width: width of the positions
        """
        super().__init__()
        if d_model % (4*rotationtime) != 0:
            raise ValueError(
                "Cannot use sin/cos positional encoding with "
                "odd dimension (got dim={:d})".format(d_model)
            )
        self.d_model = d_model
        self.height = height
        self.width = width
        self.pe_key = "coord"
        self.missing_pe = 0
        self.base = base

        ori_dmodel = d_model
        d_rotmodel = ori_dmodel//rotationtime
        self.pelist=[]
        for i in range(rotationtime):
            pe = torch.zeros(d_rotmodel, height, width)
            # Each dimension use half of d_model
            d_model = int(d_rotmodel / 2)
            div_term = torch.exp(
                torch.arange(0.0, d_model, 2) * -(math.log(1000) / d_model)
            )

            angle_radians = np.radians((90/rotationtime)*i)
            rotation_matrix = torch.tensor([[np.cos(angle_radians), -np.sin(angle_radians)],
                                        [np.sin(angle_radians), np.cos(angle_radians)]]).float()

            pos_w,pos_h = torch.meshgrid(torch.arange(0.0, width),torch.arange(0.0, height),indexing='xy')
            pos_w = pos_w.unsqueeze(-1)
            pos_h = pos_h.unsqueeze(-1)
            pos2 = torch.concat([pos_w,pos_h],dim=-1)
            rpos2 = torch.matmul(pos2,rotation_matrix)
            pos_w = rpos2[...,0].unsqueeze(-1)
            pos_h = rpos2[...,1].unsqueeze(-1)
            
            pe[0:d_model:2, :, :] = (
                torch.sin(pos_w * div_term).permute(2, 0, 1)
            )
            pe[1:d_model:2, :, :] = (
                torch.cos(pos_w * div_term).permute(2, 0, 1)
            )
            pe[d_model::2, :, :] = (
                torch.sin(pos_h * div_term).permute(2, 0, 1)
            )
            pe[d_model + 1 :: 2, :, :] = (
                torch.cos(pos_h * div_term).permute(2, 0, 1)
            )
            self.pelist.append(pe)
        pe = torch.concat(self.pelist,dim=0)
        self.pe_enc = nn.Embedding.from_pretrained(pe.flatten(1).T)

    def forward(self, coordinates, mask):
        x = ((coordinates[..., 0] -30)/(self.base/100)+self.width/2).long()
        y =((coordinates[..., 1] -30)/(self.base/100)+self.height/2).long()
        x[x >= self.width] = self.width - 1
        y[y >= self.height] = self.height - 1
        x[x <= 0] = 0
        y[y <= 0] = 0
        pe_input = x * self.width + y
        logger.debug(f'pe_input {pe_input.shape} {torch.unique(pe_input[mask],return_counts=True,dim=-1)} max freq: {torch.unique(pe_input[0][mask[0]],return_counts=True)[1].max()} ')
        peemb = self.pe_enc(pe_input)
        peemb[~mask] = self.missing_pe
        return peemb


@torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
def new_gelu(x):
    return (
        0.5
        * x
        * (
            1.0
            + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))
        )
    )

#acutal RMSNorm
class LayerNorm(nn.Module):
    def __init__(self, dim, bias=0,eps = 1e-8):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.weight

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")

    def forward(self, x):
        (
            B,
            T,
            C,
        ) = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(
            1, 2
        )  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(
            1, 2
        )  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(
            1, 2
        )  # (B, nh, T, hs)
        y = torch.nn.functional.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=None,
            dropout_p=self.dropout if self.training else 0,
            is_causal=True,
        )
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class TargetSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")

    def attmap(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
        # Efficient implementation equivalent to the following:
        L, S = query.size(-2), key.size(-2)
        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
        if is_causal:
            assert attn_mask is None
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask
        attn_weight = query @ key.transpose(-2, -1) * scale_factor
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        return attn_weight

    def forward(self, x, attentionmask=None,rope=None,return_attn=False):
        (
            B,
            T,
            C,
        ) = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head)
        q = q.view(B, T, self.n_head, C // self.n_head)
        if rope is not None:
            q = rope(q).transpose(
            1, 2
        )  # (B, nh, T, hs)
            k = rope(k).transpose(
            1, 2
        )  # (B, nh, T, hs)
        else:
            q = q.transpose(
            1, 2
        )  # (B, nh, T, hs)
            k = k.transpose(
            1, 2
        )  # (B, nh, T, hs)

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(
            1, 2
        )  # (B, nh, T, hs)

        y = torch.nn.functional.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attentionmask,
            dropout_p=self.dropout if self.training else 0,
            is_causal=False,
        )
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        if return_attn:
            return y, self.attmap(q,k,v,attentionmask,dropout_p=self.dropout if self.training else 0,is_causal=False)
        return y

# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
# https://arxiv.org/abs/2002.05202

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = new_gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class TargetBlock(nn.Module): # modified
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = TargetSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        # self.mlp = MLP(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 8 * config.n_embd, bias=config.bias),
            SwiGLU(),
            nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        )

    def forward(self, x,attentionmask=None,rope=None,return_attn=False):
        if return_attn:
            tmp,attn = self.attn(self.ln_1(x),attentionmask=attentionmask,rope=rope,return_attn=return_attn)
            x = x + tmp
            x = x + self.mlp(self.ln_2(x))
            return x,attn
        else:
            x = x + self.attn(self.ln_1(x),attentionmask=attentionmask,rope=rope)
            x = x + self.mlp(self.ln_2(x))
            return x

class MSE_loss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, block_size=1000, loss_len=5):
        super(MSE_loss, self).__init__()
        self.block_size = block_size
        self.mse_loss = nn.MSELoss(reduction="none")
        self.loss_len = loss_len

    def forward(self, logits_exp, targets_exp, mask):
        logits_exp = logits_exp[:, self.loss_len : -1, :]
        targets_exp = targets_exp[:, 1 + self.loss_len :, :]
        mask = ~mask[
            :, 1 + self.loss_len :
        ]  # True is spot with expression False is empty
        # print(~mask,targets_exp[~mask])
        loss_exp = self.mse_loss(logits_exp[mask], targets_exp[mask])
        return loss_exp.mean().mean()

class MSLE_loss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, block_size=1000, loss_len=5):
        super(MSLE_loss, self).__init__()
        self.block_size = block_size
        self.mse_loss = nn.MSELoss(reduction="none")
        self.loss_len = loss_len

    def forward(self, logits_exp, targets_exp, mask):
        logits_exp = logits_exp[:, self.loss_len : -1, :]
        targets_exp = targets_exp[:, 1 + self.loss_len :, :]
        mask = ~mask[
            :, 1 + self.loss_len :
        ]  # True is spot with expression False is empty
        # print(~mask,targets_exp[~mask])
        loss_exp = self.mse_loss(torch.log1p(logits_exp[mask]), torch.log1p(targets_exp[mask]))
        return loss_exp.mean().mean()

class MSE_Targetloss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, loss_len=5,block_size=1000):
        super(MSE_Targetloss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="mean")
        self.loss_len = loss_len

    def forward(self, logits_exp, targets_exp, mask):
        block_size = targets_exp.size(-2)
        logits_exp = logits_exp[:, block_size+self.loss_len:block_size*2-1, :]
        # logits_exp = logits_exp[:, self.loss_len:block_size-1, :]
        targets_exp = targets_exp[:, 1 + self.loss_len:block_size, :]
        mask = mask[:, 1 + self.loss_len:block_size]
        loss_exp = self.mse_loss(logits_exp[mask], targets_exp[mask])
        return loss_exp
    
    def forwardAE(self,logits_exp, targets_exp):
        loss_exp = self.mse_loss(logits_exp, targets_exp)
        return loss_exp
    
class MSLE_Targetloss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, loss_len=5,block_size=1000):
        super(MSLE_Targetloss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="none")
        self.loss_len = loss_len

    def forward(self, logits_exp, targets_exp, mask):
        block_size = targets_exp.size(-2)
        logits_exp = logits_exp[:, block_size+self.loss_len:block_size*2-1, :]
        # logits_exp = logits_exp[:, self.loss_len:block_size-1, :]
        targets_exp = targets_exp[:, 1 + self.loss_len:block_size, :]
        mask = mask[:, 1 + self.loss_len:block_size]
        epi=1
        loss_exp = self.mse_loss(torch.log1p(logits_exp[mask]), torch.log1p(targets_exp[mask])) * (targets_exp[mask]+epi)
        return loss_exp.mean().mean()
    
    def forwardAE(self,logits_exp, targets_exp):
        loss_exp = self.mse_loss(torch.log1p(logits_exp), torch.log1p(targets_exp))
        return loss_exp

class Quantize_Targetloss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, loss_len=5,block_size=1000,codebook=None,in_dim=1120,basedir=None):
        super(Quantize_Targetloss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduction="none")
        self.loss_len = loss_len
        self.codebook = codebook

        cdict = np.load(f'{basedir}kmdict.npy')
        self.cdict = torch.tensor(cdict).to(codebook.weight.device)
        self.m1,self.m2,self.m3 = cdict.max(0)+1

    def forward(self, logits_exp, targets_exp, mask,reduce='mean'):
        block_size = targets_exp.size(-2)
        logits_exp = logits_exp[:, block_size+self.loss_len:block_size*2-1, :]
        # logits_exp = logits_exp[:, self.loss_len:block_size-1, :]
        targets_exp = targets_exp[:, 1 + self.loss_len:block_size, :]
        mask = mask[:, 1 + self.loss_len:block_size]

        a = logits_exp[mask]
        b = self.codebook.weight.data
        logit_prob = torch.matmul(a,b.T)


        hierarchical_label = self.cdict[targets_exp[mask][...,0]] # N*3
        logger.debug(f'{logit_prob.shape,hierarchical_label.shape}')
        logit_prob = F.softmax(logit_prob,dim=-1)

        hminloss = F.nll_loss(torch.log(logit_prob+1e-10), targets_exp[mask][...,0], reduction='none')

        # maxcls = 4,9,14
        ph1 = torch.zeros(logit_prob.shape[0],self.m1).to(logit_prob.device)
        ph1.scatter_add_(1,self.cdict[:,0].unsqueeze(0).repeat(logit_prob.size(0),1).long(),logit_prob)
        h1loss = F.nll_loss(torch.log(ph1+1e-10), hierarchical_label[:,0].long(), reduction='none')
        ph2 = torch.zeros(logit_prob.shape[0],self.m2).to(logit_prob.device)
        ph2.scatter_add_(1,self.cdict[:,1].unsqueeze(0).repeat(logit_prob.size(0),1).long(),logit_prob)
        h2loss = F.nll_loss(torch.log(ph2+1e-10), hierarchical_label[:,1].long(), reduction='none')

        ph3 = torch.zeros(logit_prob.shape[0],self.m3).to(logit_prob.device)
        ph3.scatter_add_(1,self.cdict[:,2].unsqueeze(0).repeat(logit_prob.size(0),1).long(),logit_prob)
        h3loss = F.nll_loss(torch.log(ph3+1e-10), hierarchical_label[:,2].long(), reduction='none')

        loss_exp = 0.25*hminloss+0.25*h1loss+0.25*h2loss+0.25*h3loss


        if reduce == 'mean':
            return loss_exp.mean().mean()
        else:
            return loss_exp
    
    def forwardAE(self,logits_exp, targets_exp):
        exit(0)
        logit_prob = torch.matmul(logits_exp,self.codebook.weight.data.T)
        loss_exp = self.criterion(logit_prob, targets_exp[:,0])
        return loss_exp.mean()

class CLF_loss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, block_size=1000, loss_len=5):
        super(CLF_loss, self).__init__()
        self.block_size = block_size
        self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
        self.loss_len = loss_len

    def forward(self, logits_exp, targets_exp, mask):
        block_size = targets_exp.size(-2)
        logits_exp = logits_exp[:, block_size+self.loss_len :block_size*2 -1]
        targets_exp = targets_exp[:, 1 + self.loss_len :, :]
        mask = mask[
            :, 1 + self.loss_len :
        ]  # True is spot with expression False is empty
        # print(~mask,logits_exp[mask].view(-1,10).shape,targets_exp[mask].long().view(-1,1).shape)
        loss_exp = self.ce_loss(
            logits_exp[mask].view(-1, 10), targets_exp[mask].long().view(-1)
        )
        return loss_exp
    
    def forwardAE(self, logits_exp, targets_exp):
        loss_exp = self.ce_loss(
            logits_exp, targets_exp.long()
        )
        return loss_exp


class ZINB_loss(nn.Module):  ## based on prefix len to cut prefix losses
    def __init__(self, block_size=1000, loss_len=5,ridge_lambda=0.0, scale_lambda=0.0):
        super(ZINB_loss, self).__init__()
        self.block_size = block_size
        self.loss_len = loss_len
        self.ridge_lambda = ridge_lambda
        self.scale_lambda = scale_lambda

    def forward(self, logits_exp, targets_exp, mask):
        mean, disp, pi = logits_exp
        # logger.debug(f'{mean.shape, disp.shape}')
        eps = 1e-10
        mask = ~mask[
            :, 1 + self.loss_len :
        ]  # True is spot with expression False is empty
        x = targets_exp[:, 1 + self.loss_len :, :][mask]
        mean = mean[:, self.loss_len : -1, :][mask]
        disp = disp[:, self.loss_len : -1, :][mask]
        pi = pi[:, self.loss_len : -1, :][mask]
        t1 = (
            torch.lgamma(disp + eps)
            # + torch.lgamma(x + 1.0)
            - torch.lgamma(x + disp + eps)
        )
        t2 = (disp + x) * torch.log(1.0 + (mean / (disp + eps))) + (
            x * (torch.log(disp + eps) - torch.log(mean + eps))
        )
        nb_final = t1 + t2

        nb_case = nb_final - torch.log(1.0 - pi + eps)
        zero_nb = torch.pow(disp / (disp + mean + eps), disp)
        zero_case = -torch.log(pi + ((1.0 - pi) * zero_nb) + eps)
        result = torch.where(torch.le(x, 1e-8), zero_case, nb_case)

        if self.ridge_lambda > 0:
            ridge = self.ridge_lambda * torch.square(pi)
            result += ridge
        result = torch.mean(result)

        if self.scale_lambda > 0:
            scale_loss = F.mse_loss(mean.sum(-1), x.sum(-1))
            result += self.scale_lambda*scale_loss
        return result


@dataclass
class DaoConfig:   
    block_size: int = 512
    batch_size: int = 0
    n_layer: int = 24
    n_head: int = 20
    n_embd: int = 1120
    dropout: float = 0.0
    bias: bool = False
    train_mode: str = "frozenAE"
    loss_len: int = 5
    vocab_size: int = 2000
    task: str = "none"
    scale_lambda: float = 0
    ridge_lambda: float = 0
    encoder: str = 'mlp'
    decoder: str = 'mlp'
    skipconnect: bool = False
    noise: float = 0
    rope_base: float = 100
    loc_emb: str = "ones"
    device_type: str = "cuda"
    modeltype: str = "GeST"
    codebook: str = "none"
    n_class: int = 10
    pool: str = "single" # "mean" "max" "single"


class GeST(nn.Module):
    def __init__(self, config,onlyAE=False):
        super().__init__()
        assert config.block_size is not None
        self.config = config
        self.onlyAE = onlyAE

        if not onlyAE:
            self.transformer = nn.ModuleDict(
                dict(
                    drop=nn.Dropout(config.dropout),
                    h=nn.ModuleList([TargetBlock(config) for _ in range(config.n_layer)]),
                    ln_f=LayerNorm(config.n_embd, bias=config.bias),
                )
            )
            if self.config.modeltype.__contains__('ROPE'):
                if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == '0':
                    logger.info('Using ROPE')
                self.rope = ROPE_2D_rotate(config.batch_size, config.n_embd//config.n_head, config.rope_base, 1)
            else:
                if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == '0':
                    logger.info('Using Sinusoidal2dPEmodified')
                self.rope = None
                self.sinu2d = Sinusoidal2dPEmodified(config.n_embd,base=config.rope_base,device_type=config.device_type)

            self.adapterenc = nn.Identity()
            self.adapterdec = nn.Identity()

        if self.config.task == "mse" or self.config.task == "msle" or self.config.task == "quantize":
            if self.config.decoder == 'mlp':
                self.epx_head = nn.Linear(
                    config.n_embd, config.vocab_size, bias=True
                )  # expr level
            elif self.config.decoder == 'scimilarity':
                from scimilarity import scimilarity_Decoder
                self.epx_head = scimilarity_Decoder(config.vocab_size,128,[1024,1024,1024])
                if not self.onlyAE:
                    self.adapterdec = nn.Linear(config.n_embd,128,bias=False)
            elif self.config.decoder == 'multimlp': # 3Layer MLP without batch normalization
                self.epx_head = nn.Sequential(
                    nn.Linear(config.n_embd, config.n_embd),
                    nn.LeakyReLU(),
                    nn.Dropout(config.dropout),
                    nn.Linear(config.n_embd, config.n_embd),
                    nn.LeakyReLU(),
                    nn.Linear(config.n_embd, config.vocab_size)
                )
            else:
                exit(0)
            if self.config.task == "mse":
                self.criterion = MSE_Targetloss(config.loss_len,config.block_size)
            elif self.config.task == "msle":
                self.criterion = MSLE_Targetloss(config.loss_len,config.block_size)
            elif self.config.task == "quantize":
                assert self.config.codebook is not None
                self.codebook = nn.Embedding.from_pretrained(torch.tensor(np.load(self.config.codebook))).to(torch.float32).to(config.device_type)
                if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == '0':
                    logger.info(f'load codebook from {self.config.codebook}')
                    logger.info(f'{self.codebook.weight.data}')
                    logger.debug(f'model codebook norm:{self.codebook.weight.data.norm(dim=-1)}')

                basedir = config.codebook.split('codebook')[0]
                self.criterion = Quantize_Targetloss(config.loss_len,config.block_size,codebook=self.codebook,in_dim=config.vocab_size,basedir=basedir)
            else:
                exit(0)
        elif self.config.task == "clf":
            self.epx_head = nn.Linear(
                config.n_embd, config.vocab_size * 10, bias=False
            )  # expr level
            self.criterion = CLF_loss(config.block_size, config.loss_len)
        elif self.config.task == "zinb":
            self.epx_head = ZINB(
                config.n_embd,
                config.vocab_size,
                n_dec_1=config.n_embd,
                softmax=True,
                disp="gene",
            )
            self.criterion = ZINB_loss(config.block_size, config.loss_len,ridge_lambda=0,scale_lambda=config.scale_lambda)
        else:
            exit(0)

        if self.config.encoder == 'mlp':
            self.inlinear = nn.Linear(config.vocab_size, config.n_embd, bias=True)
        elif self.config.encoder == 'scimilarity':
            from scimilarity import scimilarity_Encoder
            self.inlinear = scimilarity_Encoder(config.vocab_size,128,[1024,1024,1024])
            if not self.onlyAE:
                self.adapterenc = nn.Linear(128,config.n_embd,bias=False)
        elif self.config.encoder == 'multimlp':
            self.inlinear = nn.Sequential(
                nn.Linear(config.vocab_size, config.n_embd),
                nn.LeakyReLU(),
                nn.Dropout(config.dropout),
                nn.Linear(config.n_embd, config.n_embd),
                nn.LeakyReLU(),
                nn.Linear(config.n_embd, config.n_embd)
            )
        else:
            exit(0)

        if self.config.loc_emb.__contains__('sinu'):
            if self.config.modeltype.__contains__('ROPE'):
                self.pe = Sinusoidal2dPEmodified(config.n_embd,base=config.rope_base,device_type=config.device_type)
            else:
                self.pe = self.sinu2d
        self.skipconnect = config.skipconnect

        L = S = self.config.block_size
        self.attentionmask_dig = torch.diag(torch.ones(L, dtype=torch.bool)).cuda()
        self.attentionmask_exp = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
        self.attentionmask_zero = torch.zeros(L, S, dtype=torch.bool).cuda()
        self.attentionmask_pos = (torch.ones(L, S, dtype=torch.int).tril(diagonal=1).cuda() - torch.ones(L, S, dtype=torch.int).tril(diagonal=0).cuda()).to(torch.bool)

        if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
            logger.info("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
        

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def query_index(self,expression=None,iflogit=False):
        device = self.codebook.weight.device
        dtype = self.codebook.weight.dtype
        if expression is None:
            a = self.codebook.weight.data.to(dtype).to(device)
        else:
            a = expression.to(dtype).to(device)
        b = self.criterion.codebook.weight.data.to(dtype).to(device)
        logit_prob = torch.matmul(a,b.T)

        if iflogit:
            return logit_prob
        else:
            return torch.argmax(logit_prob,dim=-1)

    def forward(
        self,
        inputs_embeds=None,
        coord=None,
        mask=None, # 1 indicates value, 0 indicates padding location
        embedding=False,
        tok_emb=None,
        return_attn=False,
    ):
        if self.config.task == "quantize":
            target_idx = inputs_embeds.to(torch.long)
            input_idx = target_idx.clone()
            if (self.config.noise!=0) and self.training:
                noisemask = mask.unsqueeze(-1) & (torch.rand(input_idx.shape[0], input_idx.shape[1], 1) < self.config.noise).to(input_idx.device)  
                random_values = torch.randint(0, self.codebook.weight.shape[0], (noisemask.sum(),1)).to(input_idx.device).to(input_idx.dtype)  
                logger.debug(f'noise debug: {noisemask.sum()} {mask.sum()} {noisemask.sum()/mask.sum()} random_values:{random_values[:,0]}')
                input_idx[noisemask] = random_values[:,0]
            inputs_embeds = self.codebook(input_idx[...,0])

        if self.onlyAE:
            return self.forwardAE(inputs_embeds=inputs_embeds,embedding=embedding)
        
        t_forward_s = time.time()

        if tok_emb is None:
            if (self.config.noise!=0) and self.training and self.config.task != "quantize":
                noise = torch.normal(mean=torch.zeros_like(inputs_embeds), std=self.config.noise).to(inputs_embeds.device).to(inputs_embeds.dtype)
            else:
                noise = 0
            if self.config.encoder == 'scimilarity' and self.config.train_mode != 'frozenAE':
                tok_emb = torch.zeros([inputs_embeds.shape[0],inputs_embeds.shape[1],self.config.n_embd]).to(inputs_embeds.device)
                tok_nonzeroemb = self.adapterenc(self.inlinear((inputs_embeds+noise)[mask]))
                tok_emb = tok_emb.to(tok_nonzeroemb.dtype)
                tok_emb[mask] = tok_nonzeroemb
            else:
                tok_emb = self.adapterenc(self.inlinear(inputs_embeds+noise))
        
        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward encoder time: %.5f' % t_forward_d)

        t_forward_s = time.time()

        if self.config.loc_emb =="sinu":
            posemb = self.pe(coord, mask)
        elif self.config.loc_emb =="ones":
            posemb = torch.ones_like(tok_emb)
        elif self.config.loc_emb =="sinu+mean":
            pure_posemb = self.pe(coord, mask)
            mean_emb = torch.cumsum(inputs_embeds,dim=1) / (torch.arange(inputs_embeds.shape[1]).cuda()[None,:,None]+1)
            mean_emb[:,1:] = mean_emb[:,:-1].clone() # summary of previous tokens 
            mean_emb = (mean_emb*(mask.unsqueeze(-1).to(mean_emb.dtype))).to(pure_posemb.dtype)
            posemb = pure_posemb + self.adapterenc(self.inlinear(mean_emb))
        else:
            logger.warning('loc emb error!')
            exit(0)
        coord_double = torch.concatenate([coord,coord],axis=1)

        if self.rope is not None:
            self.rope.build_cache(coord_double)
        else:
            if self.config.loc_emb =="sinu+mean":
                addposemb = pure_posemb
            else:
                addposemb = posemb
            tok_emb += addposemb

        x = torch.concatenate([posemb,tok_emb],axis=1)

        x = self.transformer.drop(x)

        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward positional embedding time: %.5f' % t_forward_d)
        t_forward_s = time.time()


        L= posemb.size(-2)

        attentionmask1 = torch.concatenate([self.attentionmask_dig[:L,:L],self.attentionmask_zero[:L,:L]],axis=1)
        attentionmask2 = torch.concatenate([self.attentionmask_pos[:L,:L],self.attentionmask_exp[:L,:L]],axis=1)
        attentionmask = torch.concatenate([attentionmask1,attentionmask2],axis=0).bool()

        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward attention mask time: %.5f' % t_forward_d)
        t_forward_s = time.time()

        if return_attn:
            attlist = []
            for block in self.transformer.h:
                x,attn = block(x,attentionmask=attentionmask,rope=self.rope,return_attn=return_attn)
                attlist.append(attn)
        else:
            for block in self.transformer.h:
                x = block(x,attentionmask=attentionmask,rope=self.rope,return_attn=return_attn)
        x = self.transformer.ln_f(x)

        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward transformer time: %.5f' % t_forward_d)
        t_forward_s = time.time()

        if self.config.decoder == 'scimilarity' and self.config.train_mode != 'frozenAE':
            logits_exp = torch.zeros([x.shape[0],x.shape[1],self.config.vocab_size]).to(x.device)
            doublemask = torch.cat([torch.zeros_like(mask).to(mask.device).to(mask.dtype),mask],axis=1)
            logits_exp_nonzero = self.epx_head(self.adapterdec(x[doublemask]))
            logits_exp = logits_exp.to(logits_exp_nonzero.dtype)
            logits_exp[doublemask] = logits_exp_nonzero
        else:
            logits_exp = self.epx_head(self.adapterdec(x))

        if self.config.task == "msle" or self.config.task == "quantize" or self.config.task == "mse":
            logits_exp = F.elu(logits_exp)

        if self.skipconnect:
            mean_emb = torch.cumsum(inputs_embeds,dim=1) / (torch.arange(inputs_embeds.shape[1]).cuda()[None,:,None]+1)
            mean_emb = (mean_emb*(mask.unsqueeze(-1).to(logits_exp.dtype))).to(logits_exp.dtype)

            logits_exp[:,tok_emb.shape[1]:] = logits_exp[:,tok_emb.shape[1]:]+mean_emb
        
        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward decoder time: %.5f' % t_forward_d)
        
        if self.config.task == "clf":
            b, s, _ = logits_exp.size()
            logits_exp = logits_exp.unsqueeze(-1).view(b, s, -1, 10)
            if embedding:
                return logits_exp.argmax(-1), x
            
        if embedding:
            if return_attn:
                return logits_exp, x, attlist
            return logits_exp, x
        t_forward_loss_s = time.time()
        if self.config.task == "quantize":
            inputs_embeds = target_idx
        loss_exp = self.criterion(
            logits_exp=logits_exp, targets_exp=inputs_embeds, mask=mask
        )
        t_forward_loss_e = time.time()
        t_forward_d = t_forward_loss_e - t_forward_loss_s
        logger.debug('forward loss time: %.5f' % t_forward_d)
        if return_attn:
            return logits_exp, loss_exp, attlist
        
        return logits_exp, loss_exp

    def forward_cellemb(
        self,
        inputs_embeds=None,
        coord=None,
        mask=None, 
        embedding=False,
        tok_emb=None,
        return_attn=False,
    ):
        if self.config.task == "quantize":
            target_idx = inputs_embeds.to(torch.long)
            input_idx = target_idx.clone()
            if (self.config.noise!=0) and self.training:
                noisemask = mask.unsqueeze(-1) & (torch.rand(input_idx.shape[0], input_idx.shape[1], 1) < self.config.noise).to(input_idx.device) 
                random_values = torch.randint(0, self.codebook.weight.shape[0], (noisemask.sum(),1)).to(input_idx.device).to(input_idx.dtype)  
                logger.debug(f'noise debug: {noisemask.sum()} {mask.sum()} {noisemask.sum()/mask.sum()} random_values:{random_values[:,0]}')
                input_idx[noisemask] = random_values[:,0]
            inputs_embeds = self.codebook(input_idx[...,0])

        t_forward_s = time.time()

        if tok_emb is None:
            if (self.config.noise!=0) and self.training and self.config.task != "quantize":
                noise = torch.normal(mean=torch.zeros_like(inputs_embeds), std=self.config.noise).to(inputs_embeds.device).to(inputs_embeds.dtype)
            else:
                noise = 0
            if self.config.encoder == 'scimilarity' and self.config.train_mode != 'frozenAE':
                tok_emb = torch.zeros([inputs_embeds.shape[0],inputs_embeds.shape[1],self.config.n_embd]).to(inputs_embeds.device)
                tok_nonzeroemb = self.adapterenc(self.inlinear((inputs_embeds+noise)[mask]))
                tok_emb = tok_emb.to(tok_nonzeroemb.dtype)
                tok_emb[mask] = tok_nonzeroemb
            else:
                tok_emb = self.adapterenc(self.inlinear(inputs_embeds+noise))
        
        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward encoder time: %.5f' % t_forward_d)

        t_forward_s = time.time()

        if self.config.loc_emb =="sinu":
            posemb = self.pe(coord, mask)
        elif self.config.loc_emb =="ones":
            posemb = torch.ones_like(tok_emb)
        elif self.config.loc_emb =="sinu+mean":
            pure_posemb = self.pe(coord, mask)
            mean_emb = torch.cumsum(inputs_embeds,dim=1) / (torch.arange(inputs_embeds.shape[1]).cuda()[None,:,None]+1)
            mean_emb[:,1:] = mean_emb[:,:-1].clone() # summary of previous tokens 
            mean_emb = (mean_emb*(mask.unsqueeze(-1).to(mean_emb.dtype))).to(pure_posemb.dtype)
            posemb = pure_posemb + self.adapterenc(self.inlinear(mean_emb))
        else:
            logger.warning('loc emb error!')
            exit(0)
        coord_double = coord
        if self.rope is not None:
            self.rope.build_cache(coord_double)
        else:
            if self.config.loc_emb =="sinu+mean":
                addposemb = pure_posemb
            else:
                addposemb = posemb
            tok_emb += addposemb

        x = tok_emb

        x = self.transformer.drop(x)

        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward positional embedding time: %.5f' % t_forward_d)
        t_forward_s = time.time()


        L= posemb.size(-2)


        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward attention mask time: %.5f' % t_forward_d)
        t_forward_s = time.time()

        if return_attn:
            attlist = []
            for block in self.transformer.h:
                x,attn = block(x,rope=self.rope,return_attn=return_attn)
                attlist.append(attn)
        else:
            for block in self.transformer.h:
                x = block(x,rope=self.rope,return_attn=return_attn)
        x = self.transformer.ln_f(x)

        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward transformer time: %.5f' % t_forward_d)
        t_forward_s = time.time()

        if self.config.decoder == 'scimilarity' and self.config.train_mode != 'frozenAE':
            logits_exp = torch.zeros([x.shape[0],x.shape[1],self.config.vocab_size]).to(x.device)
            doublemask = torch.cat([torch.zeros_like(mask).to(mask.device).to(mask.dtype),mask],axis=1)
            logits_exp_nonzero = self.epx_head(self.adapterdec(x[doublemask]))
            logits_exp = logits_exp.to(logits_exp_nonzero.dtype)
            logits_exp[doublemask] = logits_exp_nonzero
        else:
            logits_exp = self.epx_head(self.adapterdec(x))

        if self.config.task == "msle" or self.config.task == "quantize" or self.config.task == "mse":
            logits_exp = F.elu(logits_exp)

        if self.skipconnect:
            mean_emb = torch.cumsum(inputs_embeds,dim=1) / (torch.arange(inputs_embeds.shape[1]).cuda()[None,:,None]+1)
            mean_emb = (mean_emb*(mask.unsqueeze(-1).to(logits_exp.dtype))).to(logits_exp.dtype)

            logits_exp[:,tok_emb.shape[1]:] = logits_exp[:,tok_emb.shape[1]:]+mean_emb
        
        t_forward_e = time.time()
        t_forward_d = t_forward_e - t_forward_s
        logger.debug('forward decoder time: %.5f' % t_forward_d)
        
        if self.config.task == "clf":
            b, s, _ = logits_exp.size()
            logits_exp = logits_exp.unsqueeze(-1).view(b, s, -1, 10)
            if embedding:
                return logits_exp.argmax(-1), x
            
        if embedding:
            if return_attn:
                return logits_exp, x, attlist
            return logits_exp, x
        t_forward_loss_s = time.time()
        if self.config.task == "quantize":
            inputs_embeds = target_idx
        t_forward_loss_e = time.time()
        t_forward_d = t_forward_loss_e - t_forward_loss_s
        logger.debug('forward loss time: %.5f' % t_forward_d)
        if return_attn:
            return logits_exp, 0, attlist
        
        return logits_exp, 0


    def forward_cellembonepos(
        self,
        inputs_embeds=None,
        coord=None,
        mask=None, # 1 indicates value, 0 indicates padding location
        embedding=False,
        tok_emb=None,
        return_attn=False,
    ):
        # not align with the pre-training setting
        exit(0)
        return 0


    def forwardAE(
        self,
        inputs_embeds=None,
        embedding=False
    ):
        inputs_embeds = inputs_embeds.unsqueeze(1)
        if (self.config.noise!=0) and self.training:
            noise = torch.normal(mean=torch.zeros_like(inputs_embeds), std=self.config.noise).to(inputs_embeds.device).to(inputs_embeds.dtype)
            tok_emb = self.inlinear(inputs_embeds+noise)
        else:
            tok_emb = self.inlinear(inputs_embeds)

        logits_exp = self.epx_head(tok_emb).squeeze(1)
        
        if self.config.task == "clf":
            b, _ = logits_exp.size()
            logits_exp = logits_exp.unsqueeze(-1).view(b, -1, 10)
            if embedding:
                return logits_exp.argmax(-1), tok_emb
            
        if embedding:
            return logits_exp, tok_emb

        loss_exp = self.criterion.forwardAE(
            logits_exp=logits_exp, targets_exp=inputs_embeds.squeeze(1)
        )
        return logits_exp, loss_exp

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        decay = set()
        no_decay = set()
        frozen = set()

        whitelist_weight_modules = (torch.nn.Linear,)
        blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = "%s.%s" % (mn, pn) if mn else pn 
                if not p.requires_grad:
                    frozen.add(fpn)
                    continue
                elif pn.endswith("bias"):
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)
                elif pn.startswith("dec_"):
                    decay.add(fpn)
                elif pn.startswith("epx_head"):
                    decay.add(fpn)
                elif pn.startswith("inlinear"):
                    decay.add(fpn)
                elif pn.startswith("criterion"):
                    no_decay.add(fpn)
        if "lm_head.weight" in decay:
            decay.remove("lm_head.weight")
            no_decay.add("lm_head.weight")

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
        if "pe.missing_pe" in param_dict:
            no_decay.add("pe.missing_pe")

        inter_params = decay & no_decay
        union_params = decay | no_decay
        if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
            logger.info(f'frozen parameters: {sorted(list(frozen))}')
        assert (
            len(inter_params) == 0
        ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert (
            len(param_dict.keys() - union_params) == 0
        ), "parameters %s were not separated into either decay/no_decay set!" % (
            str(param_dict.keys() - union_params),
        )
        optim_groups = [
            {
                "params": [param_dict[pn] for pn in sorted(list(decay))],
                "weight_decay": weight_decay,
            },
            {
                "params": [param_dict[pn] for pn in sorted(list(no_decay))],
                "weight_decay": 0.0,
            },
        ]
        use_fused = (device_type == "cuda") and (
            "fused" in inspect.signature(torch.optim.AdamW).parameters
        )
        if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
            logger.info(f"using fused AdamW: {use_fused}")

        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(
            optim_groups, lr=learning_rate, betas=betas, **extra_args
        )

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size
        flops_per_token = 6 * N + 12 * L * H * Q * T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0 / dt)
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    def train(self, mode=True):
            """
            Override the default train() to freeze the BN parameters
            """
            super(GeST, self).train(mode)
            if self.config.train_mode == "frozenAE":
                for m in self.inlinear.modules():
                    if isinstance(m, nn.BatchNorm1d):
                        m.eval()
                        logger.info(f"freeze_bn var and mean and weight and bias in inlinear")
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

                for m in self.epx_head.modules():
                    if isinstance(m, nn.BatchNorm1d):
                        m.eval()
                        logger.info(f"freeze_bn var and mean and weight and bias in epx_head")
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False
#clip 20
WEIGHT = torch.tensor([9.74974503e-01, 2.31777786e+00, 6.93234958e+00, 6.17643524e-01,
       6.68795040e-01, 1.47002543e-01, 8.72730024e-02, 1.39751199e+00,
       4.23810241e-01, 6.17552954e-01, 5.44460552e+00, 5.09238497e+00,
       3.26338812e+00, 1.24321841e+00, 1.41798059e+01, 4.41261774e-01,
       9.68027208e-01, 7.29879093e+00, 9.74748841e-01, 4.34142814e-01,
       3.83377548e+00, 6.22482059e-01, 2.08743612e+00, 2.00000000e+01,
       3.87968896e+00, 1.54830969e+01, 1.29781275e+01, 1.24985973e+00,
       2.14484460e+00, 1.85051514e-01, 4.04747945e+00, 1.02218504e+00,
       4.15940975e+00, 1.56557709e+01, 4.68845240e-01, 4.79112897e+00,
       3.25066757e-01, 7.76295367e+00, 5.26655708e-01, 5.45306535e-01,
       1.08170507e-01, 2.60590456e-01, 2.36595639e+00, 6.92721830e-01,
       7.17995459e-01, 1.02742190e+00, 2.91043702e+00, 9.19520167e+00,
       9.38787866e-01, 3.81813451e+00, 5.65288908e+00, 4.31054490e-01,
       6.20647317e-01, 6.37125925e+00, 8.79207175e+00, 3.51976796e+00,
       1.73095042e+00, 4.42141981e+00, 1.79208611e+01, 8.63700239e-01,
       1.55259073e-01, 1.78904094e-01, 1.40896700e+00, 8.58681286e-01,
       2.00000000e+01, 6.65308431e+00, 1.83904033e+01, 9.63707635e+00,
       5.61146218e+00, 4.71601609e+00, 3.84954513e+00, 9.62606255e+00,
       7.28616326e+00, 1.69814612e+01, 2.00000000e+01, 1.08681351e+01,
       2.00000000e+01, 2.00000000e+01, 6.66361134e+00, 1.28789063e+01,
       1.10680746e+01, 5.66048705e+00, 2.00000000e+01, 2.00000000e+01,
       6.38091268e+00, 7.58811237e+00, 2.00000000e+01, 2.00000000e+01,
       6.98987945e+00, 8.92246264e+00, 2.00000000e+01, 2.00000000e+01,
       3.13465007e+00, 5.77695798e+00, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 1.54546876e+01, 1.99120679e+01, 1.41559743e+01,
       3.98053154e+00, 1.01724695e+01, 2.00000000e+01, 5.07092398e+00,
       2.00000000e+01, 3.80605727e+00, 8.81969082e+00, 1.93627695e+01,
       2.00000000e+01, 2.00000000e+01, 2.46208849e+00, 2.00000000e+01,
       2.75705556e+00, 4.46359551e+00, 9.26601181e+00, 6.48406831e+00,
       2.00000000e+01, 1.28006151e+01, 6.02920883e+00, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 1.10102023e+01,
       2.00000000e+01, 1.69132625e+01, 2.00000000e+01, 1.38532973e+01,
       3.15342746e+00, 2.00000000e+01, 8.25765170e+00, 6.17055292e+00,
       4.75864674e+00, 8.61227478e+00, 5.56327922e+00, 1.82311791e+01,
       2.00000000e+01, 1.57141879e+01, 1.55689551e+01, 1.51217320e+01,
       9.03734413e+00, 2.00000000e+01, 6.15252355e+00, 4.04747945e+00,
       1.87799437e+00, 1.26278932e+01, 2.00000000e+01, 2.00000000e+01,
       3.48914861e+00, 1.63232650e+01, 8.77375493e-01, 1.69814612e+01,
       2.00000000e+01, 4.27988045e+00, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 1.05153617e+01, 1.66458592e+01,
       1.71195218e+01, 2.00000000e+01, 1.57435602e+01, 3.47188983e+00,
       1.03220646e+01, 2.00000000e+01, 2.00000000e+01, 1.19472408e+01,
       8.14584597e+00, 1.55977865e+01, 1.47509715e+01, 7.56765924e+00,
       1.08401605e+01, 2.00000000e+01, 2.00000000e+01, 6.47410049e+00,
       5.91987963e-01, 1.08262272e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 3.45338447e+00, 2.00000000e+01,
       2.00000000e+01, 1.49605768e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 1.79974460e+01,
       2.00000000e+01, 4.14712198e+00, 2.00000000e+01, 2.00000000e+01,
       1.69472932e+01, 6.25301020e+00, 1.59221262e+01, 2.00000000e+01,
       6.20693053e+00, 2.00000000e+01, 2.54158260e+00, 9.78258390e+00,
       7.00732507e+00, 2.00000000e+01, 1.21541194e+01, 1.96335775e+01,
       3.83552128e+00, 1.38990177e+01, 6.11677904e+00, 3.86544504e+00,
       1.23501536e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       9.58225795e+00, 9.28644403e+00, 2.00000000e+01, 1.56267249e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 5.53767570e+00,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 1.17966453e+00, 2.00000000e+01,
       2.00000000e+01, 5.61520316e+00, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 5.53040363e+00, 2.00000000e+01,
       4.76133676e+00, 4.90553566e+00, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 7.69205912e+00, 1.84306449e+01,
       4.82682220e+00, 6.52424844e+00, 2.00000000e+01, 2.65955312e+00,
       8.20136780e+00, 1.29382561e+01, 2.00000000e+01, 2.00000000e+01,
       1.27811908e+01, 1.49871970e+01, 1.52863970e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 1.81135586e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       5.09855008e+00, 1.10972394e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 1.92301478e+01, 2.00000000e+01,
       5.20568896e+00, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 2.00000000e+01,
       1.08821767e+01, 2.00000000e+01, 1.04631115e+01, 2.00000000e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 1.00751253e+01,
       2.00000000e+01, 2.00000000e+01, 2.00000000e+01, 1.06752912e+01,
       9.95603396e+00, 5.02553982e+00, 2.00000000e+01, 1.44721731e+01,
       6.34725300e+00, 2.00000000e+01, 1.43567272e-01, 5.71696514e-01,
       5.08316520e+00, 6.42496261e-02, 4.25394178e+00, 3.94326064e-01,
       8.18542734e+00, 5.59855678e-02, 4.31409951e-02, 6.21196603e-01,
       7.84977142e+00, 3.74513327e+00, 5.57949439e-01, 1.35197508e+01,
       5.93949985e-01, 9.65110026e-02, 1.72461337e-02, 2.24608126e-01,
       5.52967748e-01, 1.42856254e-01, 1.64649401e-01, 3.63224147e-01,
       2.54021821e-02, 7.47451324e-02, 8.29506080e-01, 2.00000000e+01,
       1.62289109e+01, 3.20502463e+00])

class GeST_Anno(GeST):
    def __init__(self, config,onlyAE=False):
        super(GeST_Anno,self).__init__(config,onlyAE)
        self.annomodel = nn.Linear(config.n_embd, config.n_class, bias=True)
        if config.n_class == 338:
            if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
                logger.info("use weighted cross entropy loss for annotation task")
            # self.annocriterion = nn.CrossEntropyLoss(weight=torch.clip(WEIGHT,0,1))
            self.annocriterion = nn.CrossEntropyLoss(weight=torch.ones_like(WEIGHT))
        else:
            self.annocriterion = nn.CrossEntropyLoss()
        self.config = config
        if "LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == "0":
            logger.info(f"pool type: {config.pool}")

    def forward(self, inputs_embeds=None, coord=None, mask=None, label=None, embedding=True, tok_emb=None, return_attn=False):
        results = super().forward(inputs_embeds, coord, mask, embedding, tok_emb, return_attn)
        allcemb = torch.zeros([results[1].shape[0],self.config.n_embd]).to(results[1].device)
        blocksize = mask.shape[1]
        for i in range(results[1].shape[0]):
            if self.config.pool == 'single':
                cellemb = results[1][i,blocksize:][mask[i]][-2]
            elif self.config.pool == 'max':
                cellemb = results[1][i,blocksize:][mask[i]].max(0)[0]
            elif self.config.pool == 'mean':
                cellemb = results[1][i,blocksize:][mask[i]].mean(0)
            else:
                exit(0)

            allcemb[i]=cellemb
        x = allcemb
        x = self.annomodel(x)
        loss = self.annocriterion(x, label)
        logger.debug(f"loss: {loss} label: {label}")
        if return_attn:
            return x, loss, results[2]
        else:
            return x, loss