from time import time
import numpy as np
import torch
import torch.nn as nn


class AllPairsNoisyEvenFasterBlackBoxNet(nn.Module):
    def __init__(
        self,
        hidden_size=512,
        in_size=1000,
        girth=512,
        num_digits=21,
        dbase=10.0,
        device="cpu",
    ):
        super().__init__()
        _in_size = num_digits * (in_size + 0)
        self.mlp = nn.Sequential(
            nn.Linear(_in_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            # nn.Linear(girth, girth),
            # nn.RMSNorm(girth),
            # nn.GELU(),
            nn.Linear(girth, hidden_size),
            nn.RMSNorm(hidden_size),
            nn.GELU(),
        )
        self.relation = nn.Sequential(
            nn.Linear(hidden_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, 1)
        )
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self._dbase = dbase
        self._fdt = (self._dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _add_noise(self, x):
        """ x.shape: (B, D)
        """
        # noise_level = np.random.choice([0, 1e-3, 1e-2, 1e-1])
        noise_level = 1e-1
        if noise_level == 0:
            return x
        std = noise_level * ((x ** 2).mean(dim=-1) ** 0.5)
        std = std.unsqueeze(-1)
        noise = torch.normal(
            torch.zeros(x.shape).to(x.device), std
        ).to(x.device)
        x = x + noise
        return x

    def _to_digits(self, x):
        """ x.shape: (B, D)
        """
        B, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self._dbase)
        x = x.reshape(B, D * self.num_digits) # (B, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = x / self._dbase
        return x

    def encode(self, x):
        """ x.shape: (B, D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.mlp(x) # (B, clen + 1)
        return z

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = z.unsqueeze(1) # (B, 1, clen + 1)
        z = (z - z.permute(1, 0, 2)) # (B, B, clen + 1)
        tests = self.relation(z)[:, :, 0] # (B, B)
        return tests

    def tgt_run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = (z.clone()[-1:] - z) # (B, clen + 1)
        tests = self.relation(z)[:, 0] # (B,)
        return tests

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x):
        """ x.shape: (B, D)
            x[-1] is the only 'target' element
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.tgt_run_tests(z) # (B,)
        # tests = nn.functional.sigmoid(tests) # (B,)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class ClfProbe(nn.Module):
    def __init__(self, in_size, aggregation="last"):
        super().__init__()
        self.agg = aggregation
        self.linear = nn.Linear(in_size, 1)

    def forward(self, x):
        """ x.shape: (B, L, D)
        """
        if self.agg == "last":
            x = x.clone()[:, -1] # (B, D)
        elif self.agg == "first":
            x = x.clone()[:, 0] # (B, D)
        else:
            x = x.mean(dim=1) # (B, D)
        x = self.linear(x).squeeze(-1) # (B,)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=32, device="cpu"):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        # self.register_buffer('pe', pe)
        self.pe = pe.to(device)

    def indexate(self, indices):
        """
        self.pe.shape: (L, 1, D)
        indices.shape: (B, L)
        """
        col0 = torch.zeros_like(indices).to(self.pe.device)
        idx_pe = self.pe[indices, col0]
        return idx_pe

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x.permute(1, 0, 2) + self.pe[:x.shape[1]]
        x = x.permute(1, 0, 2)
        return x

class AllPairsSymNumNet(nn.Module):
    def __init__(
        self,
        # num_hidden_size=512,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        device="cpu",
    ):
        super().__init__()
        vocab_upper_bound = 32
        self.sym_encoder = nn.Sequential(
            nn.Embedding(
                vocab_upper_bound,
                sym_dmodel,
                padding_idx=0,
                sparse=False,
            ),
            nn.RMSNorm(sym_dmodel),
            nn.GELU(),
            nn.Linear(sym_dmodel, sym_dmodel),
            nn.RMSNorm(sym_dmodel),
            nn.GELU(),
        )
        _in_size = num_digits * (num_in_size + 0)
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        # self._pos_enc = 
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation="gelu",
            dropout=0.0,
        )
        self.relation = nn.Sequential(
            PositionalEncoding(
                sym_dmodel,
                max_len=vocab_upper_bound,
                device=device,
            ),
            nn.TransformerEncoder(
                _encoder_layer,
                num_layers=sym_num_layers,
            ),
            ClfProbe(sym_dmodel, aggregation="last"),
        )
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, D)
        """
        B, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, D * self.num_digits) # (B, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, BIG_D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, D)
        return z

    def encode_sym(self, x):
        """ x.shape: (B, L, D)
        """
        z = self.sym_encoder(x) # (B, L, D)
        return z

    def run_tests(self, z_sym, z_num):
        """ z_sym.shape: (B, L, D)
            z_num.shape: (B, D)
        """
        z_sym = (
            z_sym
            .unsqueeze(0)
            .repeat(z_sym.shape[0], 1, 1, 1)
        ) # (B, B, L, D)
        z_num = (
            z_num
            .unsqueeze(1)
            .repeat(1, z_num.shape[0], 1)
            .unsqueeze(2)
        ) # (B, B, 1, D)
        z_cat = torch.concat(
            [z_sym, z_num],
            dim=2,
        ) # (B, B, L + 1, D)
        B, _, L_, D = z_cat.shape
        z_mask = (1.0 - torch.eye(B).to(z_cat.device))
        z_cat = z_cat * z_mask.unsqueeze(-1).unsqueeze(-1)
        z_cat = z_cat.reshape(-1, L_, D) # (B * B, L + 1, D)
        tests = self.relation(z_cat).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_sym, z_num):
        """ z_sym.shape: (B, L, D)
            z_num.shape: (D,)
        """
        z_num = (
            z_num
            .unsqueeze(0)
            .unsqueeze(0)
            .repeat(z_sym.shape[0], 1, 1)
        ) # (B, 1, D)
        z_cat = torch.concat(
            [z_sym, z_num],
            dim=1,
        ) # (B, L + 1, D)
        tests = self.relation(z_cat) # (B,)
        return tests

    def forward(self, x_sym, x_num):
        """ x_sym.shape: (B, V)
            x_num.shape: (B, D)
        """
        z_sym = self.encode_sym(x_sym) # (B, L, dmodel)
        z_num = self.encode_num(x_num) # (B, dmodel)
        tests = self.run_tests(z_sym, z_num) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_sym, x_num):
        """ x_sym.shape: (B, V)
            x_num.shape: (D,)
        """
        z_sym = self.encode_sym(x_sym) # (B, L, dmodel)
        z_num = self.encode_num(x_num.unsqueeze(0)).squeeze(0) # (dmodel,)
        tests = self.tgt_run_tests(z_sym, z_num) # (B,)
        return tests

    def predict(self, x_sym, x_num, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(x_sym, x_num) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        num_enc_two_layers=True,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="last",
        reinit_trf=False,
        use_expr_embs=False,
        sort_by_results=False,
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        num_hidden_size = sym_dmodel
        if num_enc_two_layers:
            self.num_encoder = nn.Sequential(
                nn.Linear(_in_size, num_girth),
                nn.RMSNorm(num_girth),
                nn.GELU(),
                nn.Linear(num_girth, num_hidden_size),
                nn.RMSNorm(num_hidden_size),
                nn.GELU(),
            )
        else:
            self.num_encoder = nn.Sequential(
                nn.Linear(_in_size, num_girth),
                nn.RMSNorm(num_girth),
                nn.GELU(),
                nn.Linear(num_girth, num_girth),
                nn.RMSNorm(num_girth),
                nn.GELU(),
                nn.Linear(num_girth, num_hidden_size),
                nn.RMSNorm(num_hidden_size),
                nn.GELU(),
            )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        self._expr_embs = None
        if use_expr_embs:
            max_vocab_len = 32
            self._expr_embs = nn.Embedding(
                max_vocab_len,
                num_hidden_size,
                padding_idx=0,
                sparse=False,
            )
        _trf_activation = "relu" if reinit_trf else "gelu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        if reinit_trf:
            ####### Manual initialization of transformer encoder layers
            for module in self._tr_enc.modules():
                if isinstance(module, nn.TransformerEncoderLayer):
                    for submodule in module.modules():
                        if isinstance(submodule, nn.Linear):
                            nn.init.kaiming_uniform_(
                                submodule.weight,
                                nonlinearity=_trf_activation,
                            )
                            if submodule.bias is not None:
                                nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        self._sort_res = sort_by_results
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        if self._sort_res:
            B, L, D = x.shape
            sort_indices = x.clone()[:, :1].sort(dim=-1).indices
            sort_indices = sort_indices.repeat(1, L, 1)
            indices_b = torch.arange(B).reshape(B, 1, 1).repeat(1, L, D).to(x.device)
            indices_l = torch.arange(L).reshape(1, L, 1).repeat(B, 1, D).to(x.device)
            x = x[indices_b, indices_l, sort_indices]
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D)
        return z

    def relation(self, z_num, expr_ids=None, parent_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        if parent_ids is not None:
            z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        if expr_ids is not None:
            z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids=None, parent_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_rows = z_num.clone()[:, :1] # (B, 1, D)
        z_rows = z_rows.repeat(1, z_num.shape[1], 1) # (B, L, D)
        z_num = (
            z_rows.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, L, D)
        B, _, L, D = z_num.shape
        z_num = z_num.reshape(-1, L, D)
        if parent_ids is not None:
            parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1)
            parent_ids = parent_ids.reshape(-1, L)
        if expr_ids is not None:
            expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1)
            expr_ids = expr_ids.reshape(-1, L)
        tests = self.relation(
            z_num,
            expr_ids=expr_ids,
            parent_ids=parent_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids=None, parent_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        z_num = (z_tgt - z_num) # (B, L, D)
        B = z_num.shape[0]
        tests = self.relation(
            z_num,
            expr_ids=expr_ids,
            parent_ids=parent_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids=None, parent_ids=None):
        """ x_num.shape: (B, L, D)
        """
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids=expr_ids,
            parent_ids=parent_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids=None, parent_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        D = x_tgt.shape[0]
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids=expr_ids,
            parent_ids=parent_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, th=0.5, expr_ids=None, parent_ids=None):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids=expr_ids,
            parent_ids=parent_ids,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsSingleDomainTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="first",
        sort_by_results=False,
        domain_embeddings=False,
        use_diffs=False,
        sorts=None,
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        if use_diffs:
            _in_size += num_digits * (num_in_size - 1)
        elif sorts is not None:
            _in_size += num_digits * len(sorts) * (num_in_size - 1)
        self._use_diffs = use_diffs
        self._sorts = sorts
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        if domain_embeddings:
            max_num_domains = 128
            self._domain_embs = nn.Embedding(
                max_num_domains,
                num_hidden_size,
                padding_idx=(max_num_domains - 1),
                sparse=False,
            )
        _trf_activation = "relu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in self._tr_enc.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(
                            submodule.weight,
                            nonlinearity=_trf_activation,
                        )
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        self._sort_res = sort_by_results
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        if self._sort_res:
            B, L, D = x.shape
            sort_indices = x.clone()[:, :1].sort(dim=-1).indices
            sort_indices = sort_indices.repeat(1, L, 1)
            indices_b = torch.arange(B).reshape(B, 1, 1).repeat(1, L, D).to(x.device)
            indices_l = torch.arange(L).reshape(1, L, 1).repeat(B, 1, D).to(x.device)
            x = x[indices_b, indices_l, sort_indices]
        if self._use_diffs:
            x = torch.concat(
                [x, torch.diff(x, n=1, dim=-1)],
                dim=-1,
            ) # (B, L, D + (D - 1))
        elif self._sorts is not None:
            x = torch.concat(
                [x] + [torch.diff(x[:, :, _sort_ids], n=1, dim=-1) for _sort_ids in self._sorts],
                dim=-1
            ) # (B, L, D + k * (D - 1))
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
            domain_ids.shape: (,)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        if domain_ids is not None:
            D = z_num.shape[-1]
            z_num = z_num + self._domain_embs(domain_ids).reshape(1, 1, D) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, domain_ids=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, domain_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, domain_ids=None, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsStatsTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="first",
        sorts=None,
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        if sorts is not None:
            _in_size += num_digits * len(sorts) * (num_in_size - 1)
        self._sorts = sorts
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        _trf_activation = "relu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in self._tr_enc.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(
                            submodule.weight,
                            nonlinearity=_trf_activation,
                        )
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def _basic_stats(self, x):
        """ x.shape (*, D)
        """
        *leading, D = x.shape
        n = D
        s = x.sum(dim=-1)
        mean = x.mean(dim=-1)
        std = x.std(dim=-1, unbiased=False)
        xmin, _ = x.min(dim=-1)
        xmax, _ = x.max(dim=-1)
        rang = xmax - xmin
        median = x.median(dim=-1).values
        var = std ** 2
        x_centered = x - mean.unsqueeze(-1)
        m2 = (x_centered ** 2).mean(dim=-1)
        m3 = (x_centered ** 3).mean(dim=-1)
        m4 = (x_centered ** 4).mean(dim=-1)
        skew = m3 / (m2 ** 1.5 + 1e-8)
        kurt = m4 / (m2 ** 2  + 1e-8) - 3.0
        stats = torch.stack(
            (s, mean, std, xmin, xmax, rang, median, var, skew, kurt), dim=-1
        )
        return stats

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        if self._sorts is not None:
            x = torch.concat(
                [x] + [torch.diff(x[:, :, _sort_ids], n=1, dim=-1) for _sort_ids in self._sorts],
                dim=-1
            ) # (B, L, D + k * (D - 1))
        x = ...
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
            domain_ids.shape: (,)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, domain_ids=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, domain_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, domain_ids=None, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsSimpresNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_girth=1024,
        max_len=12,
        device="cpu",
    ):
        super().__init__()
        self._max_len = max_len
        _in_size = num_digits * (num_in_size + 0)
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        rel_input_size = num_hidden_size * max_len
        self._rel = nn.Sequential(
            nn.Linear(rel_input_size, sym_girth),
            nn.RMSNorm(sym_girth),
            nn.GELU(),
            nn.Linear(sym_girth, sym_girth),
            nn.RMSNorm(sym_girth),
            nn.GELU(),
            nn.Linear(sym_girth, 1),
        )
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids=None, parent_ids=None, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
            domain_ids.shape: (,)
        """
        clf = self._rel(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids=None, parent_ids=None, domain_ids=None):
        """ z_num.shape: (B, M, D)
        """
        B, M, D = z_num.shape
        ####### row - col; row --| col
        z_row = (
            z_num
            .clone()
            [:, :1]
            .repeat(1, M, 1)
            .unsqueeze(1)
        ) # (B, 1, M, D)
        z_col = z_num.unsqueeze(0)
        z_diff = (z_row - z_col).reshape(B, B, M * D) # (B, B, M, D)
        ####### columns must be repeated, not rows
        tests = self.relation(z_diff).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids=None, parent_ids=None, domain_ids=None):
        """ z_num.shape: (B, M, D)
            z_tgt.shape: (1, 1, D)
        """
        B, M, D = z_num.shape
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, M, D)
        z_num = z_num.reshape(B, M * D)
        tests = self.relation(z_num).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids=None, parent_ids=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :self._max_len]
        z_num = self.encode_num(x_num) # (B, M, dmodel)
        tests = self.run_tests(z_num) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids=None, parent_ids=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :self._max_len]
        z_num = self.encode_num(x_num) # (B, M, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(z_num, z_tgt) # (B,)
        return tests

    def predict(self, x_num, expr_ids=None, parent_ids=None, domain_ids=None, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(x_num) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsMultiDomainTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="first",
        num_sorts=1,
        device="cpu",
    ):
        super().__init__()
        assert num_sorts > 0
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + num_sorts * (num_in_size - 1))
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        max_num_domains = 80
        self._domain_embs = nn.Embedding(
            max_num_domains,
            num_hidden_size,
            padding_idx=(max_num_domains - 1),
            sparse=False,
        )
        _trf_activation = "relu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in self._tr_enc.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(
                            submodule.weight,
                            nonlinearity=_trf_activation,
                        )
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        # self._sort_res = sort_by_results
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x, sorts=None):
        """ x.shape: (B, L, BIG_D)
        """
        if sorts is not None:
            x = torch.concat(
                [x] + [torch.diff(x[:, :, _sort_ids], n=1, dim=-1) for _sort_ids in sorts],
                dim=-1
            ) # (B, L, D + k * (D - 1))
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
            domain_ids.shape: (,)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        if domain_ids is not None:
            D = z_num.shape[-1]
            z_num = z_num + self._domain_embs(domain_ids).reshape(1, 1, D) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, domain_ids=None, sorts=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, domain_ids=None, sorts=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D), sorts=sorts) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, domain_ids=None, sorts=None, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            domain_ids=domain_ids,
            sorts=sorts,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsUniversalTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="first",
        num_sorts=1,
        device="cpu",
    ):
        super().__init__()
        assert num_sorts > 0
        max_prefix_len = 32
        # _in_size = num_digits * (num_in_size + num_sorts * (num_in_size - 1))
        _in_size = num_digits * (num_in_size * num_sorts)
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        _trf_activation = "relu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in self._tr_enc.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(
                            submodule.weight,
                            nonlinearity=_trf_activation,
                        )
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        # self._sort_res = sort_by_results
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x, sorts=None):
        """ x.shape: (B, L, BIG_D)
        """
        if sorts is not None:
            # x = torch.concat(
            #     [x] + [torch.diff(x[:, :, _sort_ids], n=1, dim=-1) for _sort_ids in sorts],
            #     dim=-1
            # ) # (B, L, D + k * (D - 1))
            x = torch.concat(
                [x[:, :, _sort_ids] for _sort_ids in sorts],
                dim=-1
            ) # (B, L, k * D)
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, sorts=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        # x_num = x_num.clone()[:, :1]
        # z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        # z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_num = self.encode_num(x_num, sorts=sorts) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, sorts=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        # x_num = x_num.clone()[:, :1]
        # z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        # z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_num = self.encode_num(x_num, sorts=sorts) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D), sorts=sorts) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, sorts=None, th=0.5, domain_ids=None):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            sorts=sorts,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

def _bin_and_average(x: torch.Tensor, idx: torch.Tensor, K: int):
    """
    x   : (*, D)
    idx : (D,)  inteiros de 0 .. K-1
    K   : número de bins de saída
    return: tensor (*, K) com médias
    """
    *batch_dims, D = x.shape                # tamanhos antes da última dimensão
    device = x.device
    # 1) Soma por bin
    out_shape = (*batch_dims, K)
    sums = torch.zeros(out_shape, device=device, dtype=x.dtype)
    # expand idx para ter mesmo número de dimensões que x
    expand_idx = idx.view((1,)*len(batch_dims) + (-1,))   # shape -> (1,1,...,D)
    expand_idx = expand_idx.expand(*batch_dims, D)        # agora (*, D)
    sums.scatter_add_(dim=-1, index=expand_idx, src=x)
    # 2) Contagem por bin
    ones = torch.ones_like(x)
    counts = torch.zeros(out_shape, device=device, dtype=x.dtype)
    counts.scatter_add_(dim=-1, index=expand_idx, src=ones)
    # 3) Média: soma / contagem   (evitando divisão por zero)
    # Onde count==0 queremos 0.  Usamos where para não produzir NaNs/Infs.
    mean = torch.where(counts > 0, sums / counts, torch.zeros_like(sums))
    return mean

class AllPairsBinnedTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="first",
        num_sorts=1,
        num_bins=1000,
        device="cpu",
    ):
        super().__init__()
        assert num_sorts > 0
        max_prefix_len = 32
        # _in_size = num_digits * (num_in_size + num_sorts * (num_in_size - 1))
        # _in_size = num_digits * (num_in_size * num_sorts)
        _in_size = num_digits * num_bins
        num_hidden_size = sym_dmodel
        self._num_bins = num_bins
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        _trf_activation = "relu"
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation=_trf_activation,
            dropout=0.0,
        )
        self._tr_enc = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in self._tr_enc.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(
                            submodule.weight,
                            nonlinearity=_trf_activation,
                        )
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        # self._sort_res = sort_by_results
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x, sorts=None):
        """ x.shape: (B, L, BIG_D)
            bins (a.k.a sorts)
        """
        if sorts is not None:
            # x_binned = torch.zeros(*x.shape[:2], self._num_bins).float().to(x.device)
            # x_binned[:, :, sorts] = x
            # x = x_binned.clone() # (B, L, num_bins)
            x = _bin_and_average(x, sorts, self._num_bins) # (B, L, num_bins)
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        z_num = self._tr_enc(z_num) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, domain_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, sorts=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        # z_num = self.encode_num(x_num, sorts=sorts) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, sorts=None, domain_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num, sorts=sorts) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        # z_num = self.encode_num(x_num, sorts=sorts) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D), sorts=sorts) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, sorts=None, th=0.5, domain_ids=None):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            sorts=sorts,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsExpertsTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="mean",
        use_diffs=False,
        sorts=None,
        num_experts=32,
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        if use_diffs:
            _in_size += num_digits * (num_in_size - 1)
        elif sorts is not None:
            _in_size += num_digits * len(sorts) * (num_in_size - 1)
        self._use_diffs = use_diffs
        self._sorts = sorts
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            sym_dmodel,
            max_len=max_prefix_len,
            device=device,
        )
        max_vocab_len = 32
        self._expr_embs = nn.Embedding(
            max_vocab_len,
            num_hidden_size,
            padding_idx=0,
            sparse=False,
        )
        _trf_activation = "relu"
        experts = []
        for expert_id in range(num_experts):
            tr_expert = nn.Sequential(
                nn.TransformerEncoder(
                    nn.TransformerEncoderLayer(
                        d_model=sym_dmodel,
                        nhead=sym_nhead,
                        batch_first=True,
                        activation=_trf_activation,
                        dropout=0.0,
                    ),
                    num_layers=sym_num_layers,
                ),
                ClfProbe(sym_dmodel, aggregation=clf_agg),
            )
            experts.append(tr_expert)
        self._tr_enc = nn.ModuleList(experts)
        ####### Manual initialization of transformer encoder layers
        for expert in self._tr_enc:
            for seq_module in expert.modules():
                for module in seq_module.modules():
                    if isinstance(module, nn.TransformerEncoderLayer):
                        for submodule in module.modules():
                            if isinstance(submodule, nn.Linear):
                                nn.init.kaiming_uniform_(
                                    submodule.weight,
                                    nonlinearity=_trf_activation,
                                )
                                if submodule.bias is not None:
                                    nn.init.constant_(submodule.bias, 0)
        # self._clf = ClfProbe(sym_dmodel, aggregation=clf_agg)
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        if self._use_diffs:
            x = torch.concat(
                [x, torch.diff(x, n=1, dim=-1)],
                dim=-1,
            ) # (B, L, D + (D - 1))
        elif self._sorts is not None:
            x = torch.concat(
                [x] + [torch.diff(x[:, :, _sort_ids], n=1, dim=-1) for _sort_ids in self._sorts],
                dim=-1
            ) # (B, L, D + k * (D - 1))
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D*)
        return z

    def relation(self, z_num, expr_ids, parent_ids, expert_id):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
            expert_id.shape: (,)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num + self._expr_embs(expr_ids) # (B, L, D)
        ####### Use expert from "expert_id"
        clf = self._tr_enc[expert_id](z_num) # (B,)
        return clf

    def run_tests(self, z_num, expr_ids, parent_ids, expert_id):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        ####### row - col; row --| col
        z_diff = (
            z_num.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, D)
        z_diff = z_diff.unsqueeze(2) # (B, B, 1, D)
        z_diff = z_diff.repeat(1, 1, L, 1) # (B, B, L, D)
        z_diff = z_diff.reshape(-1, L, D) # (B * B, L, D)
        ####### columns must be repeated, not rows
        parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        parent_ids = parent_ids.reshape(-1, L) # (B * B, L)
        expr_ids = expr_ids.unsqueeze(0).repeat(B, 1, 1) # (B, B, L)
        expr_ids = expr_ids.reshape(-1, L) # (B * B, L)
        tests = self.relation(
            z_diff,
            expr_ids,
            parent_ids,
            expert_id,
        ).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, expr_ids, parent_ids, expert_id):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        B, L, D = z_num.shape
        z_num = z_num.clone()[:, 0] # (B, D)
        z_tgt = z_tgt.clone()[:, 0] # (1, D)
        ####### row - col; row "is preceded by" col
        z_num = (z_tgt - z_num) # (B, D)
        z_num = z_num.unsqueeze(1) # (B, 1, D)
        z_num = z_num.repeat(1, L, 1) # (B, L, D)
        tests = self.relation(
            z_num,
            expr_ids,
            parent_ids,
            expert_id,
        ).reshape(B) # (B,)
        return tests

    def forward(self, x_num, expr_ids, parent_ids, expert_id):
        """ x_num.shape: (B, L, D)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        tests = self.run_tests(
            z_num,
            expr_ids,
            parent_ids,
            expert_id,
        ) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, expr_ids, parent_ids, expert_id):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        B, L, D = x_num.shape
        x_num = x_num.clone()[:, :1]
        z_num = self.encode_num(x_num) # (B, 1, dmodel)
        z_num = z_num.repeat(1, L, 1) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(
            z_num,
            z_tgt,
            expr_ids,
            parent_ids,
            expert_id,
        ) # (B,)
        return tests

    def predict(self, x_num, expr_ids, parent_ids, expert_id, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(
            x_num,
            expr_ids,
            parent_ids,
            expert_id,
        ) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsConvTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        clf_agg="first",
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._pos_enc = PositionalEncoding(
            num_hidden_size,
            max_len=max_prefix_len,
            device=device,
        )
        self._conv_enc = nn.Sequential(
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
            nn.Conv1d(num_hidden_size, num_hidden_size, kernel_size=3, padding=1),
            # nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        self._clf = ClfProbe(num_hidden_size, aggregation=clf_agg)
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D)
        return z

    def relation(self, z_num, parent_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_num = self._pos_enc(z_num) # (B, L, D)
        if parent_ids is not None:
            z_num = z_num + self._pos_enc.indexate(parent_ids) # (B, L, D)
        z_num = z_num.permute(0, 2, 1) # (B, D, L)
        z_num = self._conv_enc(z_num) # (B, D, L)
        z_num = z_num.permute(0, 2, 1) # (B, L, D)
        clf = self._clf(z_num) # (B,)
        return clf

    def run_tests(self, z_num, parent_ids=None):
        """ z_num.shape: (B, L, D)
            parent_ids.shape: (B, L)
        """
        z_rows = z_num.clone()[:, :1] # (B, 1, D)
        z_rows = z_rows.repeat(1, z_num.shape[1], 1) # (B, L, D)
        z_num = (
            z_rows.unsqueeze(1)
            - z_num.unsqueeze(0)
        ) # (B, B, L, D)
        B, _, L, D = z_num.shape
        z_num = z_num.reshape(-1, L, D)
        if parent_ids is not None:
            parent_ids = parent_ids.unsqueeze(0).repeat(B, 1, 1)
            parent_ids = parent_ids.reshape(-1, L)
        tests = self.relation(z_num, parent_ids=parent_ids).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt, parent_ids=None):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        z_num = (z_tgt - z_num) # (B, L, D)
        B = z_num.shape[0]
        tests = self.relation(z_num, parent_ids=parent_ids).reshape(B) # (B,)
        # z_num = torch.concat(
        #     [z_num, z_tgt.repeat(1, z_num.shape[1], 1)],
        #     dim=0,
        # ) # (B + 1, L, D)
        # tests = self.run_tests(z_num) # (B + 1, B + 1)
        # tests = tests.clone()[-1, :-1] # (B,)
        return tests

    def forward(self, x_num, parent_ids=None):
        """ x_num.shape: (B, L, D)
        """
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        tests = self.run_tests(z_num, parent_ids=parent_ids) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt, parent_ids=None):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        D = x_tgt.shape[0]
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(z_num, z_tgt, parent_ids=parent_ids) # (B,)
        return tests

    def predict(self, x_num, th=0.5, parent_ids=None):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(x_num, parent_ids=parent_ids) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsCatTreeNet(nn.Module):
    def __init__(
        self,
        num_in_size=1000,
        num_girth=512,
        num_digits=21,
        num_dbase=10.0,
        sym_dmodel=512,
        sym_nhead=4,
        sym_num_layers=2,
        clf_agg="last",
        device="cpu",
    ):
        super().__init__()
        max_prefix_len = 32
        _in_size = num_digits * (num_in_size + 0)
        num_hidden_size = sym_dmodel
        self.num_encoder = nn.Sequential(
            nn.Linear(_in_size, num_girth),
            nn.RMSNorm(num_girth),
            nn.GELU(),
            nn.Linear(num_girth, num_hidden_size),
            nn.RMSNorm(num_hidden_size),
            nn.GELU(),
        )
        _encoder_layer = nn.TransformerEncoderLayer(
            d_model=sym_dmodel,
            nhead=sym_nhead,
            batch_first=True,
            activation="relu",
            dropout=0.0,
        )
        _trf_encoder = nn.TransformerEncoder(
            _encoder_layer,
            num_layers=sym_num_layers,
        )
        ####### Manual initialization of transformer encoder layers
        for module in _trf_encoder.modules():
            if isinstance(module, nn.TransformerEncoderLayer):
                for submodule in module.modules():
                    if isinstance(submodule, nn.Linear):
                        nn.init.kaiming_uniform_(submodule.weight, nonlinearity="relu")
                        if submodule.bias is not None:
                            nn.init.constant_(submodule.bias, 0)
        #######
        self.relation = nn.Sequential(
            PositionalEncoding(
                sym_dmodel,
                max_len=max_prefix_len,
                device=device,
            ),
            _trf_encoder,
            ClfProbe(sym_dmodel, aggregation=clf_agg),
        )
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self.num_dbase = num_dbase
        self._fdt = (self.num_dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)

    def _to_digits(self, x):
        """ x.shape: (B, L, D)
        """
        B, L, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self.num_dbase)
        x = x.reshape(B, L, D * self.num_digits) # (B, L, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, L, D)
        """
        x = x / self.num_dbase
        return x

    def encode_num(self, x):
        """ x.shape: (B, L, BIG_D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.num_encoder(x) # (B, L, D)
        return z

    def run_tests(self, z_num):
        """ z_num.shape: (B, L, D)
        """
        # z_rows = z_num.clone()[:, :1] # (B, 1, D)
        # z_rows = z_rows.repeat(1, z_num.shape[1], 1) # (B, L, D)
        # z_num = (
        #     z_rows.unsqueeze(1) # 1
        #     - z_num.unsqueeze(0) # 0
        # ) # (B, B, L, D)
        # B, _, L, D = z_num.shape
        # z_num = z_num.reshape(-1, L, D)
        # tests = self.relation(z_num).reshape(B, B) # (B, B)
        z_rows = z_num.clone()[:, :1] # (B, 1, D)
        z_rows = z_rows.unsqueeze(1) # (B, 1, 1, D)
        z_rows = z_rows.repeat(1, z_rows.shape[0], 1, 1,) # (B, B, 1, D)
        z_num = z_num.unsqueeze(0) # (1, B, L, D)
        z_num = z_num.repeat(z_num.shape[1], 1, 1, 1) # (B, B, L, D)
        z_cat = torch.concat(
            [z_rows, z_num],
            dim=2,
        ) # (B, B, 1 + L, D)
        B, _, L_, D = z_cat.shape
        z_cat = z_cat.reshape(-1, L_, D) # (B * B, 1 + L, D)
        tests = self.relation(z_cat).reshape(B, B) # (B, B)
        return tests

    def tgt_run_tests(self, z_num, z_tgt):
        """ z_num.shape: (B, L, D)
            z_tgt.shape: (1, 1, D)
        """
        # z_tgt = z_tgt.unsqueeze(0).unsqueeze(0) # (1, 1, D)
        # z_num = (z_tgt - z_num) # (B, L, D)
        # B = z_num.shape[0]
        # tests = self.relation(z_num).reshape(B) # (B,)
        z_tgt = z_tgt.repeat(z_num.shape[0], 1, 1) # (B, 1, D)
        z_cat = torch.concat(
            [z_tgt, z_num],
            dim=1,
        ) # (B, 1 + L, D)
        B = z_cat.shape[0]
        tests = self.relation(z_cat).reshape(B) # (B,)
        return tests

    def forward(self, x_num):
        """ x_num.shape: (B, L, D)
        """
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        tests = self.run_tests(z_num) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x_num, x_tgt):
        """ x_num.shape: (B, L, D)
            x_tgt.shape: (D,)
        """
        D = x_tgt.shape[0]
        z_num = self.encode_num(x_num) # (B, L, dmodel)
        z_tgt = self.encode_num(x_tgt.reshape(1, 1, D)) # (1, 1, dmodel)
        tests = self.tgt_run_tests(z_num, z_tgt) # (B,)
        return tests

    def predict(self, x_num, th=0.5):
        """ x_sym.shape: (B, L, V)
            x_num.shape: (B, D)
        """
        test = self.forward(x_num) # (B, B)
        test = (test >= th).long() # (B, B)
        return test
