import math
import torch
from torch import nn
import torch.nn.functional as F

from image_synthesis.utils.misc import instantiate_from_config
from image_synthesis.modeling.utils.misc import gen_attention_mask
from image_synthesis.modeling.transformers.base_transformer import BaseTransformer
try:
    from image_synthesis.modeling.modules.sparse_matmul.sparse_matmul import SparseMatmul
except:
    print("Sparse Matmul compiled Error! Using torch matmul instead!")
    SparseMatmul = None

from image_synthesis.modeling.utils.misc import logits_top_k
import numpy as np
from einops import rearrange
from image_synthesis.distributed.distributed import is_primary, get_rank

# NOTE: We found that our implemented sparse mat mul is slower than torch.einsum or operation @ :(
SparseMatmul = None

class SparseAttention(nn.Module):
    """
    A multi-head masked sparse self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """
    def __init__(self,
                 n_embd, # the embed dim
                 n_head, # the number of heads
                 seq_len=None, # the max length of sequence
                 attn_pdrop=0.1, # attention dropout prob
                 resid_pdrop=0.1, # residual attention dropout prob
                 causal=True,
                 attn_type='full',
                 content_spatial_size=None, # H , W
                 conv_attn_kernel_size=None, # only need for dalle_conv attention
    ):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head
        self.causal = causal

        # sparse mask to ensure that attention is only applied to the left in the input sequence
        self.content_spatial_size = content_spatial_size
        self.attn_type = attn_type
        self.conv_attn_kernel_size = conv_attn_kernel_size
        content_seq_len = content_spatial_size[0] * content_spatial_size[1]
        condition_seq_len = seq_len - content_seq_len
        assert condition_seq_len >= 0, 'condition seq should be larger or equal to 0!'
        assert attn_type in ['full', 'dalle_row', 'dalle_col', 'dalle_conv']
        # if self.causal or self.attn_type != 'full':
        mask = gen_attention_mask(
            H=self.content_spatial_size[0],
            W=self.content_spatial_size[1],
            type=self.attn_type,
            causal=self.causal,
            condition_seq_len=condition_seq_len,
            kernel_size=self.conv_attn_kernel_size
        )
        self.register_buffer('mask', mask.view(1, 1, seq_len, seq_len))
        self.sparse_matmul = SparseMatmul() if SparseMatmul is not None else None

    def forward(self, x, mask=None):
        """
        x: B x T x C
        mask: None or tensor B x T, bool type. For values with False, no attention should be attened
        """
        B, T, C = x.size()
        hs = C // self.n_head
        # import pdb; pdb.set_trace()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, hs).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        # print(q.shape, k.shape)
        # import pdb; pdb.set_trace()
        if self.sparse_matmul is not None:
            att = self.sparse_matmul(
                q.contiguous().view(B*self.n_head, T, hs),
                k.contiguous().view(B*self.n_head, T, hs),
                mask = self.mask[0,0,:T,:T], # 2D mask is ok
            ).view(B, self.n_head, T, T) # B x nh x T x hs
            att = att * (1.0 / math.sqrt(k.size(-1)))
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        
        if mask is not None:
            mask = mask.view(B, 1, 1, T)
            att.masked_fill(~mask, float('-inf'))

        att = F.softmax(att, dim=-1) # (B, nh, T, T)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
        att = att.mean(dim=1, keepdim=False) # (B, T, T)

        # output projection
        y = self.resid_drop(self.proj(y))
        return y, att   


class FullAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self,
                 n_embd, # the embed dim
                 n_head, # the number of heads
                 seq_len=None, # the max length of sequence
                 attn_pdrop=0.1, # attention dropout prob
                 resid_pdrop=0.1, # residual attention dropout prob
                 causal=True,
    ):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head
        self.causal = causal



    def forward(self, x, encoder_output, mask=None):
        """
        x: B x T x C
        mask: None or tensor B x T, bool type. For values with False, no attention should be attened
        """
        B, T, C = x.size()
        # import pdb; pdb.set_trace()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        
        # print(q.shape, k.shape)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)

        att = F.softmax(att, dim=-1) # (B, nh, T, T)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
        att = att.mean(dim=1, keepdim=False) # (B, T, T)

        # output projection
        y = self.resid_drop(self.proj(y))
        return y, att

class CrossAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self,
                 condition_seq_len,
                 n_embd, # the embed dim
                 condition_embd, # condition dim
                 n_head, # the number of heads
                 seq_len=None, # the max length of sequence
                 attn_pdrop=0.1, # attention dropout prob
                 resid_pdrop=0.1, # residual attention dropout prob
                 causal=True,
    ):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(condition_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(condition_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

        self.n_head = n_head
        self.causal = causal

        # causal mask to ensure that attention is only applied to the left in the input sequence
        if self.causal:
            self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))
                                        .view(1, 1, seq_len, seq_len))


    def forward(self, x, encoder_output, mask=None):
        """
        x: B x T x C
        encoder_output: B x T x C_condition
        mask: None or tensor B x T, bool type. For values with False, no attention should be attened
        """
        B, T, C = x.size()
        B, T_E, _ = encoder_output.size()
        # import pdb; pdb.set_trace()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        
        # print(q.shape, k.shape)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)

        # if self.causal:
            # print(att.shape, self.mask.shape, T)
        #     att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        
        # if mask is not None:
        #     mask = mask.view(B, 1, 1, T)
        #     att.masked_fill(~mask, float('-inf'))

        att = F.softmax(att, dim=-1) # (B, nh, T, T)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
        att = att.mean(dim=1, keepdim=False) # (B, T, T)

        # output projection
        y = self.resid_drop(self.proj(y))
        return y, att

class GELU2(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return x * F.sigmoid(1.702 * x)

class AdaLayerNorm(nn.Module):
    def __init__(self, n_embd, diffusion_step):
        super().__init__()
        self.emb = nn.Embedding(diffusion_step, n_embd)
        self.silu = nn.SiLU()
        self.linear = nn.Linear(n_embd, n_embd*2)
        self.layernorm = nn.LayerNorm(n_embd)

    def forward(self, x, timestep):
        emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1)
        scale, shift = torch.chunk(emb, 2, dim=2)
        x = self.layernorm(x) * (1 + scale) + shift
        return x


class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self,
                 condition_seq_len,
                 n_embd,
                 n_head,
                 seq_len,
                 attn_pdrop=0.1,
                 resid_pdrop=0.1,
                 causal=True,
                 mlp_hidden_times=4,
                 activate='GELU',
                 attn_type='full',
                 if_upsample=False,
                 upsample_type='bilinear',
                 upsample_pre_channel=0,
                 content_spatial_size=None, # H , W
                 conv_attn_kernel_size=None, # only need for dalle_conv attention
                 condition_dim=1024,
                 diffusion_step=100,
                 ):
        super().__init__()
        self.if_upsample = if_upsample
        if self.if_upsample == True:
            self.upsample = Upsample(scale=2, upsample_type=upsample_type, dim=n_embd, pre_dim=upsample_pre_channel)

        self.ln1 = nn.LayerNorm(n_embd)
        # self.ln1 = AdaLayerNorm(n_embd, diffusion_step)
        self.ln2 = nn.LayerNorm(n_embd)
        self.if_selfcross = False
        if attn_type == 'full':
            self.attn = FullAttention(
                n_embd=n_embd,
                n_head=n_head,
                seq_len=seq_len,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop,
                causal=causal
            )
        elif attn_type == 'cross':
            self.attn = CrossAttention(
                condition_seq_len,
                n_embd=n_embd,
                condition_embd=condition_dim,
                n_head=n_head,
                seq_len=seq_len,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop,
                causal=causal
            )
        elif attn_type == 'selfcross':
            self.attn1 = FullAttention(
                    n_embd=n_embd,
                    n_head=n_head,
                    seq_len=seq_len,
                    attn_pdrop=attn_pdrop, 
                    resid_pdrop=resid_pdrop,
                    causal=causal
                    )
            self.attn2 = CrossAttention(
                    condition_seq_len,
                    n_embd=n_embd,
                    condition_embd=condition_dim,
                    n_head=n_head,
                    seq_len=seq_len,
                    attn_pdrop=attn_pdrop,
                    resid_pdrop=resid_pdrop,
                    causal=causal
                    )
            self.if_selfcross = True
            # self.ln1_1 = nn.LayerNorm(n_embd)
            self.ln1_1 = AdaLayerNorm(n_embd, diffusion_step)
        else:
            self.attn = SparseAttention(
                n_embd=n_embd,
                n_head=n_head,
                seq_len=seq_len,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop,
                causal=causal,
                attn_type=attn_type,
                content_spatial_size=content_spatial_size, # H , W
                conv_attn_kernel_size=conv_attn_kernel_size, # only need for dalle_conv attention
            )
        assert activate in ['GELU', 'GELU2']
        act = nn.GELU() if activate == 'GELU' else GELU2()
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, mlp_hidden_times * n_embd),
            act,
            nn.Linear(mlp_hidden_times * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x, encoder_output, timestep, mask=None):    
        if self.if_upsample == True:
            x = self.upsample(x)
        if self.if_selfcross == False:
            a, att = self.attn(self.ln1(x), encoder_output, mask=mask)
            x = x + a 
        else:
            a, att = self.attn1(self.ln1(x, timestep), encoder_output, mask=mask)
            x = x + a
            a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask)
            x = x + a
        x = x + self.mlp(self.ln2(x))

        return x, att

class Upsample(nn.Module):
    def __init__(self, scale, upsample_type, dim, pre_dim):
        super().__init__()
        assert scale == 2
        if upsample_type == 'pixel_shuffle':
            self.upsample = nn.PixelShuffle(2)
            self.reduction = nn.Linear(pre_dim, dim*4, bias=False)
        elif upsample_type == 'bilinear':
            self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            self.reduction = nn.Linear(pre_dim, dim, bias=False)
        else:
            print("upsample_type error")
        self.pre_dim = pre_dim
        self.dim = dim
        self.upsample_type = upsample_type

    def forward(self, x):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        H = int(math.sqrt(L))
        W = int(math.sqrt(L))
        assert L == H * W, "input feature has wrong size"
        assert C == self.pre_dim, "wrong in input channel"

        # x = x.view(B, H, W, C)
        # x = x.permute(0, 3, 1, 2).contiguous()   # B,C,H,W
        # x = self.upsample(x)
        if self.upsample_type == 'pixel_shuffle':
            x = self.reduction(x)
            x = x.view(B, H, W, self.dim*4)
            x = x.permute(0, 3, 1, 2).contiguous()   # B,C,H,W
            x = self.upsample(x)
            x = x.permute(0, 2, 3, 1).contiguous().view(B, L*4, self.dim)
            # x = x.permute(0, 2, 3, 1).contiguous().view(B, L*4, C//4)
            # x = self.reduction(x)
        elif self.upsample_type == 'bilinear':
            x = self.reduction(x)
            x = x.view(B, H, W, self.dim)
            x = x.permute(0, 3, 1, 2).contiguous()    # B,C,H,W
            x = self.upsample(x)
            x = x.permute(0, 2, 3, 1).contiguous().view(B, L*4, self.dim)
        else:
            print("upsample_type error")

        return x


def generate_bool_mask(mask_ratio, seq_len):
    threshold = int(seq_len * mask_ratio)
    mask = torch.randperm(seq_len) < threshold
    return mask


class NAR8_32gray_Transformer(BaseTransformer):
    def __init__(
        self,
        *,
        n_layer, # number of layers in transformer
        condition_seq_len, # length of condition sequences
        content_seq_len, # length of content sequences
        embd_pdrop=0., # embedding dropout prob

        n_embd, # the embed dim
        n_head, # the number of heads
        attn_pdrop=0.1, # attention dropout prob
        resid_pdrop=0.1, # residual attention dropout prob
        causal=True,
        block_activate='GELU',
        mlp_hidden_times=4, # the times of hidden dimension in the MLP of attetntion block

        attn_with_mask = False, # perform attention with mask
        attn_condition_with_mask=False,
        attn_content_with_mask=False,
        attn_type='full',
        content_spatial_size=None, # H , W
        conv_attn_kernel_size=None, # only need for dalle_conv attention

        predict_condition=False, # prediction with condition tokens 
        condition_loss_weight=1/8.0,
        content_loss_weight=7/8.0,
        content_emb_config=None,
        condition_emb_config=None,

        condition_ignore_token=-100,
        content_ignore_token=-100,

        diffusion_step=100,
        learnable_keep_rate=True,
        diffusion_keep_rate=0.85,
        # diffusion_keep_rate=[0.8,0.8,0.8,0.8],
        # diffusion_global_rank=0,
        onehot_input=False,
        final_loss_weight=None,

    ):
        super().__init__()

        # embeddings for condition and content
        self.content_emb = instantiate_from_config(content_emb_config)
        if condition_emb_config is None:
            # share the condition embed with content embed
            self.condition_emb = None
            assert not predict_condition, 'If want to predict condition token, please provide condition embed config'
        else:
            # for condition and config, we learn a seperate embedding
            self.condition_emb = instantiate_from_config(condition_emb_config)
            # assert self.condition_emb.embed_dim == self.content_emb.embed_dim
            self.condition_dim = self.condition_emb.embed_dim

        self.condition_ignore_token = condition_ignore_token
        self.content_ignore_token = content_ignore_token

        # drop for embedding
        if embd_pdrop > 0:
            self.drop = nn.Dropout(embd_pdrop)
        else:
            self.drop = None
        
        # transformer
        if isinstance(attn_type, str):
            if attn_type == 'full' or 'selfcross':
                all_attn_type = [attn_type] * n_layer
            elif attn_type == 'cross':
                all_attn_type = []
                for index in range(0,n_layer):
                    if index%2==0:
                        all_attn_type.append('full')
                    else:
                        all_attn_type.append('cross')
            else:
                raise NotImplementedError
        elif isinstance(attn_type, dict):
            start_attn_type = attn_type.get('start', []) # list
            end_attn_type = attn_type.get('end', []) # list
            middle_attn_type = attn_type['middle'] # list
            middle_layers = n_layer - len(start_attn_type) - len(end_attn_type)
            middle_attn_type = middle_attn_type * int(middle_layers / len(middle_attn_type))
            all_attn_type = start_attn_type + middle_attn_type + end_attn_type
            assert len(all_attn_type) == n_layer, 'number of attn type not equal to number of layers!'
        else:
            raise NotImplementedError
        
        if content_spatial_size is None:
            s = int(math.sqrt(content_seq_len))
            assert s * s == content_seq_len
            content_spatial_size = (s, s)

        self.blocks = nn.Sequential(*[Block(
                condition_seq_len,
                n_embd=n_embd,
                n_head=n_head,
                seq_len=content_seq_len,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop,
                causal=causal,
                mlp_hidden_times=mlp_hidden_times,
                activate=block_activate,
                attn_type=all_attn_type[n],
                content_spatial_size=content_spatial_size, # H , W
                condition_dim = self.condition_dim,
                diffusion_step = diffusion_step,
        ) for n in range(n_layer)])

        # final prediction head
        out_cls = self.content_emb.num_embed
        self.to_logits = nn.Sequential(
            nn.LayerNorm(n_embd),
            nn.Linear(n_embd, out_cls),
        )
        
        self.causal = causal
        self.condition_seq_len = condition_seq_len
        self.content_seq_len = content_seq_len
        self.condition_loss_weight = condition_loss_weight
        self.content_loss_weight = content_loss_weight
        self.predict_condition = predict_condition
        self.attn_with_mask = attn_with_mask
        self.attn_condition_with_mask = attn_condition_with_mask
        self.attn_content_with_mask = attn_content_with_mask

        self.diffusion_step = diffusion_step
        self.diffusion_keep_rate = diffusion_keep_rate
        # self.diffusion_global_rank = diffusion_global_rank
        self.diffusion_item_number = self.content_emb.num_embed
        self.diffusion_matrix = []
        self.onehot_input = onehot_input
        if final_loss_weight == 'step':
            self.final_loss_weight = 1/self.diffusion_item_number
        else:
            self.final_loss_weight = final_loss_weight

        self.learnable_keep_rate = learnable_keep_rate
        if self.learnable_keep_rate == True:
            self.diffusion_keep_rate = nn.Parameter(torch.zeros(self.diffusion_step).double(), requires_grad=True)
        else:
            self.diffusion_keep_rate = nn.Parameter(torch.zeros(self.diffusion_step).double(), requires_grad=False)
        
        self.diffusion_keep_rate.data.fill_(diffusion_keep_rate)

        if self.attn_with_mask:
            print('Warning: attn_with_mask is suppressed!')
            self.attn_condition_with_mask = True
            self.attn_content_with_mask = True

        self.apply(self._init_weights)

        if condition_emb_config['target'] == 'image_synthesis.modeling.embeddings.clip_text_embedding.CLIPTextEmbedding':
            self.condition_emb = instantiate_from_config(condition_emb_config)
            # assert self.condition_emb.embed_dim == self.content_emb.embed_dim

        ########## build diffusion matrix ###################
        # init diffusion matrix #
        # self.diffusion_a_matrix = nn.Parameter(torch.zeros(self.diffusion_step+1).double())
        # self.diffusion_a_matrix.data[0] = 1
        # self.diffusion_b_matrix = nn.Parameter(torch.zeros(self.diffusion_step+1).double())
        # self.update_diffusion_matrix()

        # self.diffusion_a_matrix = self.diffusion_keep_rate.cumprod(dim=0)
        # self.diffusion_a_matrix = torch.cat((torch.ones(1).double(), self.diffusion_a_matrix), dim=0)
        # self.diffusion_b_matrix = (1-self.diffusion_a_matrix)/self.diffusion_item_number

        # print("after initization, diffusion matix of A and B are:")
        # print(self.diffusion_a_matrix)
        # print(self.diffusion_b_matrix)

        # self.build_diffusion_matrix()
        # self.diffusion_matrix.append((1,0))
        # att = np.cumprod(self.diffusion_keep_rate)
        # btt = (1-att)/self.diffusion_item_number
        # for i in range(len(att)):
        #     self.diffusion_matrix.append((att[i], btt[i]))
        # print("a and b of diffusion_matrix is ")
        # print(self.diffusion_matrix)


    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def update_diffusion_matrix(self):
        diffusion_a_matrix = self.diffusion_keep_rate.cumprod(dim=0) 
        diffusion_a_matrix = torch.cat((torch.ones(1).type_as(diffusion_a_matrix), diffusion_a_matrix), dim=0)
        diffusion_b_matrix = (1-diffusion_a_matrix)/self.diffusion_item_number
        return diffusion_a_matrix, diffusion_b_matrix


    # def get_next_diffusion(self, start_ab, a2):
    #     a2 = np.float64(a2)
    #     a1 = np.float64(start_ab[0])
    #     b1 = np.float64(start_ab[1])
    #     b2 = np.float64(1-a2)/np.float64(self.diffusion_item_number)
        # new_a = a1*a2+b1*b2
    #     new_a = a1*a2
    #     new_b = b1*a2+b2
    #     return (new_a, new_b)

    # def build_diffusion_matrix(self):
    #     start_ab = (self.diffusion_keep_rate[0], (1-self.diffusion_keep_rate[0])/self.diffusion_item_number)
    #     self.diffusion_matrix.append(start_ab)
    #     for i in range(0, self.diffusion_step):
    #         self.diffusion_matrix.append(self.get_next_diffusion(self.diffusion_matrix[i], self.diffusion_keep_rate[i]))
    #     assert len(self.diffusion_matrix) == self.diffusion_step + 1
    #     self.diffusion_matrix.insert(0, (1,0))


    @property
    def device(self):
        return self.to_logits[-1].weight.device

    def parameters(self, recurse=True, name=None):
        """
        Following minGPT:
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """
        # return super().parameters(recurse=True)
        if name is None or name == 'none':
            return super().parameters(recurse=recurse)
        else:
            # separate out all parameters to those that will and won't experience regularizing weight decay
            print("GPTLikeTransformer: get parameters by the overwrite method!")
            decay = set()
            no_decay = set()
            whitelist_weight_modules = (torch.nn.Linear, )
            blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
            for mn, m in self.named_modules():
                for pn, p in m.named_parameters():
                    fpn = '%s.%s' % (mn, pn) if mn else pn # full param name

                    if pn.endswith('bias'):
                        # all biases will not be decayed
                        no_decay.add(fpn)
                    elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                        # weights of whitelist modules will be weight decayed
                        decay.add(fpn)
                    elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                        # weights of blacklist modules will NOT be weight decayed
                        no_decay.add(fpn)
            # special case the position embedding parameter as not decayed
            module_name = ['condition_emb', 'content_emb']
            pos_emb_name = ['pos_emb', 'width_emb', 'height_emb', 'pad_emb', 'token_type_emb']
            for mn in module_name:
                if hasattr(self, mn) and getattr(self, mn) is not None:
                    for pn in pos_emb_name:
                        if hasattr(getattr(self, mn), pn):
                            if isinstance(getattr(getattr(self, mn), pn), torch.nn.Parameter):
                                no_decay.add('{}.{}'.format(mn, pn))

            # validate that we considered every parameter
            param_dict = {pn: p for pn, p in self.transformer.named_parameters()}# if p.requires_grad} 
            inter_params = decay & no_decay
            union_params = decay | no_decay
            assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
            assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                        % (str(param_dict.keys() - union_params), )

            # create the pytorch optimizer object
            optim_groups = [
                {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
                {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
            ]
            return optim_groups

    def sample_from_logits(self, logits, return_prob=False):
        # assert abs(logits.sum().item()-logits.size()[0]*logits.size()[2])<0.5
        if abs(logits.sum().item()-logits.size()[0]*logits.size()[2])>0.5:
            print("not good enough")
        assert logits.dim() == 3
        # output_sample = torch.argmax(torch.log(logits)-torch.log(-torch.log(torch.rand(logits.size()).type_as(logits))), dim=1)
        output_sample = torch.argmax(torch.log(logits)-torch.log(-torch.log(torch.rand(logits.size()).type_as(logits))), dim=1)
        if return_prob == True:
            # index = torch.gather(logits, 1, output_sample.unsqueeze(1))
            index_onehot = rearrange(F.one_hot(output_sample, num_classes=self.diffusion_item_number), 'b l c -> b c l')
            return_logits = logits * index_onehot
            return output_sample, return_logits
        # output_sample = logits
        return output_sample


    def forward(
            self, 
            input, 
            return_loss=False, 
            return_logits=True, 
            return_att_weight=False,
            is_train=True,
            diffusion_index=None,
            **kwargs):


        batch_size = input.shape[0]


        diffusion_a_matrix, diffusion_b_matrix = self.update_diffusion_matrix()

        if is_train == True:
            timestep = torch.randint(low=0, high=self.diffusion_step, size=(batch_size,))

            if timestep.min().item() == 0:
                # timestep = torch.zeros(batch_size).type(torch.LongTensor)
                onehot_input_loss = self.onehot_input
            else:
                onehot_input_loss = False
        else:
            timestep = torch.full((batch_size,), diffusion_index).type(torch.LongTensor)

        
        if is_train == True:
            ############## use content['logits'] to sample new gt and input ################
            # cont_emb = self.content_emb(input['content_token'], diffusion_ratio=diffusion_ratio)
            # this is way1
            # input_sample = torch.argmax(F.gumbel_softmax(input['content_logits'], dim=1), dim=1)
            # q0 = F.softmax(input['content_logits'], dim=1)
            # this is way2
            # q0 = F.softmax(input['content_logits'], dim=1)
            # if self.onehot_input == True:
            #     sample_image = self.sample_from_logits(q0)
            #     q0 = F.one_hot(sample_image, num_classes=self.diffusion_item_number)
            #     q0 = rearrange(q0, 'b l c -> b c l')
            # this is way3
            sample_image = input
            q0 = F.one_hot(sample_image, num_classes=self.diffusion_item_number)
            q0 = rearrange(q0, 'b l c -> b c l')

            # q0 = q0.type(torch.cuda.FloatTensor)
            q0 = q0.type(torch.cuda.DoubleTensor)
    
            # personly, I think use q0 is better than use onehot_q0!
            # this_process do: q(t) -> q(t-1), which t = self.diffusion_global_rank + 1
            at, bt = diffusion_a_matrix[timestep+1], diffusion_b_matrix[timestep+1]
            at_1, bt_1 = diffusion_a_matrix[timestep], diffusion_b_matrix[timestep]
            at = at.unsqueeze(1).unsqueeze(2)
            at_1 = at_1.unsqueeze(1).unsqueeze(2)
            bt = bt.unsqueeze(1).unsqueeze(2)
            bt_1 = bt_1.unsqueeze(1).unsqueeze(2)
            qt = (q0 * at + bt).float()
            qt_1 = (q0 * at_1 + bt_1).float()
            xt, qt_prob = self.sample_from_logits(qt, return_prob=True)
            xt_1 = self.sample_from_logits(qt_1)
    
            cont_emb = self.content_emb(xt)

        else:
            cont_emb = self.content_emb(input)

  
        ############# prepare noise input or text input here ####################
        emb = cont_emb
        cond_emb = None
            
        # 2) forward in transformer
        # timestep = torch.randint(low=0, high=self.diffusion_step, size=(batch_size,))
        for block_idx in range(len(self.blocks)):   
            emb, att_weight = self.blocks[block_idx](emb, cond_emb, timestep.cuda()) # B x (Ld+Lt) x D, B x (Ld+Lt) x (Ld+Lt)
        
        # 3) get logits
        logits = self.to_logits(emb) # B x (Ld+Lt) x n


        # 4) get output, especially loss
        out = {}
        if return_logits:
            out['logits'] = logits
        if return_att_weight:
            out['attention_weight'] = att_weight

        if return_loss:
            if onehot_input_loss == False:
                logits = rearrange(logits, 'b l c -> b c l')
                # the key is to get the target distribution #
                this_a = self.diffusion_keep_rate[timestep]
                this_b = (1-this_a)/self.diffusion_item_number
                this_a = this_a.unsqueeze(1).unsqueeze(2).float()
                this_b = this_b.unsqueeze(1).unsqueeze(2).float()

                qt_full = qt_prob.sum(dim=1).unsqueeze(dim=1).expand(-1, self.diffusion_item_number, -1)
                xt_onehot = rearrange(F.one_hot(xt, num_classes=self.diffusion_item_number), 'b l c -> b c l')
                qt_1_prob = qt_1 * xt_onehot
            
                base = this_b * qt_1/qt_full
                target = base + this_a*qt_1_prob/qt_full
                if abs(target.sum().item() - target.size()[0]*target.size()[2]) > 1:
                    print(abs(target.sum().item() - target.size()[0]*target.size()[2]))
                    print("maybe something wrong here --------------------------------------")
                if self.learnable_keep_rate == False:
                    loss = F.kl_div(logits.log_softmax(dim=1), target.float().detach(), reduction='batchmean')
                else:
                    loss = F.kl_div(logits.log_softmax(dim=1), target.float(), reduction='batchmean')
                final_loss = torch.zeros_like(loss)

            else:
                # print("got it")
                assert timestep.min().item() == 0
                logits = logits.contiguous()
                # target = sample_image
                loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), sample_image.view(-1), ignore_index=self.content_ignore_token)
                # temp = torch.argmax(logits, dim=2)
                # acc = (temp==sample_image).sum()/(sample_image.size()[0]*sample_image.size()[1])
                # out['acc']=acc

                al = diffusion_a_matrix[torch.full((batch_size,), self.diffusion_step)]
                bl = diffusion_b_matrix[torch.full((batch_size,), self.diffusion_step)]
                al = al.unsqueeze(1).unsqueeze(2)
                bl = bl.unsqueeze(1).unsqueeze(2)
                ql = (q0 * al + bl).float()
                # I do not want this, but the numerical accuracy is not enough!
                final_loss = F.kl_div(ql.log_softmax(dim=1), torch.zeros_like(ql).fill_(1/self.diffusion_item_number).softmax(dim=1), reduction='batchmean')
                if final_loss.item() > 0.01:
                    final_loss = final_loss * self.final_loss_weight
                else:
                    final_loss = torch.zeros_like(loss)

            ############ print something for debug ###################
            all_number = input.size()[0] * input.size()[1]
            temp_out = torch.argmax(logits.softmax(dim=1), dim=1)
            target_out = torch.argmax(target, dim=1)
            similar_rate = (temp_out == target_out).sum()/all_number
            unchange_rate = (temp_out == input).sum()/all_number
            init_rate = (target_out == input).sum()/all_number
            out['init_acc'] = init_rate
            out['unchange_acc'] = unchange_rate
            out['similar_acc'] = similar_rate


            out['final_loss'] = final_loss
            out['loss'] = loss + final_loss
        return out


    def sample(
            self,
            content_token = None,
            filter_ratio = 0.5,
            temperature = 1.0,
            return_att_weight = False,
            return_logits = False,
            content_logits = None,
            **kwargs):
        this_input = content_token 
                
        if is_primary() and filter_ratio>0.9:
            print("the self.diffusion_keep_rate is:")
            print(self.diffusion_keep_rate*1000000)
        unchange_rate_list = []
        all_number = this_input.size()[0] * this_input.size()[1]
        batch_size = this_input.size()[0]

        with torch.no_grad():
            if True:
                start_step = int(self.diffusion_step * filter_ratio)
                if start_step == 0:
                    start_step = 1
                diffusion_a_matrix, diffusion_b_matrix = self.update_diffusion_matrix()
                at = diffusion_a_matrix[torch.full((batch_size, ), start_step)]
                bt = diffusion_b_matrix[torch.full((batch_size, ), start_step)]
                at = at.unsqueeze(1).unsqueeze(2)
                bt = bt.unsqueeze(1).unsqueeze(2)
                q0 = F.one_hot(this_input, num_classes=self.diffusion_item_number)
                q0 = rearrange(q0, 'b l c -> b c l')
                q0 = q0.type(torch.cuda.DoubleTensor)
                qt = (q0*at+bt).float()
                this_input = self.sample_from_logits(qt)
            else:
                this_input = torch.randint(low=0, high=self.diffusion_item_number, size=this_input.size()).type_as(this_input)

            for diffusion_index in range(start_step-1, -1, -1):
                trans_out = self.forward(this_input, return_loss=False, return_logits=True, return_att_weight=return_att_weight, is_train=False, diffusion_index=diffusion_index, **kwargs)
                logits = trans_out['logits']
                content_token = torch.argmax(logits, dim=2)
                unchange_number = (this_input==content_token).sum().item()
                unchange_rate_list.append(unchange_number/all_number)

                this_input = content_token


                # batch_size = logits.shape[0]
            # content_token = torch.zeros((batch_size,0)).to(condition_token)
            # for index in range(logits.size()[1]):
            #     sample = self._get_sample_from_logits(logits[:,index:index+1,:], filter_ratio=filter_ratio, temperature=temperature)  # B x 1
            #     content_token = torch.cat((content_token, sample), dim=1)
            # content_token = torch.argmax(logits, dim=2)

            output = {'content_token': content_token}
            if is_primary() and filter_ratio>0.9:
                print("the unchange rate list is: ")
                print(unchange_rate_list)
        if return_logits:
            output['logits'] = logits
        return output




