import math
from typing import Optional
import logging

import numpy as np
from sklearn.metrics import precision_recall_fscore_support, balanced_accuracy_score, accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, RelaxedBernoulli, Bernoulli
from torch.distributions.relaxed_categorical import ExpRelaxedCategorical

from utils import ExponentialScheduler
from utils import (
    deduplicate_list,
    sorting_permutation,
    is_ascending,
    no_duplicate,
    inverse_permutation,
    apply_permutation,
    format_unground_atom,
    format_unground_rule
)


logger = logging.getLogger(__name__)


def _poisson_binomial_log_cdf(
    logits: torch.Tensor,  # batch_size x N
    K: int
):
    N = logits.shape[1]
    log_pmf = torch.full((logits.shape[0], N + 1, K + 1), float("-inf"), device=logits.device)
    log_pmf[:, 0, 0] = 0
    for i in range(1, N + 1):
        j = min(i, K + 1)
        log_pmf[:, i, 0] = log_pmf[:, i-1, 0] - F.softplus(logits[:, i-1])
        if j > 1:
            log_pmf[:, i, 1: j] = torch.logaddexp(
                log_pmf[:, i-1, 1: j] - F.softplus(logits[:, i-1]).unsqueeze(-1),
                log_pmf[:, i-1, : j-1] - F.softplus(- logits[:, i-1]).unsqueeze(-1)
            )
        if i <= K:
            log_pmf[:, i, i] = log_pmf[:, i-1, i-1] - F.softplus(- logits[:, i-1])

    log_pmf_detach = log_pmf.detach()
    log_cdf = torch.stack(
        [
            torch.logsumexp(log_pmf_detach[:, :, : k+1], dim=-1)
            for k in range(K + 1)
        ],
        dim=2
    )
    
    return torch.logsumexp(log_pmf[:, N], dim=-1), log_pmf, log_cdf


def _sample_body_atoms(
    body_atom_logits: torch.Tensor,  # body_atom_logits: num_rules x max_occurrence_in_body x total_predicates
    max_body_atoms: int,
    num_samples: int
):
    body_atom_logits = body_atom_logits.flatten(start_dim=1)
    log_cdf, log_pmf, log_cdf = _poisson_binomial_log_cdf(body_atom_logits, max_body_atoms)

    R, N = body_atom_logits.shape
    n_already_sampled = torch.zeros((num_samples, R), dtype=torch.int, device=body_atom_logits.device)
    body_atom_logits_detach = body_atom_logits.detach()
    all_samples = []
    for i in range(N, 0, -1):
        j = max_body_atoms - n_already_sampled
        temp = F.logsigmoid(body_atom_logits_detach[:, i-1]).unsqueeze(0) + log_cdf[torch.arange(R).unsqueeze(0), i-1, j-1] - log_cdf[torch.arange(R).unsqueeze(0), i, j]
        temp = temp.exp().clamp(0, 1)
        sample_prob = torch.where(
            j - 1 < 0,
            0,
            temp
        )
        samples = Bernoulli(probs=sample_prob).sample((1, ))
        all_samples.append(samples)
        n_already_sampled += samples.squeeze(0).to(dtype=torch.int)
    all_samples = torch.cat(all_samples, dim=0).permute(1, 2, 0).contiguous()
    assert torch.all(all_samples.sum(dim=-1) <= max_body_atoms)
    return all_samples, log_cdf


def _extract_body_atom(
    body_atom_logits: torch.Tensor,  # num_rules x max_occurrence_in_body x total_predicates
    max_body_atoms: int
):
    num_rules, max_occurrence_in_body, total_predicates = body_atom_logits.shape
    body_atom_logits = body_atom_logits.flatten(start_dim=1)
    _, indices = torch.sort(body_atom_logits, dim=-1, descending=True)
    indices = indices[:, : max_body_atoms]
    
    body_atom = torch.zeros(
        (num_rules, max_occurrence_in_body * total_predicates),
        dtype=bool,
        device=body_atom_logits.device
    )
    body_atom[torch.arange(num_rules).unsqueeze(1).expand(-1, indices.shape[1]), indices] = True
    body_atom.masked_fill_(body_atom_logits <= 0, False)
    body_atom = body_atom.view(num_rules, max_occurrence_in_body, total_predicates)

    return body_atom
    

def _remove_irrelevant_vars_in_head(
    head_vars_log_prob,
    head_atom,
    head_arities_tensor
):
    device = head_vars_log_prob.device
    chunk_ids = torch.arange(head_arities_tensor.shape[0], device=device).repeat_interleave(head_arities_tensor)

    expand_shape = list(head_vars_log_prob.shape[:-1]) + [head_arities_tensor.shape[0]]
    sumed_head_vars_log_prob = torch.zeros(expand_shape, dtype=head_vars_log_prob.dtype, device=device)
    sumed_head_vars_log_prob = sumed_head_vars_log_prob.scatter_add(-1, chunk_ids.expand_as(head_vars_log_prob), head_vars_log_prob)

    num_sample_atom, num_rules = head_atom.shape
    X, Y = torch.meshgrid(torch.arange(num_sample_atom, device=device), torch.arange(num_rules, device=device), indexing="ij")
    mask = torch.ones_like(sumed_head_vars_log_prob, dtype=torch.bool, device=device)
    mask[X, :, Y, head_atom] = False
    sumed_head_vars_log_prob = sumed_head_vars_log_prob.masked_fill(mask, 0)

    return sumed_head_vars_log_prob


def _remove_irrelevant_vars_in_body(
    body_vars_log_prob,  # num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
    body_atom,  # num_sample_atoms x num_rules x max_occurrence_in_body x num_predicates
    body_arities_tensor
):
    device = body_vars_log_prob.device
    chunk_ids = torch.arange(body_arities_tensor.shape[0], device=device).repeat_interleave(body_arities_tensor)

    expand_shape = list(body_vars_log_prob.shape[:-1]) + [body_arities_tensor.shape[0]]
    sumed_body_vars_log_prob = torch.zeros(expand_shape, dtype=body_vars_log_prob.dtype, device=device)
    sumed_body_vars_log_prob = sumed_body_vars_log_prob.scatter_add(-1, chunk_ids.expand_as(body_vars_log_prob), body_vars_log_prob)

    sumed_body_vars_log_prob = sumed_body_vars_log_prob.masked_fill(~body_atom.unsqueeze(1), 0)

    return sumed_body_vars_log_prob


class ProbProgram(nn.Module):


    def __init__(
        self,
        cwa: bool,
        num_background_predicates: int,
        num_aux_predicates: int,
        num_target_predicates: int,
        background_arities: list[int],
        aux_arities: list[int],
        target_arities: list[int],
        max_occurrence_in_body: int,
        num_rules: int,
        predicate_embed_dim: int,
        variable_embed_dim: int,
        rule_head_atom_embed_dim: int,
        rule_body_atom_embed_dim: int,
        max_variables: int,
        max_body_atoms: int,
        remove_irrelevant_vars: bool
    ) -> None:
        super().__init__()

        self.cwa = cwa

        self.num_background_predicates = num_background_predicates
        self.num_aux_predicates = num_aux_predicates
        self.num_target_predicates = num_target_predicates
        self.num_predicates = self.num_background_predicates + self.num_aux_predicates + self.num_target_predicates
        self.num_head_predicates = (self.num_aux_predicates + self.num_target_predicates) if cwa else self.num_predicates
        self.max_arity = max(arity for arity in background_arities + aux_arities + target_arities)
        self.head_arities = (aux_arities + target_arities) if self.cwa else (background_arities + aux_arities + target_arities)
        self.head_arities_tensor = torch.tensor(self.head_arities)
        self.body_arities = background_arities + aux_arities + target_arities
        self.body_arities_tensor = torch.tensor(self.body_arities)
        self.sum_head_arities = sum(self.head_arities)
        self.sum_body_arities = sum(self.body_arities)
        self.atom_vars_embed_dim = variable_embed_dim * self.max_arity
        self.atom_embed_dim = predicate_embed_dim + variable_embed_dim * self.max_arity

        # rules' parameters
        self.rule_vars_embed = nn.Parameter(torch.randn([num_rules, max_variables, variable_embed_dim]))
        self.rule_head_atom_embed = nn.Parameter(torch.randn([num_rules, rule_head_atom_embed_dim]))
        self.rule_body_atom_embed = nn.Parameter(torch.randn([num_rules, rule_body_atom_embed_dim]))

        # background + aux + target parameters
        self.predicate_embed = nn.Parameter(torch.randn([
            self.num_predicates,
            predicate_embed_dim
        ]))
        self.head_vars_embed = nn.Parameter(torch.randn([
            num_rules,
            self.sum_head_arities,
            variable_embed_dim
        ]))
        self.body_vars_embed = nn.Parameter(torch.randn([
            num_rules,
            max_occurrence_in_body,
            self.sum_body_arities,
            variable_embed_dim
        ]))
        self.map_to_head_atom = nn.Parameter(torch.empty(num_rules, self.atom_embed_dim, rule_head_atom_embed_dim, dtype=torch.float32))
        self.map_to_body_atom = nn.Parameter(torch.empty(num_rules, self.atom_embed_dim, rule_body_atom_embed_dim, dtype=torch.float32))
        stdv = 1. / math.sqrt(self.atom_embed_dim)
        nn.init.uniform_(self.map_to_head_atom.data, -stdv, stdv)
        nn.init.uniform_(self.map_to_body_atom.data, -stdv, stdv)

        self.num_rules = num_rules
        self.max_occurrence_in_body = max_occurrence_in_body
        self.max_body_atoms = max_body_atoms

        self.remove_irrelevant_vars = remove_irrelevant_vars


    def forward(
        self,
        num_sample_vars: Optional[int] = None,
        num_sample_atoms: Optional[int] = None
    ):
        # ================================================
        #   Variable Selection
        # ================================================
        # self.rule_vars_embed: num_rules x max_variables x variable_embed_dim
        # self.head_vars_embed: num_rules x sum_head_arities x variable_embed_dim
        # head_vars_logits: num_rules x sum_head_arities x max_variables
        head_vars_logits = self.head_vars_embed @ self.rule_vars_embed.transpose(-1, -2)
        # self.body_vars_embed: num_rules x max_occurrence_in_body x sum_body_arities x variable_embed_dim
        # body_vars_logits: num_rules x max_occurrence_in_body x sum_body_arities x max_variables
        body_vars_logits = self.body_vars_embed @ self.rule_vars_embed.unsqueeze(1).transpose(-1, -2)
        if self.training:
            head_vars_dist = Categorical(logits=head_vars_logits)
            # head_vars: num_sample_atoms x num_sample_vars x num_rules x sum_head_arities
            head_vars = head_vars_dist.sample((num_sample_atoms, num_sample_vars))
            # head_vars_log_prob: num_sample_atoms x num_sample_vars x num_rules x sum_head_arities
            head_vars_log_prob = head_vars_dist.log_prob(head_vars)
            # head_vars_entropy: num_rules x sum_head_arities
            head_vars_entropy = head_vars_dist.entropy()

            body_vars_dist = Categorical(logits=body_vars_logits)
            # body_vars: num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
            body_vars = body_vars_dist.sample((num_sample_atoms, num_sample_vars))
            # body_vars_log_prob: num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
            body_vars_log_prob = body_vars_dist.log_prob(body_vars)
            # body_vars_entropy: num_rules x max_occurrence_in_body x sum_body_arities
            body_vars_entropy = body_vars_dist.entropy()
        else:
            head_vars = head_vars_logits.argmax(dim=-1)
            body_vars = body_vars_logits.argmax(dim=-1)
        head_vars, body_vars = head_vars.tolist(), body_vars.tolist()


        # ================================================
        #   Atom Unification
        # ================================================
        # self.head_vars_embed: num_rules x sum_head_arities x variable_embed_dim
        head_atom_vars_embed = torch.split(self.head_vars_embed, self.head_arities, dim=-2)
        head_padded_atom_vars_embed = torch.zeros(
            [
                self.num_rules,
                self.num_head_predicates if self.cwa else self.num_predicates,
                self.atom_vars_embed_dim
            ],
            device=self.predicate_embed.device
        )
        for i, tensor in enumerate(head_atom_vars_embed):
            temp = tensor.flatten(start_dim=-2)
            head_padded_atom_vars_embed[:, i, : temp.shape[-1]] = temp
        predicate_embed = self.predicate_embed[- self.num_head_predicates: ] if self.cwa else self.predicate_embed
        # head_atom_embed: num_rules x num_target_predicates/num_predicates x atom_embed_dim
        head_atom_embed = torch.cat(
            [predicate_embed.expand(head_padded_atom_vars_embed.shape[:1] + (-1, -1)), head_padded_atom_vars_embed],
            dim=-1
        )
        # head_atom_embed: num_rules x num_target_predicates/num_predicates x rule_head_atom_embed_dim
        head_atom_embed = head_atom_embed @ self.map_to_head_atom


        # self.body_vars_embed: num_rules x max_occurrence_in_body x sum_body_arities x variable_embed_dim
        body_atom_vars_embed = torch.split(self.body_vars_embed, self.body_arities, dim=-2)
        body_padded_atom_vars_embed = torch.zeros(
            [
                self.num_rules,
                self.max_occurrence_in_body,
                self.num_predicates,
                self.atom_vars_embed_dim
            ],
            device=self.predicate_embed.device
        )
        for i, tensor in enumerate(body_atom_vars_embed):
            temp = tensor.flatten(start_dim=-2)
            body_padded_atom_vars_embed[:, :, i, : temp.shape[-1]] = temp
        # body_atom_embed: num_rules x max_occurrence_in_body x num_predicates x atom_embed_dim
        body_atom_embed = torch.cat(
            [self.predicate_embed.expand(body_padded_atom_vars_embed.shape[:2] + (-1, -1)), body_padded_atom_vars_embed],
            dim=-1
        )
        # body_atom_embed: num_rules x max_occurrence_in_body x num_predicates x rule_body_atom_embed_dim
        body_atom_embed = body_atom_embed @ self.map_to_body_atom.unsqueeze(1)



        # head_atom_embed: num_rules x num_target_predicates/num_predicates x rule_head_atom_embed_dim
        # self.rule_head_atom_embed: num_rules x rule_head_atom_embed_dim
        # self.rule_head_atom_embed.unsqueeze(-1): num_rules x rule_head_atom_embed_dim x 1
        # head_atom_logits: num_rules x num_target_predicates/num_predicates
        head_atom_logits = torch.matmul(
            head_atom_embed,
            self.rule_head_atom_embed.unsqueeze(-1)
        ).squeeze(-1)
        if self.training:
            head_atom_dist = Categorical(logits=head_atom_logits)
            head_atom = head_atom_dist.sample((num_sample_atoms, ))
            head_atom_log_prob = head_atom_dist.log_prob(head_atom)
            head_atom_entropy = head_atom_dist.entropy()

            if self.remove_irrelevant_vars:
                _head_vars_log_prob = head_vars_log_prob
                head_vars_log_prob = _remove_irrelevant_vars_in_head(
                    head_vars_log_prob,
                    head_atom,
                    self.head_arities_tensor
                )
        else:
            head_atom = head_atom_logits.argmax(dim=-1)
        head_atom = head_atom.tolist()


        # body_atom_embed: num_rules x max_occurrence_in_body x total_predicates x rule_head_atom_embed_dim
        # self.rule_body_atom_embed: num_rules x rule_head_atom_embed_dim
        # self.rule_body_atom_embed.unsqueeze(1).unsqueeze(-1): num_rules x 1 x rule_head_atom_embed_dim x 1
        # body_atom_logits: num_rules x max_occurrence_in_body x total_predicates
        body_atom_logits = torch.matmul(
            body_atom_embed,
            self.rule_body_atom_embed.unsqueeze(1).unsqueeze(-1)
        ).squeeze(-1)
        if self.training:
            body_atom_dist = Bernoulli(logits=body_atom_logits)
            body_atom, log_prob_cdf = _sample_body_atoms(body_atom_logits, self.max_body_atoms, num_sample_atoms)
            body_atom = body_atom.reshape((num_sample_atoms, ) + body_atom_logits.shape)
            body_atom_log_prob = body_atom_dist.log_prob(body_atom)
            body_atom_entropy = body_atom_dist.entropy()
            body_atom = body_atom.to(dtype=bool)

            if self.remove_irrelevant_vars:
                _body_vars_log_prob = body_vars_log_prob
                body_vars_log_prob = _remove_irrelevant_vars_in_body(
                    body_vars_log_prob,
                    body_atom,
                    self.body_arities_tensor
                )
        else:
            body_atom = _extract_body_atom(body_atom_logits, self.max_body_atoms)
        body_atom = body_atom.tolist()


        if self.training:
            return {
                "vars": {
                    "head": {
                        "logits": head_vars_logits,
                        "samples": head_vars,
                        "log_prob": head_vars_log_prob,
                        "entropy": head_vars_entropy,
                    },
                    "body": {
                        "logits": body_vars_logits,
                        "samples": body_vars,
                        "log_prob": body_vars_log_prob,
                        "entropy": body_vars_entropy,
                    }
                },
                "atom": {
                    "head": {
                        "logits": head_atom_logits,
                        "samples": head_atom,
                        "log_prob": head_atom_log_prob,
                        "entropy": head_atom_entropy,
                    },
                    "body": {
                        "logits": body_atom_logits,
                        "samples": body_atom,
                        "log_prob": body_atom_log_prob,
                        "entropy": body_atom_entropy,
                        "log_prob_cdf": log_prob_cdf,
                    }
                }
            }
        else:
            return {
                "vars": {"head": {"samples": head_vars}, "body": {"samples": body_vars}},
                "atom": {"head": {"samples": head_atom}, "body": {"samples": body_atom}}
            }


def conjunction(
    atoms: list[torch.Tensor],
    variables: list[list[int]],  # # max_occurrence x sum_arities
    body_atom: list[list[int]]  # max_occurrence x num_predicates
):
    einsum_operators: list[torch.Tensor] = []
    einsum_vars: list[list[int]] = []
    # body_vars: list[int] = []
    cur_arity = 0  # in range(sum_arities)
    for i, pred in enumerate(atoms):
        for occur in range(len(body_atom)):
            if not body_atom[occur][i]:
                continue

            einsum_operators.append(pred)
            cur_vars = variables[occur][cur_arity: cur_arity + pred.ndim]
            # body_vars.extend(cur_vars)
            einsum_vars.append(cur_vars)

        cur_arity += pred.ndim
    assert cur_arity == len(variables[0])
    
    if len(einsum_vars) == 1:
        body_vars = deduplicate_list(einsum_vars[0])
        body = torch.einsum(einsum_operators[0], einsum_vars[0], body_vars)
    elif len(einsum_vars) > 1:
        body, body_vars = einsum_operators[0], einsum_vars[0]
        for i in range(1, len(einsum_vars)):
            _vars = deduplicate_list(body_vars + einsum_vars[i])
            body = torch.einsum(body, body_vars, einsum_operators[i], einsum_vars[i], _vars)
            body_vars = _vars
    else:
        body = torch.tensor(True, dtype=torch.bool, device=atoms[0].device)
        body_vars: list[int] = []
    
    return body, body_vars


def implication(
    body: torch.Tensor,
    body_vars: list[int],
    head_vars: list[int]
):
    body_vars = body_vars.copy()

    dim_max: list[int] = []
    for dim, vars in enumerate(body_vars):
        if vars not in head_vars:
            dim_max.append(dim)
    if dim_max:
        body = body.amax(dim=dim_max)
        for dim in dim_max[::-1]:
            body_vars.pop(dim)

    head_vars_noextra: list[int] = []
    for v in head_vars:
        if v in body_vars:
            head_vars_noextra.append(v)
    head_vars_noextra_dedup = deduplicate_list(head_vars_noextra)
        
    if head_vars_noextra_dedup:
        assert body.ndim > 0
        num_constants = body.shape[-1]

        permutation = [body_vars.index(x) for x in head_vars_noextra_dedup]
        body = body.permute(permutation)
        body_vars = apply_permutation(body_vars, permutation)
        assert body_vars == head_vars_noextra_dedup

        if len(body_vars) < len(head_vars_noextra):
            temp = torch.zeros([num_constants] * len(head_vars_noextra), dtype=torch.bool, device=body.device)
            temp_indexes = list(torch.meshgrid(list(torch.arange(s) for s in body.shape), indexing='ij'))
            assert len(temp_indexes) == len(body_vars)
            indexes_map = {v: ind for ind, v in zip(temp_indexes, body_vars)}
            indexes = []
            for v in head_vars_noextra:
                indexes.append(indexes_map[v])
            body = temp.index_put(indexes, body)
    else:
        assert body.ndim == 0
    
    for dim, v in enumerate(head_vars):
        if v not in body_vars:
            body = body.unsqueeze(dim)
    
    return body


def _one_step_symbolic_forward_chaining(
    step_input: list[torch.Tensor],
    head_vars: list[list[int]],  # num_rules x sum_head_arities
    body_vars: list[list[list[int]]],  # num_rules x max_occurrence_in_body x sum_body_arities
    head_atom: list[int],  # num_rules
    body_atom: list[list[list[int]]],  # num_rules x max_occurrence_in_body x num_predicates
    num_head_predicates: int
):
    step_output = [torch.zeros_like(p) for p in step_input]
    for r in range(len(body_atom)):
        # ================================================
        #   Compute body
        # ================================================
        body, _body_vars = conjunction(step_input, body_vars[r], body_atom[r])


        # ================================================
        #   Compute head
        # ================================================
        cur_arity = 0  # in range(sum_arities)
        for i in range(- num_head_predicates, 0):
            j = i + num_head_predicates
            pred = step_output[i]
            if j < head_atom[r]:
                cur_arity += pred.ndim
                continue
                
            _head_vars = head_vars[r][cur_arity: cur_arity + pred.ndim]
            head = implication(body, _body_vars, _head_vars)

            step_output[i] = torch.maximum(step_output[i], head)

            break

    return step_output


@torch.jit.script
def _forward_chaining(
    inference_steps: int,
    all_predicates: list[torch.Tensor],
    num_targets: int,
    targets_label: torch.Tensor,
    num_head_predicates: int,
    head_vars: list[list[list[list[int]]]],  # num_sample_atoms x num_sample_vars x num_rules x sum_head_arities
    body_vars: list[list[list[list[list[int]]]]],  # num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
    head_atom: list[list[int]],  # num_sample_atoms x num_rules
    body_atom: list[list[list[list[int]]]],  # num_sample_atoms x num_rules x max_occurrence_in_body x num_predicates
    is_training: bool,
    sampling_bacc_non_parallel_compute: bool
):
    device = all_predicates[0].device
    cuda_stream = torch.cuda.Stream()
    cuda_stream.wait_stream(torch.cuda.current_stream(device=device))
    with torch.cuda.stream(cuda_stream):
        num_sample_vars = len(head_vars[0])
        num_sample_atoms = len(head_atom)

        if sampling_bacc_non_parallel_compute:
            balanced_acc = torch.empty((num_sample_atoms, num_sample_vars), dtype=torch.float, device=device)
            predictions = torch.empty((1, ), device=device)
        else:
            balanced_acc = torch.empty((1, ), device=device)
            predictions = torch.empty((num_sample_atoms, num_sample_vars, sum(p.numel() for p in all_predicates[- num_targets: ])), device=device)
        all_steps = torch.zeros((num_sample_atoms, num_sample_vars), dtype=torch.int, device=device)
        for m in range(num_sample_atoms):
            for n in range(num_sample_vars):

                step_input = all_predicates
                step_output = step_input

                step = 0
                for step in range(1, inference_steps + 1):
                    step_output = _one_step_symbolic_forward_chaining(
                        step_input,
                        head_vars[m][n], body_vars[m][n],
                        head_atom[m], body_atom[m],
                        num_head_predicates
                    )
                    
                    all_equal = True
                    for i in range(len(step_output)):
                        step_output[i] = torch.maximum(step_input[i], step_output[i])

                        all_equal = all_equal & bool((step_output[i] == step_input[i]).all())

                    if all_equal:
                        break

                    step_input = step_output
                if sampling_bacc_non_parallel_compute:
                    p = torch.cat(list(l.flatten() for l in step_output[- num_targets: ]))
                    TP = torch.sum((p == True) & (targets_label == True), dim=-1)
                    FN = torch.sum((p == False) & (targets_label == True), dim=-1)
                    TN = torch.sum((p == False) & (targets_label == False), dim=-1)
                    FP = torch.sum((p == True) & (targets_label == False), dim=-1)
                    balanced_acc[m][n] = (TP / (TP + FN) + TN / (TN + FP)) / 2
                else:
                    predictions[m][n] = torch.cat(list(l.flatten() for l in step_output[- num_targets: ]))
                all_steps[m][n] = step
        
        if not sampling_bacc_non_parallel_compute:
            TP = torch.sum((predictions == True) & (targets_label == True), dim=-1)
            FN = torch.sum((predictions == False) & (targets_label == True), dim=-1)
            TN = torch.sum((predictions == False) & (targets_label == False), dim=-1)
            FP = torch.sum((predictions == True) & (targets_label == False), dim=-1)
            balanced_acc = (TP / (TP + FN) + TN / (TN + FP)) / 2
    torch.cuda.current_stream(device=device).wait_stream(cuda_stream)

    return predictions, balanced_acc, all_steps


@torch.jit.script
def forward_chaining(
    num_concurrency: int,
    inference_steps: int,
    all_predicates: list[torch.Tensor],
    num_targets: int,
    targets_label: torch.Tensor,
    num_head_predicates: int,
    head_vars: list[list[list[list[int]]]],  # num_sample_atoms x num_sample_vars x num_rules x sum_head_arities
    body_vars: list[list[list[list[list[int]]]]],  # num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
    head_atom: list[list[int]],  # num_sample_atoms x num_rules
    body_atom: list[list[list[list[int]]]],  # num_sample_atoms x num_rules x max_occurrence_in_body x num_predicates
    is_training: bool,
    sampling_bacc_non_parallel_compute: bool
):
    num_sample_atoms = len(head_vars)
    batch_size = math.ceil(num_sample_atoms / num_concurrency)
    futures = [
        torch.jit.fork(
            _forward_chaining,
            inference_steps=inference_steps,
            all_predicates=all_predicates,
            num_targets=num_targets,
            targets_label=targets_label,
            num_head_predicates=num_head_predicates,
            head_vars=head_vars[batch_i: batch_i + batch_size],
            body_vars=body_vars[batch_i: batch_i + batch_size],
            head_atom=head_atom[batch_i: batch_i + batch_size],
            body_atom=body_atom[batch_i: batch_i + batch_size],
            is_training=is_training,
            sampling_bacc_non_parallel_compute=sampling_bacc_non_parallel_compute
        )
        for batch_i in range(0, num_sample_atoms, batch_size)
    ]
    predictions: list[torch.Tensor] = []
    balanced_acc: list[torch.Tensor] = []
    fc_s: list[torch.Tensor] = []
    for fut in futures:
        preds, b_acc, steps = torch.jit.wait(fut)

        predictions.append(preds)
        balanced_acc.append(b_acc)
        fc_s.append(steps)        
    balanced_acc = torch.cat(balanced_acc, dim=0)
    predictions = torch.cat(predictions, dim=0)
    fc_steps = torch.cat(fc_s, dim=0)

    return predictions, balanced_acc, fc_steps


class SymbProgram(nn.Module):

    
    def __init__(
        self,
        cwa: bool,
        max_train_inference_steps: int,
        max_eval_inference_steps: int,
        background_arities: list[int],
        aux_arities: list[int],
        target_arities: list[int],
        num_concurrency: int,
        sampling_bacc_non_parallel_compute: bool
    ) -> None:
        super().__init__()

        self.cwa = cwa

        self.max_train_inference_steps = max_train_inference_steps
        self.max_eval_inference_steps = max_eval_inference_steps
        self.head_arities = (aux_arities + target_arities) if self.cwa else (background_arities + aux_arities + target_arities)
        self.body_arities = background_arities + aux_arities + target_arities
        self.sum_head_arities = sum(self.head_arities)
        self.sum_body_arities = sum(self.body_arities)
        self.num_head_predicates = len(aux_arities + target_arities) if self.cwa else len(background_arities + aux_arities + target_arities)

        self.num_concurrency = num_concurrency

        self.sampling_bacc_non_parallel_compute = sampling_bacc_non_parallel_compute


    @torch.no_grad()
    def forward(
        self,
        all_predicates: list[torch.Tensor],
        targets_label: list[torch.Tensor],
        head_vars: list[list[list[list[int]]]],  # num_sample_atoms x num_sample_vars x num_rules x sum_head_arities
        body_vars: list[list[list[list[list[int]]]]],  # num_sample_atoms x num_sample_vars x num_rules x max_occurrence_in_body x sum_body_arities
        head_atom: list[list[int]],  # num_sample_atoms x num_rules
        body_atom: list[list[list[list[int]]]],  # num_sample_atoms x num_rules x max_occurrence_in_body x num_predicates
    ):
        num_targets = len(targets_label)
        targets_label = torch.cat(tuple(l.flatten() for l in targets_label))


        predictions, balanced_acc, fc_steps = forward_chaining(
            self.num_concurrency,
            self.max_train_inference_steps if self.training else self.max_eval_inference_steps,
            all_predicates,
            num_targets,
            targets_label,
            self.num_head_predicates,
            head_vars, body_vars,
            head_atom, body_atom,
            self.training,
            self.sampling_bacc_non_parallel_compute & self.training
        )


        max_fc_steps = int(fc_steps.amax())


        if self.training:
            return balanced_acc, max_fc_steps
        else:
            return predictions, max_fc_steps


    @torch.no_grad()
    def extract_rules(
        self,
        predicate_names: list[str],
        head_vars: list[list[int]],  # num_rules x sum_head_arities
        body_vars: list[list[list[int]]],  # num_rules x max_occurrence_in_body x sum_body_arities
        head_atom: list[int],  # num_rules
        body_atom: list[list[list[int]]],  # num_rules x max_occurrence_in_body x num_predicates
    ):
        num_head_predicates = len(self.head_arities)
        num_body_predicates = len(self.body_arities)

        raw_rules, postprocessed_rules = [], []

        for r in range(len(head_atom)):
            heads = []
            cur_arity = 0
            for i in range(- num_head_predicates, 0):
                j = i + num_head_predicates
                if j < head_atom[r]:
                    cur_arity += self.head_arities[i]
                    continue
                
                h_name = predicate_names[i]
                h_vars = tuple(head_vars[r][cur_arity: cur_arity + self.head_arities[i]])
                heads.append((h_name, h_vars))

                break

            bodys = []
            cur_arity = 0
            for i in range(num_body_predicates):
                b_name = predicate_names[i]
                for occur in range(len(body_atom[0])):
                    if body_atom[r][occur][i]:
                        b_vars = tuple(body_vars[r][occur][cur_arity: cur_arity + self.body_arities[i]])
                        bodys.append((b_name, b_vars))
                cur_arity += self.body_arities[i]
            assert cur_arity == self.sum_body_arities

            _raw_rules = format_unground_rule(
                [format_unground_atom(*h) for h in heads],
                [format_unground_atom(*b) for b in bodys]
            )
            raw_rules.extend(_raw_rules)

            bodys_dedup = deduplicate_list(bodys)
            for head in heads:
                if any(b == head for b in bodys_dedup):
                    continue

                v_map = {}
                for p in [head] + bodys_dedup:
                    for v in p[1]:
                        if v not in v_map:
                            v_map[v] = len(v_map)

                new_head = (head[0], tuple(map(v_map.get, head[1])))
                new_bodys = list(map(lambda p: (p[0], tuple(map(v_map.get, p[1]))), bodys_dedup))
                
                postprocessed_rule, = format_unground_rule(
                    [format_unground_atom(*new_head)], [format_unground_atom(*b) for b in new_bodys]
                )
                postprocessed_rules.append(postprocessed_rule)
        postprocessed_rules = deduplicate_list(postprocessed_rules)


        return raw_rules, postprocessed_rules
