import numpy as np
import torch as T
import torch.nn as nn

from src.scm.distribution.continuous_distribution import UniformDistribution
from src.scm.scm import SCM
from src.ds.causal_graph import CausalGraph


class CTM(SCM):
    def __init__(self, cg, v_size={}, default_v_size=1, regions=4, c2_scale=1.0, batch_size=None, seed=None):
        self.cg = cg
        self.u_size = {k: 1 for k in self.cg.c2}
        self.v_size = {k: v_size.get(k, default_v_size) for k in self.cg}
        self.region_count = {k: int(regions * (c2_scale ** len(self.cg.v2c2[k]))) for k in self.cg}
        self.batch_size = batch_size

        if seed is not None:
            self.rand_state = np.random.RandomState(seed=seed)
        else:
            self.rand_state = np.random.RandomState()

        super().__init__(
            v=list(cg),
            f={V: self.get_ctm_func(V) for V in cg},
            pu=UniformDistribution(self.cg.c2, self.u_size))

    def get_ctm_func(self, V):
        v_pa = sorted(self.cg.pa[V])
        u_pa = sorted(self.cg.v2c2[V])

        outcomes = 2 ** (sum([self.v_size[k] for k in v_pa]))
        output_size = self.v_size[V]
        c2_size = len(u_pa)
        regions = []
        region_outputs = []
        for r in range(self.region_count[V]):
            intervals = [sorted(self.rand_state.rand(2)) for _ in range(c2_size)]
            output = self.rand_state.binomial(1, 0.5, size=(outcomes, output_size))
            regions.append(intervals)
            region_outputs.append(output)
        default_output = self.rand_state.binomial(1, 0.5, size=(outcomes, output_size))
        region_outputs.append(default_output)
        region_outputs = T.LongTensor(region_outputs)

        def ctm_func(v_raw, u_raw):
            v = {k: v.cpu() for (k, v) in v_raw.items()}
            u = {k: v.cpu() for (k, v) in u_raw.items()}

            u_key = next(iter(u))
            n = len(u[u_key])

            region_found = T.ones((n, 1), dtype=T.long) * len(regions)
            for i, region in enumerate(regions):
                in_region = T.ones((n, 1), dtype=T.bool)
                for j, u_name in enumerate(u_pa):
                    in_region *= (region[j][0] <= u[u_name]) * (u[u_name] < region[j][1])
                region_found[in_region] = i

            region_found = T.squeeze(region_found)
            used_func = region_outputs[region_found]

            if len(v_pa) == 0:
                return T.squeeze(used_func, dim=1)
            else:
                v_arr = T.cat([v[k] for k in v_pa], dim=1).long()
                v_ind = T.zeros(n, dtype=T.long)
                for i in range(v_arr.shape[1]):
                    v_ind = 2 * v_ind + v_arr[:, i]
                return used_func[range(n), v_ind]

        return ctm_func

    def sample(self, n=None, u=None, do={}, select=None):
        if self.batch_size is None:
            return super().sample(n=n, u=u, do=do, select=select)

        assert not set(do.keys()).difference(self.v)
        assert (n is None) != (u is None)

        if select is None:
            samp = {k: [] for k in self.v}
        else:
            samp = {k: [] for k in select}
        if n is None:
            u_key = next(iter(u))
            remaining = len(u[u_key])
        else:
            remaining = n

        i = 0
        while remaining > 0:
            if remaining > self.batch_size:
                if n is None:
                    new_n = None
                    new_u = {k: u[k][self.batch_size * i:self.batch_size * (i + 1)] for k in u}
                else:
                    new_n = self.batch_size
                    new_u = None
                new_do = {k: do[k][self.batch_size * i:self.batch_size * (i + 1)] for k in do}
                remaining -= self.batch_size
            else:
                if n is None:
                    new_n = None
                    new_u = {k: u[k][self.batch_size * i:] for k in u}
                else:
                    new_n = remaining
                    new_u = None
                new_do = {k: do[k][self.batch_size * i:] for k in do}
                remaining = 0

            batch = super().sample(n=new_n, u=new_u, do=new_do, select=select)
            for v in batch:
                samp[v].append(batch[v])
            i += 1

        samp = {k: T.cat(samp[k], dim=0) for k in samp}
        return samp


class RelationalCTM(SCM):
    """
    Relational CTM with template-level shared mechanisms and optional aggregation
    for variable-arity relational parents.
    """

    def __init__(
        self,
        cg,
        template_of,
        reps,
        v_size={},
        default_v_size=1,
        regions=4,
        c2_scale=1.0,
        batch_size=None,
        seed=None,
        template_funcs=None,
        agg_templates=None,
        agg_sizes=None,
        agg_mode="count",
        role_aggregators=None,
        role_template_of=None,
    ):
        self.cg = cg
        self.template_of = template_of
        self.reps = reps
        self.u_size = {k: 1 for k in self.cg.c2}
        self.v_size = {k: v_size.get(k, default_v_size) for k in self.cg}
        self.batch_size = batch_size
        self.agg_templates = agg_templates or {}
        self.agg_sizes = agg_sizes or {}
        self.agg_mode = agg_mode
        self.role_aggregators = role_aggregators or {}
        self.role_template_of = role_template_of or self._default_role_template_of

        if seed is not None:
            self.rand_state = np.random.RandomState(seed=seed)
        else:
            self.rand_state = np.random.RandomState()

        if template_funcs is None:
            self.template_funcs = {}
            for template, rep in self.reps.items():
                self.template_funcs[template] = self._build_template_func(
                    template=template,
                    rep=rep,
                    regions=regions,
                    c2_scale=c2_scale,
                )
        else:
            self.template_funcs = template_funcs

        f = {}
        for node in cg:
            role_key = self.role_template_of(node)
            agg_name = self.role_aggregators.get(role_key)
            if agg_name is not None:
                f[node] = self._build_role_func(node, agg_name)
                continue
            template = self.template_of(node)
            if template not in self.template_funcs:
                raise ValueError(f"No template function for node {node}")
            f[node] = self._wrap_node_func(node, template)

        super().__init__(
            v=list(cg),
            f=f,
            pu=UniformDistribution(self.cg.c2, self.u_size),
        )

    def _build_template_func(self, template, rep, regions, c2_scale):
        rep_pa = list(self.cg.pa[rep])
        agg_parent_templates = list(self.agg_templates.get(template, []))
        direct_parents = [
            p for p in rep_pa if self.template_of(p) not in agg_parent_templates
        ]
        agg_keys = [f"__agg_{t}" for t in agg_parent_templates]
        parent_keys = direct_parents + agg_keys

        v_size_template = {k: self.v_size[k] for k in direct_parents}
        for t in agg_parent_templates:
            v_size_template[f"__agg_{t}"] = self.agg_sizes.get(t, 1)

        u_pa = list(self.cg.v2c2[rep])
        region_count = int(regions * (c2_scale ** len(u_pa)))

        outcomes = 2 ** (sum([v_size_template[k] for k in parent_keys]))
        output_size = self.v_size[rep]
        c2_size = len(u_pa)

        regions_list = []
        region_outputs = []
        for _ in range(region_count):
            intervals = [sorted(self.rand_state.rand(2)) for _ in range(c2_size)]
            output = self.rand_state.binomial(1, 0.5, size=(outcomes, output_size))
            regions_list.append(intervals)
            region_outputs.append(output)
        default_output = self.rand_state.binomial(1, 0.5, size=(outcomes, output_size))
        region_outputs.append(default_output)
        region_outputs = T.LongTensor(region_outputs)

        def template_func(v_raw, u_raw):
            v = {k: v_raw[k].cpu() for k in parent_keys}
            u = {k: v_raw_u.cpu() for (k, v_raw_u) in u_raw.items()}

            u_key = next(iter(u))
            n = len(u[u_key])

            region_found = T.ones((n, 1), dtype=T.long) * len(regions_list)
            for i, region in enumerate(regions_list):
                in_region = T.ones((n, 1), dtype=T.bool)
                for j, u_name in enumerate(u_pa):
                    in_region *= (region[j][0] <= u[u_name]) * (u[u_name] < region[j][1])
                region_found[in_region] = i

            region_found = T.squeeze(region_found)
            used_func = region_outputs[region_found]

            if len(parent_keys) == 0:
                return T.squeeze(used_func, dim=1)

            v_arr = T.cat([v[k] for k in parent_keys], dim=1).long()
            v_ind = T.zeros(n, dtype=T.long)
            for i in range(v_arr.shape[1]):
                v_ind = 2 * v_ind + v_arr[:, i]
            return used_func[range(n), v_ind]

        return {
            "func": template_func,
            "parent_keys": parent_keys,
            "u_keys": u_pa,
            "agg_templates": agg_parent_templates,
        }

    def _wrap_node_func(self, node, template):
        tmpl = self.template_funcs[template]
        parent_keys = tmpl["parent_keys"]
        rep = self.reps[template]
        rep_u = list(self.cg.v2c2[rep])
        cur_u = list(self.cg.v2c2[node])
        if len(rep_u) != len(cur_u):
            raise ValueError(
                f"Incompatible exogenous parents for template {template}: {node}"
            )
        u_key_map = {cur_u[i]: rep_u[i] for i in range(len(cur_u))}

        rep_pa = list(self.cg.pa[rep])
        agg_parent_templates = tmpl["agg_templates"]
        rep_direct_pa = [
            p for p in rep_pa if self.template_of(p) not in agg_parent_templates
        ]
        direct_pa = [
            p for p in self.cg.pa[node] if self.template_of(p) not in agg_parent_templates
        ]
        if len(rep_direct_pa) != len(direct_pa):
            raise ValueError(
                f"Incompatible direct parents for template {template}: {node}"
            )
        pa_key_map = {direct_pa[i]: rep_direct_pa[i] for i in range(len(direct_pa))}

        agg_sizes = {t: self.agg_sizes.get(t, 1) for t in agg_parent_templates}

        def count_to_bin(count, bits):
            if bits == 1:
                return (count > 0).long().unsqueeze(1)
            mask = 2 ** T.arange(bits - 1, -1, -1)
            return count.unsqueeze(-1).bitwise_and(mask).ne(0).long()

        def node_func(v_raw, u_raw):
            v_inputs = {}
            for cur_key, rep_key in pa_key_map.items():
                v_inputs[rep_key] = v_raw[cur_key]

            for agg_template in agg_parent_templates:
                parents = [
                    p for p in self.cg.pa[node] if self.template_of(p) == agg_template
                ]
                if not parents:
                    count = T.zeros((v_raw[next(iter(v_raw))].shape[0],), dtype=T.long)
                else:
                    vals = T.cat([v_raw[p].long() for p in parents], dim=1)
                    count = T.sum(vals, dim=1)
                bits = agg_sizes[agg_template]
                v_inputs[f"__agg_{agg_template}"] = count_to_bin(count, bits)

            u_inputs = {rep_key: u_raw[cur_key] for cur_key, rep_key in u_key_map.items()}
            return tmpl["func"](v_inputs, u_inputs)

        return node_func

    def _default_role_template_of(self, node):
        parts = node.split("_")
        if len(parts) >= 4 and parts[0] == "R":
            return "R_{}_{}".format(parts[2], parts[3])
        return None

    def _build_role_func(self, node, agg_name):
        parents = list(self.cg.pa[node])
        agg = self._get_aggregator(agg_name)

        def role_func(v_raw, u_raw):
            if parents:
                vals = T.cat([v_raw[p].float() for p in parents], dim=1)
                out = agg(vals)
            else:
                if v_raw:
                    sample_key = next(iter(v_raw))
                    n = v_raw[sample_key].shape[0]
                    device = v_raw[sample_key].device
                else:
                    u_key = next(iter(u_raw))
                    n = u_raw[u_key].shape[0]
                    device = u_raw[u_key].device
                if agg_name is not None and str(agg_name).lower() == "count":
                    out = T.zeros((n, 3), device=device)
                else:
                    out = T.zeros((n, 1), device=device)
            return out.long()

        return role_func

    def _get_aggregator(self, name):
        if name is None:
            name = "or"
        key = str(name).lower()

        if key in {"or", "max"}:
            return lambda x: T.max(x, dim=1).values.unsqueeze(1)
        if key in {"and", "min"}:
            return lambda x: T.min(x, dim=1).values.unsqueeze(1)
        if key == "sum":
            return lambda x: T.sum(x, dim=1).unsqueeze(1)
        if key == "count":
            def count_to_bits(count, bits):
                if bits == 1:
                    return (count > 0).long().unsqueeze(1)
                mask = 2 ** T.arange(bits - 1, -1, -1, device=count.device)
                return count.unsqueeze(-1).bitwise_and(mask).ne(0).long()

            def count_agg(x):
                count = T.sum(x != 0, dim=1).clamp_max(5)
                return count_to_bits(count, 3).float()

            return count_agg
        if key == "mean":
            return lambda x: T.mean(x, dim=1).unsqueeze(1)
        if key == "strict_maj":
            return lambda x: (T.sum(x, dim=1) > (x.shape[1] / 2.0)).float().unsqueeze(1)
        if key == "weak_maj":
            return lambda x: (T.sum(x, dim=1) >= (x.shape[1] / 2.0)).float().unsqueeze(1)

        raise ValueError(f"Unknown aggregator '{name}'")


if __name__ == "__main__":
    # cg = CausalGraph.read("../../dat/cg/zid_a.cg")
    # m = CTM(cg, v_size={}, regions=20, c2_scale=1.0, batch_size=100000)
    # result = m(20)
    # print(result)
    # for k in result:
    #     print("{}: {}".format(k, result[k].shape))
    # print(m(10, do={'X': T.ones((10, 1), dtype=T.long)}))

    source_nodes = ["O1_X", "O2_X", "O1_Y", "O2_Y"]
    source_edges = [
        ("O1_X", "O1_Y"),
        ("O1_X", "O2_Y"),
        ("O2_X", "O1_Y"),
        ("O2_X", "O2_Y"),
    ]
    cg_rel = CausalGraph(source_nodes, directed_edges=source_edges)

    def template_of(node):
        return node.split("_", 1)[1]

    reps = {
        "X": "O1_X",
        "Y": "O1_Y",
    }
    rctm = RelationalCTM(
        cg_rel,
        template_of=template_of,
        reps=reps,
        regions=10,
        c2_scale=1.0,
    )

    print("Template func ids:", {k: id(v["func"]) for k, v in rctm.template_funcs.items()})
    for node in ["O1_Y", "O2_Y"]:
        t = template_of(node)
        print("Node:", node, "template:", t, "template func id:", id(rctm.template_funcs[t]["func"]))
    
    rel_sample = rctm(5)
    print("RelationalCTM sample:", {k: v.shape for k, v in rel_sample.items()})
