import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import loralib as lora


class DenoiserMLP(nn.Module):
    def __init__(self,
                 h_dim=512, n_blocks=2, dropout: float = 0.1, activation="gelu",
                 clip_dim=512, history_shape=(2, 276), noise_shape=(1, 128),
                 **kargs):
        super().__init__()
        self.h_dim = h_dim
        self.dropout = dropout
        self.n_blocks = n_blocks
        self.activation = activation

        self.history_shape = history_shape
        self.noise_shape = noise_shape
        self.clip_dim = clip_dim

        # probability of masking the conditional text
        self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
        print('cond_mask_prob:', self.cond_mask_prob)
        self.his_mask_prob = kargs.get('his_mask_prob', 0.)
        print('his_mask_prob:', self.his_mask_prob)

        self.sequence_pos_encoder = PositionalEncoding(self.h_dim, self.dropout)
        self.embed_timestep = TimestepEmbedder(self.h_dim, self.sequence_pos_encoder)
        input_dim = self.h_dim + self.clip_dim + np.prod(history_shape) + np.prod(noise_shape)
        self.input_project = nn.Linear(input_dim, self.h_dim)

        self.mlp = MLPBlock(h_dim=h_dim,
                            out_dim=np.prod(noise_shape),
                            n_blocks=n_blocks,
                            actfun=activation)

    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def mask_cond(self, cond, force_mask=False):
        bs, d = cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_mask_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1)  # 1-> use null_cond, 0-> use real cond
            return cond * (1. - mask)
        else:
            return cond

    def forward(self, x_t, timesteps, y=None):
        """
        x_t: [B, T=1, D]
        timesteps: [batch_size] (int)
        """
        batch_size = x_t.shape[0]
        emb_time = self.embed_timestep(timesteps).squeeze(0)  # [bs, h_dim]
        emb_history = y['history_motion_normalized'].reshape(batch_size, np.prod(self.history_shape))  # [bs, History * nfeats]
        force_mask = y.get('uncond', False)
        emb_text = self.mask_cond(y['text_embedding'], force_mask=force_mask)  # [bs, clip_dim]
        emb_noise = x_t.reshape(batch_size, np.prod(self.noise_shape))  # [bs, noise_dim]
        # print('emb_time shape:', emb_time.shape, 'emb_text shape:', emb_text.shape, 'emb_history shape:', emb_history.shape, 'emb_noise shape:', emb_noise.shape)

        input_embed = torch.cat((emb_time, emb_text, emb_history, emb_noise), dim=1)  # [bs, input_dim]
        output = self.mlp(self.input_project(input_embed))  # [bs, noise_dim]
        output = output.reshape(batch_size, *self.noise_shape)  # [B, noise_shape[0], noise_shape[1]]
        # print('output shape:', output.shape)

        return output

class DenoiserTransformer(nn.Module):
    def __init__(self, h_dim=512, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu",
                 clip_dim=512, history_shape=(2, 276), noise_shape=(1, 128), dim_rel_pose=9, text_ca = False, text_sep=False, use_inter=True,
                 **kargs):
        super().__init__()
        self.h_dim = h_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = activation

        self.history_shape = history_shape
        self.noise_shape = noise_shape
        self.clip_dim = clip_dim
        self.dim_rel_pose = dim_rel_pose

        self.text_ca = text_ca
        self.text_sep = text_sep
        self.use_inter = use_inter
        self.use_indi_text = kargs.get('use_indi_text', False)
        self.attention_sep = kargs.get('attention_sep', False)
        
        self.inter_first = kargs.get('inter_first', True)
        self.text_first = kargs.get('text_first', True)
        
        self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
        self.his_mask_prob = kargs.get('his_mask_prob', 0.)
        self.interaction_mask_prob = kargs.get('interaction_mask_prob', 0.)
        self.shared_mask = kargs.get('shared_mask', True)
        
        self.use_pre_latent = kargs.get('use_pre_latent', False)
        
        self.use_extra_pe = kargs.get('use_extra_pe', False)
        self.use_step_pe = kargs.get('use_step_pe', True)
        self.merge_his = kargs.get('merge_his', True)
        self.merge_partner_his = kargs.get('merge_partner_his', True)
        
        # input embeddings
        self.sequence_pos_encoder = PositionalEncoding(self.h_dim, self.dropout)
        self.embed_timestep = TimestepEmbedder(self.h_dim, self.sequence_pos_encoder)
        if self.use_extra_pe:
            self.embed_global = RatioFourierEncoder(self.h_dim)
        if self.use_indi_text:
            self.embed_text = nn.Linear(self.clip_dim, self.h_dim)
        self.embed_history = nn.Linear(self.history_shape[-1], self.h_dim)
        self.embed_noise = nn.Linear(self.noise_shape[-1], self.h_dim)
        if self.use_pre_latent:
            dim_reltrans = 9
            self.pre_max_len = kargs.get('pre_max_len', 10)
            self.embed_pre_latent = nn.Linear(self.noise_shape[-1]+dim_reltrans, self.h_dim)
            self.pre_pos_encoder = PositionalEncoding(self.h_dim, self.dropout)
        if self.use_inter:
            self.embed_history_b = nn.Linear(self.history_shape[-1], self.h_dim)
            self.embed_text_inter = nn.Linear(self.clip_dim, self.h_dim)
            # self.zero_init_linear(self.embed_history_b)
            # self.zero_init_linear(self.embed_text_inter)
            
        
        self.num_layers = num_layers
        self.blocks = nn.ModuleList()
        for i in range(self.num_layers):
            self.blocks.append(TransformerBlock(latent_dim=h_dim,
                                                num_heads=num_heads, 
                                                dropout=dropout, 
                                                ff_size=ff_size,
                                                activation=self.activation,
                                                text_ca=self.text_ca,
                                                text_sep=self.text_sep,
                                                use_inter=self.use_inter,
                                                inter_first=self.inter_first,
                                                text_first=self.text_first,
                                                use_indi_text=self.use_indi_text,
                                                attention_sep=self.attention_sep,
                                                use_pre_latent=self.use_pre_latent,
                                                merge_his=self.merge_his,
                                                merge_partner_his=self.merge_partner_his,))
        # output projection
        self.output_process = nn.Linear(self.h_dim, self.noise_shape[-1])

    def zero_init_linear(self, linear_layer):
        nn.init.zeros_(linear_layer.weight)
        if linear_layer.bias is not None:
            nn.init.zeros_(linear_layer.bias)
    
    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def mask_cond(self, cond, mask_prob, force_mask=False):
        bs = cond.shape[0]
        if force_mask:
            mask = torch.ones((bs, 1), device=cond.device)
        elif self.training and mask_prob > 0.:
            mask = torch.bernoulli(torch.ones((bs, 1), device=cond.device) * mask_prob)  # 1-> use null_cond, 0-> use real cond
        else:
            mask = torch.zeros((bs, 1), device=cond.device)
        mask = mask.unsqueeze(-1)
        return mask

    def forward(self, x_t, timesteps, y=None):
        """
        x_t: [B, T=1, D]
        timesteps: [batch_size] (int)
        y: dict
        """
        bs = x_t.shape[0]
        device = x_t.device
        mask = y.get('primitive_padding_mask', None)
        
        text_padding_mask = y.get('text_mask', None)
        text_padding_mask_inter = y.get('text_mask_inter', None)

        # timestep
        emb_time = self.embed_timestep(timesteps).squeeze(0)  # [1, bs, h_dim] -> [bs, h_dim]
        

        # interaction condition 
        # inter_text
        text_indi = y.get('text_embedding', None)
        text_inter = y.get('text_embedding_inter', None)
        if text_indi is None:
            text_indi = torch.zeros_like(y['text_embedding_inter']).to(device)
        if self.text_ca and not self.text_sep:
            if self.use_indi_text:
                text_indi = text_indi.unsqueeze(1)
            text_inter = text_inter.unsqueeze(1)
        
        # history
        history_a = y['history_motion_normalized']  # [bs, his_length, h_dim]
        history_b = y.get('history_motion_normalized_b', None)
        if history_b is None:
            history_b = torch.zeros_like(history_a).to(device)
            
        # pre latent
        if self.use_pre_latent:
            pre_latent = y.get('pre_latent', None)
            pre_reltrans = y.get('pre_reltrans', None)
            if pre_latent is None:
                emb_pre_latent = torch.zeros((bs, self.pre_max_len, self.h_dim), device=device, dtype=x_t.dtype)
                pre_latent_padding_mask = torch.ones((bs, self.pre_max_len), dtype=torch.bool, device=device)
            else:
                emb_pre_latent = self.embed_pre_latent(torch.cat([pre_latent, pre_reltrans], dim=-1))
                emb_pre_latent = self.pre_pos_encoder(emb_pre_latent.permute(1, 0, 2)).permute(1, 0, 2)
                B, L, D = emb_pre_latent.shape
                if L >= self.pre_max_len:
                    emb_pre_latent = emb_pre_latent[:, -self.pre_max_len:, :]
                    L = self.pre_max_len
                else:
                    padded_latent = torch.zeros(B, self.pre_max_len-L, D, device=emb_pre_latent.device, dtype=emb_pre_latent.dtype)
                    emb_pre_latent = torch.cat([emb_pre_latent, padded_latent], dim=1)  # [B, pre_max_len, h_dim]
                pre_latent_padding_mask = torch.ones(B, self.pre_max_len, dtype=torch.bool, device=emb_pre_latent.device)
                pre_latent_padding_mask[:, :L] = False
        
        
        # masking
        force_mask = y.get('uncond', False)
        if self.shared_mask:
            cond_mask = self.mask_cond(history_a, self.cond_mask_prob, force_mask)
            emb_history_a = self.embed_history(history_a * (1. - cond_mask))    # [bs, his_length, h_dim]
            if self.use_indi_text:
                emb_text = self.embed_text(text_indi*(1. - cond_mask))              # [bs, max_seg/1, h_dim]
            if self.use_inter:
                emb_history_b = self.embed_history_b(history_b * (1. - cond_mask))    # [bs, his_length, h_dim]
                emb_text_inter = self.embed_text_inter(text_inter*(1. - cond_mask))       # [bs, max_seg/1, h_dim]
        else:
            text_mask = self.mask_cond(text_indi, self.cond_mask_prob, force_mask)
            if self.use_indi_text:
                emb_text = self.embed_text(text_indi*(1. - text_mask))                      # [bs, max_seg/1, h_dim]
            force_mask_his = False if self.his_mask_prob < 0 else force_mask
            his_mask = self.mask_cond(history_a, self.his_mask_prob, force_mask_his)
            emb_history_a = self.embed_history(history_a * (1. - his_mask))             # [bs, his_length, h_dim]
            if self.use_inter:
                force_mask_inter = False if self.interaction_mask_prob < 0 else force_mask
                inter_mask = self.mask_cond(history_b, self.interaction_mask_prob, force_mask_inter)
                emb_history_b = self.embed_history_b(history_b * (1. - inter_mask))           # [bs, his_length, h_dim]
                if self.text_ca:
                    emb_text_inter = self.embed_text_inter(text_inter*(1. - inter_mask))              # [bs, max_seg/1, h_dim]
                else:
                    emb_text_inter = self.embed_text_inter(text_inter*(1. - inter_mask.squeeze(1)))
            
        if getattr(self, 'use_extra_pe', False):
            start = y['start_frame'].to(device=device, dtype=torch.long)   # (B,)
            total_frame = y['total_frames'].to(device, dtype=torch.long)
            emb_global = self.embed_global((start/total_frame).unsqueeze(-1))
            force_mask_global = False if self.his_mask_prob < 0 else force_mask
            global_mask = self.mask_cond(emb_global, self.his_mask_prob, force_mask_global).squeeze(-1)
            emb_global = emb_global * (1. - global_mask)
        else:
            emb_global=None
        
        # condition: text, timestep
        if self.use_indi_text and not self.text_ca:
            emb = emb_time + emb_text
        elif not self.use_indi_text and not self.text_ca:
            emb = emb_time + emb_text_inter
        else:
            emb = emb_time
        

        # future motion
        x_emb = self.embed_noise(x_t)
        
        # concat history
        x_emb = torch.cat([emb_history_a, x_emb], dim=1)
        if self.use_step_pe:
            h_a_prev = self.sequence_pos_encoder(x_emb.permute(1, 0, 2)).permute(1, 0, 2)
        else:
            h_a_prev = x_emb
        
        def process_padding_mask(mask, emb, device):
            if mask is None:
                mask = torch.ones(*emb.shape[:-1]).to(device)
                padding_mask = ~(mask > 0.5)
            else:
                padding_mask = mask
            return padding_mask
        key_padding_mask = process_padding_mask(mask, h_a_prev, device)
        if self.use_indi_text:
            text_padding_mask = process_padding_mask(text_padding_mask, emb_text, device)
        if self.use_inter:
            text_padding_mask_inter = process_padding_mask(text_padding_mask_inter, emb_text_inter, device)
        
        if self.use_inter:
            inter_dict = {
                'text_chunks_emb': emb_text_inter,
                'motion_hist_emb': emb_history_b,
            }
            inter_pad = {}
            inter_pad['hist'] = torch.zeros(*emb_history_b.shape[:-1], dtype=torch.bool, device=device)
            inter_pad['inter_text'] = text_padding_mask_inter
        
        for i,block in enumerate(self.blocks):
            emb_text = emb_text if self.use_indi_text and self.text_ca else None
            inter_dict = inter_dict if self.use_inter else None
            inter_pad = inter_pad if self.use_inter else None
            emb_pre_latent = emb_pre_latent if self.use_pre_latent else None
            pre_latent_padding_mask = pre_latent_padding_mask if self.use_pre_latent else None
            h_a = block(h_a_prev, emb, emb_global, emb_text, key_padding_mask, text_padding_mask, inter_dict, inter_pad, emb_pre_latent, pre_latent_padding_mask)
            h_a_prev = h_a

        output = self.output_process(h_a[:,-self.noise_shape[0]:])

        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x, start=0, positions=None):
        # not used in the final model
        if positions is None:
            x = x + self.pe[start:start + x.shape[0]] 
        else:
            pos = torch.as_tensor(positions, dtype=torch.long, device=self.pe.device)
            x = x + self.pe[pos].squeeze(2)                      # [T,B,1,h] -> [T,B,h]
        return self.dropout(x)

class TimestepEmbedder(nn.Module):
    def __init__(self, h_dim, sequence_pos_encoder):
        super().__init__()
        self.h_dim = h_dim
        self.sequence_pos_encoder = sequence_pos_encoder

        time_embed_dim = self.h_dim
        self.time_embed = nn.Sequential(
            nn.Linear(self.h_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

    def forward(self, timesteps):
        return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)

class RatioFourierEncoder(nn.Module):
    def __init__(self, out_dim, num_freq=16, f_min=1.0, f_max=16.0, dropout=0.1):
        super().__init__()
        freqs = torch.linspace(f_min, f_max, num_freq) * (2*3.1415926)
        self.register_buffer("freqs", freqs, persistent=False)
        self.proj = nn.Sequential(
            nn.Linear(num_freq*2, out_dim*2),
            nn.SiLU(),
            nn.Linear(out_dim*2, out_dim),
        )
        self.dropout = dropout

    def forward(self, r):
        # r: (B,1)
        r = r.clamp(1e-6, 1-1e-6)
        rf = r * self.freqs.view(*([1]*(r.dim()-1)), -1)
        feat = torch.cat([torch.sin(rf), torch.cos(rf)], dim=-1) 
        h = self.proj(feat)  # (..., D)
        return h

class MLP(nn.Module):
    def __init__(self, in_dim,
                h_dims=[128,128], activation='tanh', use_lora=False, lora_rank=16):
        super().__init__()
        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'gelu':
            self.activation = torch.nn.GELU()
        elif activation == 'lrelu':
            self.activation = torch.nn.LeakyReLU()
        self.out_dim = h_dims[-1]
        self.layers = nn.ModuleList()
        in_dim_ = in_dim
        for h_dim in h_dims:
            layer = lora.Linear(in_dim_, h_dim, r=lora_rank) if use_lora else nn.Linear(in_dim_, h_dim)
            self.layers.append(layer)
            in_dim_ = h_dim

    def forward(self, x):
        for fc in self.layers:
            x = self.activation(fc(x))
        return x

class MLPBlock(nn.Module):
    def __init__(self, h_dim, out_dim, n_blocks, actfun='relu', residual=True, use_lora=False, lora_rank=16):
        super(MLPBlock, self).__init__()
        self.residual = residual
        self.layers = nn.ModuleList([MLP(h_dim, h_dims=(h_dim, h_dim),
                                        activation=actfun)
                                        for _ in range(n_blocks)]) # two fc layers in each MLP
        self.out_fc = lora.Linear(h_dim, out_dim, r=lora_rank) if use_lora else nn.Linear(h_dim, out_dim)

    def forward(self, x):
        h = x
        for layer in self.layers:
            r = h if self.residual else 0
            h = layer(h) + r
        y = self.out_fc(h)
        return y

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

class AdaLN(nn.Module):
    def __init__(self, latent_dim, embed_dim=None):
        super().__init__()
        if embed_dim is None:
            embed_dim = latent_dim
        self.emb_layers = nn.Sequential(
            # nn.Linear(embed_dim, latent_dim, bias=True),
            nn.SiLU(),
            zero_module(nn.Linear(embed_dim, 2 * latent_dim, bias=True)),
        )
        self.norm = nn.LayerNorm(latent_dim, elementwise_affine=False, eps=1e-6)

    def forward(self, h, emb):
        """
        h: B, T, D
        emb: B, D
        """
        # B, 1, 2D
        emb_out = self.emb_layers(emb)
        # scale: B, 1, D / shift: B, 1, D
        scale, shift = torch.chunk(emb_out, 2, dim=-1)
        h = self.norm(h) * (1 + scale[:, None]) + shift[:, None]
        # h = self.norm(h) * (1 + scale) + shift
        return h

class VanillaSelfAttention(nn.Module):
    def __init__(self, latent_dim, num_head, dropout, embed_dim=None):
        super().__init__()
        self.num_head = num_head
        self.norm = AdaLN(latent_dim, embed_dim)
        self.norm_g = AdaLN(latent_dim, embed_dim)
        self.attention = nn.MultiheadAttention(latent_dim, num_head, dropout=dropout, batch_first=True,
                                               add_zero_attn=True)

    def forward(self, x, emb, emb_g, key_padding_mask=None):
        """
        x: B, T, D
        """
        x_norm = self.norm(x, emb)
        if emb_g is not None:
            x_norm = self.norm_g(x_norm, emb_g)
        y = self.attention(x_norm, x_norm, x_norm,
                           attn_mask=None,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return y

class CrossAttention(nn.Module):
    def __init__(self, latent_dim, xf_latent_dim, num_head, dropout, embed_dim=None):
        super().__init__()
        self.num_head = num_head
        self.norm = AdaLN(latent_dim, embed_dim)
        self.norm_g = AdaLN(latent_dim, embed_dim)
        self.xf_norm = AdaLN(xf_latent_dim, embed_dim)
        self.xf_norm_g = AdaLN(xf_latent_dim, embed_dim)
        self.attention = nn.MultiheadAttention(latent_dim, num_head, kdim=xf_latent_dim, vdim=xf_latent_dim,
                                               dropout=dropout, batch_first=True, add_zero_attn=True)

    def forward(self, x, xf, emb, emb_g, key_padding_mask=None, query_padding_mask=None):
        """
        x: B, T, D
        xf: B, N, L
        """
        if query_padding_mask is not None:
            attn_mask = query_padding_mask.unsqueeze(-1).expand(-1, -1, xf.shape[1])
            attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
            attn_mask = attn_mask.reshape(-1, x.shape[1], xf.shape[1])
        else:
            attn_mask = None
        x_norm = self.norm(x, emb)
        if emb_g is not None:
            x_norm = self.norm_g(x_norm, emb_g)
        xf_norm = self.xf_norm(xf, emb)
        if emb_g is not None:
            xf_norm = self.xf_norm_g(xf_norm, emb_g)
        y = self.attention(x_norm, xf_norm, xf_norm,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return y

class FFN(nn.Module):
    def __init__(self, latent_dim, ffn_dim, dropout, activation, embed_dim=None):
        super().__init__()
        self.norm = AdaLN(latent_dim, embed_dim)
        self.norm_g = AdaLN(latent_dim, embed_dim)
        self.linear1 = nn.Linear(latent_dim, ffn_dim, bias=True)
        self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim, bias=True))
        if activation == 'tanh':
            self.activation = torch.tanh
        elif activation == 'relu':
            self.activation = torch.relu
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, emb=None, emb_g=None):
        if emb is not None:
            x_norm = self.norm(x, emb)
        else:
            x_norm = x
        if emb_g is not None:
            x_norm = self.norm_g(x_norm, emb_g)
        else:
            x_norm = x_norm
        y = self.linear2(self.dropout(self.activation(self.linear1(x_norm))))
        return y

class TransformerBlock(nn.Module):
    def __init__(self,
                 latent_dim=256,
                 num_heads=4,
                 ff_size=1024,
                 dropout=0.,
                 activation="relu", 
                 text_ca=False,
                 text_sep=False,
                 use_inter=False,
                 text_first=True,
                 **kargs):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.text_ca = text_ca
        self.text_sep = text_sep
        self.use_inter = use_inter
        self.text_first = text_first
        self.use_indi_text = kargs.get('use_indi_text', True)
        self.attention_sep = kargs.get('attention_sep', False)
        self.use_pre_latent = kargs.get('use_pre_latent', False)
        self.merge_his = kargs.get('merge_his', True)
        self.merge_partner_his = kargs.get('merge_partner_his', True)

        self.sa_block = VanillaSelfAttention(latent_dim, num_heads, dropout)
        if self.use_pre_latent:
            self.pre_latent_ca_block = CrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
        if self.use_indi_text and self.text_ca:
            self.text_ca_block = CrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
        if self.use_inter:
            if self.attention_sep:
                self.his_ca = CrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
                if self.text_ca:
                    self.inter_text_ca = CrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
            else:
                self.inter_ca = CrossAttention(latent_dim, latent_dim, num_heads, dropout, latent_dim)
            self.inter_pe = PositionalEncoding(latent_dim, dropout)
            self.control_proj = nn.Linear(latent_dim, latent_dim)
            
        if self.text_sep:
            self.text_pe = PositionalEncoding(latent_dim, dropout)
        
        self.ffn = FFN(latent_dim=latent_dim, 
                       ffn_dim=ff_size, 
                       dropout=dropout,
                       activation=activation, 
                       embed_dim=latent_dim)

    def apply_text_ca(self, x, text_emb, emb, emb_g, text_padding_mask, key_padding_mask):
        if self.text_sep:
            text_emb = self.text_pe(text_emb.permute(1, 0, 2)).permute(1, 0, 2)
        x = x + self.text_ca_block(x, text_emb, emb, emb_g, text_padding_mask, key_padding_mask)
        return x

    def apply_interca(self, x, emb, emb_g, inter_dict, inter_pad, key_padding_mask):
        motion_b = inter_dict['motion_hist_emb']              # [B, T, D]
        text_inter = inter_dict['text_chunks_emb']            # [B, L, D]
        if self.attention_sep:
            if self.text_first:
                if self.text_ca:
                    if self.text_sep:
                        text_inter = self.text_pe(text_inter.permute(1, 0, 2)).permute(1, 0, 2)
                    h = self.inter_text_ca(x, text_inter, emb, emb_g, inter_pad['inter_text'], key_padding_mask)
                else:
                    h = x
                if self.merge_partner_his:
                    h = self.his_ca(h, motion_b, emb, emb_g, inter_pad['hist'], key_padding_mask)
            else:
                if self.merge_partner_his:
                    h = self.his_ca(x, motion_b, emb, emb_g, inter_pad['hist'], key_padding_mask)
                else:
                    h = x
                if self.text_ca:
                    h = self.inter_text_ca(h, text_inter, emb, emb_g, inter_pad['inter_text'], key_padding_mask)
        else:
            if self.text_ca:
                emb_inter = torch.cat([motion_b, text_inter], dim=1)  # [B, T+L, D]
                emb_inter = self.inter_pe(emb_inter.permute(1, 0, 2)).permute(1, 0, 2)  # [B, T+L, D]
                inter_padding_mask = torch.cat([
                    inter_pad['hist'], 
                    inter_pad['inter_text']
                ], dim=1)
                h = self.inter_ca(x, emb_inter, emb, emb_g, inter_padding_mask, key_padding_mask)
            else:
                emb_inter = motion_b  # [B, T, D]
                emb_inter = self.inter_pe(emb_inter.permute(1, 0, 2)).permute(1, 0, 2)  # [B, T+1+L, D]
                inter_padding_mask = inter_pad['hist']
                h = self.inter_ca(x, emb_inter, emb, emb_g, inter_padding_mask, key_padding_mask)
        h = self.control_proj(h)
        return h
    
    def forward(self, x, emb=None, emb_g=None, text_emb=None, key_padding_mask=None, text_padding_mask=None, inter_dict=None, inter_pad=None, emb_pre_latent=None, pre_latent_padding_mask=None):
        if self.use_indi_text and self.text_ca and self.text_first:
            x = self.apply_text_ca(x, text_emb, emb, emb_g, text_padding_mask, key_padding_mask)

        if self.merge_his:
            h1 = self.sa_block(x, emb, emb_g, key_padding_mask)
            x = h1 + x
        
        if self.use_pre_latent and emb_pre_latent is not None:
            x = x + self.pre_latent_ca_block(x, emb_pre_latent, emb, emb_g, pre_latent_padding_mask, key_padding_mask)

        if self.use_indi_text and self.text_ca and not self.text_first:
            x = self.apply_text_ca(x, text_emb, emb, emb_g, text_padding_mask, key_padding_mask)

        if self.use_inter and inter_dict is not None:
            control_res = self.apply_interca(x, emb, emb_g, inter_dict, inter_pad, key_padding_mask)
            x = x + control_res
        
        out = self.ffn(x, emb, emb_g)
        out = out + x
        return out
