class BaseNet(nn.Module):
    def __init__(self, h_dim=256, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu",
                 clip_dim=512, history_shape=(2, 276), noise_shape=(1, 128),
                 **kargs):
        super().__init__()
        self.h_dim = h_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = activation

        self.history_shape = history_shape
        self.noise_shape = noise_shape
        self.clip_dim = clip_dim
        
        self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
        
        # input embeddings
        self.sequence_pos_encoder = PositionalEncoding(self.h_dim, self.dropout)
        self.embed_timestep = TimestepEmbedder(self.h_dim, self.sequence_pos_encoder)
        self.embed_text = nn.Linear(self.clip_dim, self.h_dim)
        self.embed_history = nn.Linear(self.history_shape[-1], self.h_dim)
        self.embed_noise = nn.Linear(self.noise_shape[-1], self.h_dim)
        

        self.num_layers = num_layers
        self.blocks = nn.ModuleList()
        for i in range(self.num_layers):
            self.blocks.append(TransformerBlock(num_heads=num_heads, latent_dim=ch_hidden, dropout=dropout, ff_size=ff_size))

        # output projection
        self.output_process = nn.Linear(self.h_dim, self.noise_shape[-1])

    def text_process(self, text):
        device = next(self.clip_transformer.parameters()).device

        with torch.no_grad():
            text = clip.tokenize(text, truncate=True).to(device)
            x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
            pe_tokens = x + self.positional_embedding.type(self.dtype)
            x = pe_tokens.permute(1, 0, 2)  # NLD -> LND
            x = self.clip_transformer(x)
            x = x.permute(1, 0, 2)
            clip_out = self.ln_final(x).type(self.dtype)

        out = self.clipTransEncoder(clip_out)
        out = self.clip_ln(out)

        cond = out[torch.arange(x.shape[0]), text.argmax(dim=-1)]

        return cond

    def mask_cond(self, cond, cond_mask_prob = 0.1, force_mask=False):
        bs = cond.shape[0]
        if cond_mask_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * cond_mask_prob).view([bs]+[1]*len(cond.shape[1:]))  # 1-> use null_cond, 0-> use real cond
            return cond * (1. - mask)
        else:
            return cond

    def forward(self, x, label, t, mask=None):
        # x: B, T, input_feats*2
        # output: B, T, input_feats*2
        B, T, _ = x.shape
        x_a, x_b = x[...,:self.input_feats], x[...,self.input_feats:]
        
        # if self.dataset == 'InterHuman':
        if self.task == 'T2M':
            if label is None:
                cond = torch.zeros((x.shape[0], self.text_emb_dim), dtype=x.dtype, device=x.device)
            else:
                cond = self.text_process(label)
                # mask cond
                if self.mode == "train":
                    cond = self.mask_cond(cond, 0.1)
            label_emb = self.label_emb(cond)
        # elif self.dataset == 'CHI3D' or self.dataset == 'InterX':
        elif self.task == 'A2M':
            if label is None:
                label_emb = torch.zeros((x.shape[0], self.ch_hidden), dtype=x.dtype, device=x.device)
            else:
                label_emb = self.label_emb(label.view(-1))
            # mask cond
            if self.mode == "train":
                label_emb = self.mask_cond(label_emb, 0.1)

        # for testing
        if label_emb.shape[0] == 1:
            label_emb = label_emb.repeat(x.shape[0], 1)
        
        time_emb = self.time_emb(t)

        emb = label_emb + time_emb
        # emb = time_emb

        x_emb_a = self.motion_embed(x_a)
        x_emb_b = self.motion_embed(x_b)
        h_a_prev = self.sequence_pos_encoder(x_emb_a)
        h_b_prev = self.sequence_pos_encoder(x_emb_b)

        if mask is None:
            mask = torch.ones(B, T).to(x_a.device)
        else:
            mask = mask[..., 0]
        key_padding_mask = ~(mask > 0.5)
        

        for i,block in enumerate(self.blocks):
            h_a = block(h_a_prev, h_b_prev, emb, key_padding_mask)
            h_b = block(h_b_prev, h_a_prev, emb, key_padding_mask)
            h_a_prev = h_a
            h_b_prev = h_b

        output_a = self.out(h_a)
        output_b = self.out(h_b)

        output = torch.cat([output_a, output_b], dim=-1)

        return output