import torch as T
import torch.nn as nn


def default_role_template_of(node):
    parts = node.split("_")
    if len(parts) >= 4 and parts[0] == "R":
        return "R_{}_{}".format(parts[2], parts[3])
    return None


class RoleAggregatorModule(nn.Module):
    def __init__(self, parents, agg_name):
        super().__init__()
        self.parents = list(parents)
        self.agg_name = agg_name
        self._count_max = 5
        self._count_bits = 3

    def forward(self, pa, u, v=None, n=None):
        out = self._aggregate(pa, u, v, n)
        if v is None:
            return out
        v_val = v.float()
        diff = T.abs(v_val - out)
        match = T.all(diff < 1e-6, dim=1)
        zeros = T.zeros(out.shape[0], device=out.device)
        neg = T.full((out.shape[0],), -1e8, device=out.device)
        return T.where(match, zeros, neg)

    def _aggregate(self, pa, u, v, n):
        if self.parents:
            base_device = pa[self.parents[0]].device
            vals = T.cat([pa[p].float().to(base_device) for p in self.parents], dim=1)
        else:
            vals = None
        if vals is None or vals.shape[1] == 0:
            n_out, device = self._infer_shape(pa, u, v, n)
            if str(self.agg_name).lower() == "count":
                return T.zeros((n_out, self._count_bits), device=device)
            return T.zeros((n_out, 1), device=device)
        return self._apply_agg(vals)

    def _apply_agg(self, vals):
        key = str(self.agg_name).lower()
        if key in {"or", "max"}:
            return T.max(vals, dim=1).values.unsqueeze(1)
        if key in {"and", "min"}:
            return T.min(vals, dim=1).values.unsqueeze(1)
        if key == "sum":
            return T.sum(vals, dim=1).unsqueeze(1)
        if key == "count":
            count = T.sum(vals != 0, dim=1).clamp_max(self._count_max)
            return self._count_to_bits(count, self._count_bits).float()
        if key == "mean":
            return T.mean(vals, dim=1).unsqueeze(1)
        if key == "strict_maj":
            return (T.sum(vals, dim=1) > (vals.shape[1] / 2.0)).float().unsqueeze(1)
        if key == "weak_maj":
            return (T.sum(vals, dim=1) >= (vals.shape[1] / 2.0)).float().unsqueeze(1)
        raise ValueError(f"Unknown aggregator '{self.agg_name}'")

    def _count_to_bits(self, count, bits):
        if bits == 1:
            return (count > 0).long().unsqueeze(1)
        mask = 2 ** T.arange(bits - 1, -1, -1, device=count.device)
        return count.unsqueeze(-1).bitwise_and(mask).ne(0).long()

    def _infer_shape(self, pa, u, v, n):
        if v is not None:
            return v.shape[0], v.device
        if pa:
            sample_key = next(iter(pa))
            return pa[sample_key].shape[0], pa[sample_key].device
        if u:
            sample_key = next(iter(u))
            return u[sample_key].shape[0], u[sample_key].device
        if n is None:
            n = 1
        return n, T.device("cpu")


def build_role_modules(cg, role_aggregators, role_template_of=None):
    role_template_of = role_template_of or default_role_template_of
    modules = {}
    for node in cg:
        role_key = role_template_of(node)
        if role_key is None:
            continue
        agg_name = role_aggregators.get(role_key)
        if agg_name is None:
            continue
        modules[node] = RoleAggregatorModule(cg.pa[node], agg_name)
    return modules
