#!/usr/bin/python

# Author: Baruch Sterin <sterin@berkeley.edu>
# Simple Python AIG package

from __future__ import annotations

import random
from collections import deque
from collections.abc import Iterator
from typing import Dict, List, Optional, Tuple, Union

import networkx as nx
import numpy as np
import torch
from future.utils import iteritems
from torchrl.envs import EnvBase

from .aig_env import AIGEnv


class _Node(object):

    # Node types

    CONST0 = 0
    PI = 1
    LATCH = 2
    AND = 3
    BUFFER = 4

    # Latch initialization

    INIT_ZERO = 0
    INIT_ONE = 1
    INIT_NONDET = 2

    def __init__(self, node_type, left=0, right=0):
        self._type = node_type
        self._left = left
        self._right = right

    # creation

    @staticmethod
    def make_const0():
        return _Node(_Node.CONST0)

    @staticmethod
    def make_pi(pi_id):
        return _Node(_Node.PI, pi_id, 0)

    @staticmethod
    def make_latch(l_id, init, next=None):
        return _Node(_Node.LATCH, l_id, (init, next))

    @staticmethod
    def make_and(left, right):
        return _Node(_Node.AND, left, right)

    @staticmethod
    def make_buffer(buf_id, buf_in):
        return _Node(_Node.BUFFER, buf_id, buf_in)

    # query type

    def is_const0(self):
        return self._type == _Node.CONST0

    def is_pi(self):
        return self._type == _Node.PI

    def is_and(self):
        return self._type == _Node.AND

    def is_buffer(self):
        return self._type == _Node.BUFFER

    def is_latch(self):
        return self._type == _Node.LATCH

    def is_nonterminal(self):
        return self._type in (_Node.AND, _Node.BUFFER)

    def get_fanins(self):
        if self._type == _Node.AND:
            return [self._left, self._right]
        elif self._type == _Node.BUFFER:
            return [self._right]
        else:
            return []

    def get_seq_fanins(self):
        if self._type == _Node.AND:
            return [self._left, self._right]
        elif self._type == _Node.BUFFER:
            return [self._right]
        elif self._type == _Node.LATCH:
            return [self._right[1]]
        else:
            return []

    # AND gates

    def get_left(self):
        assert self.is_and()
        return self._left

    def get_right(self):
        assert self.is_and()
        return self._right

    # Buffer

    def get_buf_id(self):
        return self._left

    def get_buf_in(self):
        assert self.is_buffer()
        return self._right

    def set_buf_in(self, f):
        assert self.is_buffer()
        self._right = f

    def convert_buf_to_pi(self, pi_id):
        assert self.is_buffer()
        self._type = _Node.PI
        self._left = pi_id
        self._right = 0

    # PIs

    def get_pi_id(self):
        assert self.is_pi()
        return self._left

    def get_latch_id(self):
        assert self.is_latch()
        return self._left

    # Latches

    def get_init(self):
        assert self.is_latch()
        return self._right[0]

    def get_next(self):
        assert self.is_latch()
        return self._right[1]

    def set_init(self, init):
        assert self.is_latch()
        self._right = (init, self._right[1])

    def set_next(self, f):
        assert self.is_latch()
        self._right = (self._right[0], f)

    def __repr__(self):
        type = "ERROR"
        if self._type == _Node.AND:
            type = "AND"
        elif self._type == _Node.BUFFER:
            type = "BUFFER"
        elif self._type == _Node.CONST0:
            type = "CONST0"
        elif self._type == _Node.LATCH:
            type = "LATCH"
        elif self._type == _Node.PI:
            type = "PI"
        return "<pyaig.aig._Node _type=%s, _left=%s, _right=%s>" % (
            type,
            str(self._left),
            str(self._right),
        )

    # def __eq__(self, other: _Learned_Node):
    #     if other is None:
    #         return False
    #     return self.node_id.__eq__(other.node_id)

    def __lt__(self, other: _Learned_Node):
        return self.node_id < other.node_id

    # def __le__(self, other: _Learned_Node):
    #     return self.node_id <= other.node_id

    # def __ne__(self, other: _Learned_Node):
    #     return self.node_id != other.node_id

    # def __gt__(self, other: _Learned_Node):
    #     return self.node_id > other.node_id

    # def __ge__(self, other: _Learned_Node):
    #     return self.node_id >= other.node_id


class _Learned_Node(_Node):
    CONST0 = 0
    PI = 1
    # LATCH = 2
    AND = 3
    # BUFFER = 4
    PO = 5

    def __init__(
        self,
        node_type: int,
        node_id: int,
        left: _Learned_Node | None = None,
        right: _Learned_Node | None = None,
        left_edge_type: int | None = None,
        right_edge_type: int | None = None,
        truth_table: torch.Tensor | None = None,
    ):
        self._type = node_type
        self._left = left
        self._right = right
        self._node_id = node_id
        self._truth_table = truth_table
        self._left_edge_type = left_edge_type  # 1 or -1
        self._negated_left_edge = 0
        self._right_edge_type = right_edge_type  # 1 or -1
        self._negated_right_edge = 0
        self._fanout_type = {}
        self._fanout_id_to_object = {}
        self._level = 0

    @property
    def node_type(self) -> int:
        return self._type

    @property
    def truth_table(self) -> torch.Tensor | None:
        return self._truth_table

    @property
    def node_id(self) -> int:
        return self._node_id

    @property
    def left(self) -> _Learned_Node | None:
        return self._left

    @property
    def right(self) -> _Learned_Node | None:
        return self._right

    @property
    def left_edge_type(self) -> int | None:
        return self._left_edge_type

    @property
    def right_edge_type(self) -> int | None:
        return self._right_edge_type

    @property
    def fanout_type(self) -> Dict[_Learned_Node, int]:
        return self._fanout_type

    @property
    def fanout_id_to_object(self) -> Dict[int, _Learned_Node]:
        return self._fanout_id_to_object

    @property
    def level(self) -> int:
        return self._level

    @truth_table.setter
    def truth_table(self, truth_table: torch.Tensor) -> None:
        self._truth_table = truth_table

    @node_id.setter
    def node_id(self, node_id: int) -> None:
        self._node_id = node_id

    @left.setter
    def left(self, left: _Learned_Node) -> None:
        self._left = left
        self.calculate_truth_table()

    @right.setter
    def right(self, right: _Learned_Node) -> None:
        self._right = right
        self.calculate_truth_table()

    @left_edge_type.setter
    def left_edge_type(self, left_edge_type: int) -> None:
        self._left_edge_type = left_edge_type
        self.calculate_truth_table()

    @right_edge_type.setter
    def right_edge_type(self, right_edge_type: int) -> None:
        self._right_edge_type = right_edge_type
        self.calculate_truth_table()

    @level.setter
    def level(self, level: int) -> None:
        self._level = level

    @staticmethod
    def make_po(
        node_id: int,
        input: _Learned_Node | None = None,
        edge_type: int | None = None,
        truth_table: torch.Tensor | None = None,
    ) -> _Learned_Node:
        return _Learned_Node(
            _Learned_Node.PO, node_id, input, input, edge_type, edge_type, truth_table
        )

    @staticmethod
    def make_pi(
        node_id: int, truth_table: Optional[torch.Tensor] = None
    ) -> _Learned_Node:
        return _Learned_Node(
            _Learned_Node.PI, node_id, None, None, None, None, truth_table
        )

    @staticmethod
    def make_and(
        node_id: int,
        left: _Learned_Node,
        right: _Learned_Node,
        left_edge_type: int,
        right_edge_type: int,
    ) -> _Learned_Node:
        node = _Learned_Node(
            _Learned_Node.AND,
            node_id,
            left,
            right,
            left_edge_type,
            right_edge_type,
            None,
        )
        node.calculate_truth_table()
        node.update_level()
        return node

    @staticmethod
    def make_const0(truth_table_size: int | None = None) -> _Learned_Node:
        truth_table = None
        if truth_table_size != None:
            truth_table = torch.zeros(truth_table_size, dtype=bool)
        return _Learned_Node(
            _Learned_Node.CONST0, 0, None, None, None, None, truth_table
        )

    def set_left_edge(self, left: _Learned_Node, left_edge_type: int) -> None:
        self.left = left
        self.left_edge_type = left_edge_type
        self.swap_edges()
        if self._type == self.PO:
            self.right = left
            self.right_edge_type = left_edge_type
        self.update_level()

    def set_right_edge(self, right: _Learned_Node, right_edge_type: int) -> None:
        self.right = right
        self.right_edge_type = right_edge_type
        self.swap_edges()
        if self._type == self.PO:
            self.left = right
            self.left_edge_type = right_edge_type
        self.update_level()

    def update_edge_type(self, node: _Learned_Node, edge_type: int) -> None:
        if node == self.left:
            self.left_edge_type = edge_type
        elif node == self.right:
            self.right_edge_type = edge_type
        elif node in self._fanout_type:
            self._fanout_type[node] = edge_type

    def swap_edges(self) -> None:
        if (
            self.right != None
            and self.left != None
            and self.left.node_id
            and self.left.node_id > self.right.node_id
        ):
            self.left, self.right = self.right, self.left
            self.left_edge_type, self.right_edge_type = (
                self.right_edge_type,
                self.left_edge_type,
            )

    def get_left(self) -> int | None:
        if self.left != None:
            return self.left.node_id

    def get_right(self) -> int | None:
        if self.right != None:
            return self.right.node_id

    def add_fanout(self, target: _Learned_Node, edge_type: int) -> None:
        self._fanout_type[target] = edge_type
        self._fanout_id_to_object[target.node_id] = target

    def fanout_size(self) -> int:
        return len(self._fanout_id_to_object)

    def delete_fanout(self, node: _Learned_Node | int) -> None:
        node_id = 0
        if isinstance(node, int):
            node_id = node
            node = self._fanout_id_to_object[node_id]
        else:
            node_id = node.node_id
        del self._fanout_id_to_object[node_id]
        del self._fanout_type[node]

    def calculate_truth_table(self, force: bool = False) -> None:
        if not self.is_pi():
            if (
                self.left == None
                or self.right == None
                or self.right_edge_type == None
                or self.left_edge_type == None
                or self.left.truth_table == None
                or self.right.truth_table == None
            ):
                self._truth_table = None
            elif self._type == _Learned_Node.PO and force:
                if self.left_edge_type == -1:
                    self._truth_table = ~self.left.truth_table
                else:
                    self._truth_table = self.left.truth_table
            else:
                if self.left_edge_type == -1 and self.right_edge_type == -1:
                    self._truth_table = ~self.left.truth_table & ~self.right.truth_table
                elif self.left_edge_type == -1 and self.right_edge_type == 1:
                    self._truth_table = ~self.left.truth_table & self.right.truth_table
                elif self.left_edge_type == 1 and self.right_edge_type == -1:
                    self._truth_table = self.left.truth_table & ~self.right.truth_table
                else:
                    self._truth_table = self.left.truth_table & self.right.truth_table

    def update_level(self):
        if self.left != None:
            self._level = self.left.level + 1
        if self.right != None and self.right.level + 1 > self._level:
            self._level = self.right.level + 1

    def __getitem__(self, node: _Learned_Node | int):
        if isinstance(node, int):
            return self._fanout_id_to_object[node]
        else:
            return self._fanout_type[node]

    def __setitem__(self, node, edge_type):
        self._fanout_type[node] = edge_type

    def __repr__(self):
        if self._type == _Learned_Node.AND:
            type = "AND"
        # elif self._type==_Node.BUFFER:
        #     type = "BUFFER"
        elif self._type == _Node.CONST0:
            type = "CONST0"
        # elif self._type==_Node.LATCH:
        #     type = "LATCH"
        elif self._type == _Learned_Node.PI:
            type = "PI"
            return "<pyaig.aig._Learned_Node _type=%s, _node_id=%s>" % (
                type,
                str(self.node_id),
            )
        elif self._type == _Learned_Node.PO:
            type = "PO"
        else:
            type = "UNKNOWN"
        return (
            "<pyaig.aig._Learned_Node _type=%s, _node_id=%s, _left=%s, _right=%s>"
            % (
                type,
                str(self.node_id),
                str(self.left_edge_type * self.left.node_id),
                str(self.right_edge_type * self.right.node_id),
            )
        )

    def __iter__(self) -> Iterator[_Learned_Node]:
        self._keys = list(self._fanout_id_to_object.keys())
        self._idx = 0
        return self

    def __next__(self) -> _Learned_Node:
        if self._idx == len(self._keys):
            raise StopIteration
        self._idx += 1
        return self._keys[self._idx - 1]


class AIG(object):

    # map AIG nodes to AIG nodes, take negation into account

    class fmap(object):
        def __init__(self, fs=[], negate_if_negated=None, zero=None):

            self.negate_if_negated = (
                negate_if_negated if negate_if_negated else AIG.negate_if_negated
            )
            zero = AIG.get_const0() if zero is None else zero
            self.m = {AIG.get_const0(): zero}
            if fs:
                self.update(fs)

        def __getitem__(self, f):
            return self.negate_if_negated(self.m[AIG.get_positive(f)], f)

        def __setitem__(self, f, g):
            self.m[AIG.get_positive(f)] = self.negate_if_negated(g, f)

        def __contains__(self, f):
            return AIG.get_positive(f) in self.m

        def __delitem__(self, f):
            del self.m[AIG.get_positive(f)]

        def iteritems(self):
            return iteritems(self.m)

        def update(self, fs):
            self.m.update(
                (AIG.get_positive(f), self.negate_if_negated(g, f)) for f, g in fs
            )

    class fset(object):
        def __init__(self, fs=[]):
            self.s = set(AIG.get_positive(f) for f in fs)

        def __contains__(self, f):
            return AIG.get_positive(f) in self.s

        def __len__(self):
            return len(self.s)

        def __iter__(self):
            return self.s.__iter__()

        def add(self, f):
            f = AIG.get_positive(f)
            res = f in self.s
            self.s.add(f)
            return res

        def remove(self, f):
            return self.s.remove(AIG.get_positive(f))

    # PO types

    OUTPUT = 0
    BAD_STATES = 1
    CONSTRAINT = 2
    JUSTICE = 3
    FAIRNESS = 4

    # Latch initialization

    INIT_ZERO = _Node.INIT_ZERO
    INIT_ONE = _Node.INIT_ONE
    INIT_NONDET = _Node.INIT_NONDET

    def __init__(self, name=None):
        self._name = name
        self._strash = {}
        self._pis = []
        self._latches = []
        self._buffers = []
        self._pos = []
        self._justice = []
        self._nodes = []
        self._name_to_id = {}
        self._id_to_name = {}
        self._name_to_po = {}
        self._po_to_name = {}
        self._fanouts = {}

        self._nodes.append(_Node.make_const0())

    def deref(self, f):
        return self._nodes[f >> 1]

    def name(self):
        return self._name

    # Create basic objects

    @staticmethod
    def get_const(c):
        if c:
            return AIG.get_const1()
        return AIG.get_const0()

    @staticmethod
    def get_const0():
        return 0

    @staticmethod
    def get_const1():
        return 1

    def create_pi(self, name=None):
        pi_id = len(self._pis)
        n = _Node.make_pi(pi_id)
        fn = len(self._nodes) << 1

        self._nodes.append(n)
        self._pis.append(fn)

        if name is not None:
            self.set_name(fn, name)

        return fn

    def create_latch(self, name=None, init=INIT_ZERO, next=None):
        l_id = len(self._latches)
        n = _Node.make_latch(l_id, init, next)
        fn = len(self._nodes) << 1

        self._nodes.append(n)
        self._latches.append(fn)

        if name is not None:
            self.set_name(fn, name)

        return fn

    def create_and(self, left, right):
        if left < right:
            left, right = right, left

        if right == 0:
            return 0

        if right == 1:
            return left

        if left == right:
            return right

        if left == (right ^ 1):
            return 0

        key = (_Node.AND, left, right)

        if key in self._strash:
            return self._strash[key]

        f = len(self._nodes) << 1
        self._nodes.append(_Node.make_and(left, right))

        self._strash[key] = f

        return f

    def create_buffer(self, buf_in=0, name=None):
        b_id = len(self._buffers)
        f = len(self._nodes) << 1

        self._nodes.append(_Node.make_buffer(b_id, buf_in))
        self._buffers.append(f)

        if name is not None:
            self.set_name(f, name)

        return f

    def convert_buf_to_pi(self, buf):
        assert self.is_buffer(buf)
        assert self.get_buf_in(buf) >= 0

        n = self.deref(buf)
        self._buffers[n.get_buf_id()] = -1
        n.convert_buf_to_pi(len(self._pis))
        self._pis.append(buf)

    def create_po(self, f=0, name=None, po_type=OUTPUT):
        po_id = len(self._pos)
        self._pos.append((f, po_type))

        if name is not None:
            self.set_po_name(po_id, name)

        return po_id

    def create_justice(self, po_ids):
        po_ids = list(po_ids)

        j_id = len(self._justice)

        for po_id in po_ids:
            assert self.get_po_type(po_id) == AIG.JUSTICE

        self._justice.append(po_ids)

        return j_id

    def remove_justice(self):

        for po_ids in self._justice:
            for po_id in po_ids:
                self.set_po_type(po_id, AIG.OUTPUT)

        self._justice = []

    # Names

    def set_name(self, f, name):
        assert not self.is_negated(f)
        assert name not in self._name_to_id
        assert f not in self._id_to_name

        self._name_to_id[name] = f
        self._id_to_name[f] = name

    def get_id_by_name(self, name):
        return self._name_to_id[name]

    def has_name(self, f):
        return f in self._id_to_name

    def name_exists(self, n):
        return n in self._name_to_id

    def get_name_by_id(self, f):
        return self._id_to_name[f]

    def remove_name(self, f):
        assert self.has_name(f)
        name = self.get_name_by_id(f)

        del self._id_to_name[f]
        del self._name_to_id[name]

    def iter_names(self):
        return iteritems(self._id_to_name)

    def fill_pi_names(self, replace=False, template="I_{}"):

        if replace:
            for pi in self.get_pis():
                if self.has_name(pi):
                    self.remove_name(pi)

        uid = 0

        for pi in self.get_pis():
            if not self.has_name(pi):
                while True:
                    name = template.format(uid)
                    uid += 1
                    if not self.name_exists(name):
                        break
                self.set_name(pi, name)

    # PO names

    def set_po_name(self, po, name):
        assert 0 <= po < len(self._pos)
        assert name not in self._name_to_po
        assert po not in self._po_to_name

        self._name_to_po[name] = po
        self._po_to_name[po] = name

    def get_po_by_name(self, name):
        return self._name_to_po[name]

    def po_has_name(self, po):
        return po in self._po_to_name

    def name_has_po(self, po):
        return po in self._name_to_po

    def remove_po_name(self, po):
        assert self.po_has_name(po)
        name = self.get_name_by_po(po)
        del self._name_to_po[name]
        del self._po_to_name[po]

    def get_name_by_po(self, po):
        return self._po_to_name[po]

    def iter_po_names(self):
        return (
            (po_id, self.get_po_fanin(po_id), po_name)
            for po_id, po_name in iteritems(self._po_to_name)
        )

    def fill_po_names(self, replace=False, template="O_{}"):

        if replace:
            self._name_to_po.clear()
            self._po_to_name.clear()

        po_names = set(name for _, _, name in self.iter_po_names())

        uid = 0
        for po_id, _, _ in self.get_pos():
            if not self.po_has_name(po_id):
                while True:
                    name = template.format(uid)
                    uid += 1
                    if name not in po_names:
                        break
                self.set_po_name(po_id, name)

    # Query IDs

    @staticmethod
    def get_id(f):
        return f >> 1

    def is_const0(self, f):
        n = self.deref(f)
        return n.is_const0()

    def is_pi(self, f):
        n = self.deref(f)
        return n.is_pi()

    def is_latch(self, f):
        n = self.deref(f)
        return n.is_latch()

    def is_and(self, f):
        n = self.deref(f)
        return n.is_and()

    def is_buffer(self, f):
        n = self.deref(f)
        return n.is_buffer()

    # PIs

    def get_pi_by_id(self, pi_id):
        return self._pis[pi_id]

    # Get/Set next for latches

    def set_init(self, l, init):
        assert not self.is_negated(l)
        assert self.is_latch(l)
        n = self.deref(l)
        n.set_init(init)

    def set_next(self, l, f):
        assert not self.is_negated(l)
        assert self.is_latch(l)
        n = self.deref(l)
        n.set_next(f)

    def get_init(self, l):
        assert not self.is_negated(l)
        assert self.is_latch(l)
        n = self.deref(l)
        return n.get_init()

    def get_next(self, l):
        assert not self.is_negated(l)
        assert self.is_latch(l)
        n = self.deref(l)
        return n.get_next()

    # And gate

    def get_and_fanins(self, f):
        assert self.is_and(f)
        n = self.deref(f)
        return (n.get_left(), n.get_right())

    def get_and_left(self, f):
        assert self.is_and(f)
        return self.deref(f).get_left()

    def get_and_right(self, f):
        assert self.is_and(f)
        return self.deref(f).get_right()

    # Buffer

    def get_buf_in(self, b):
        n = self.deref(b)
        return n.get_buf_in()

    def set_buf_in(self, b, f):
        assert b > f
        n = self.deref(b)
        return n.set_buf_in(f)

    def get_buf_id(self, b):
        n = self.deref(b)
        return n.get_buf_id()

    def skip_buf(self, b):
        while self.is_buffer(b):
            b = AIG.negate_if_negated(self.get_buf_in(b), b)
        return b

    # Fanins

    def get_fanins(self, f):
        n = self.deref(f)
        return n.get_fanins()

    def get_positive_fanins(self, f):
        n = self.deref(f)
        return (self.get_positive(fi) for fi in n.get_fanins())

    def get_positive_seq_fanins(self, f):
        n = self.deref(f)
        return (self.get_positive(fi) for fi in n.get_seq_fanins())

    # PO fanins

    def get_po_type(self, po):
        assert 0 <= po < len(self._pos)
        return self._pos[po][1]

    def get_po_fanin(self, po):
        assert 0 <= po < len(self._pos)
        return self._pos[po][0]

    def set_po_fanin(self, po, f):
        assert 0 <= po < len(self._pos)
        self._pos[po] = (f, self._pos[po][1])

    def set_po_type(self, po, po_type):
        assert 0 <= po < len(self._pos)
        self._pos[po] = (self._pos[po][0], po_type)

    # Justice

    def get_justice_pos(self, j_id):
        assert 0 <= j_id < len(self._justice)
        return (po for po in self._justice[j_id])

    def set_justice_pos(self, j_id, po_ids):
        assert 0 <= j_id < len(self._justice)
        for po_id in po_ids:
            assert self.get_po_type(po_id) == AIG.JUSTICE
        self._justice[j_id] = po_ids

    # Negation

    @staticmethod
    def is_negated(f):
        return (f & 1) != 0

    @staticmethod
    def get_positive(f):
        return f & ~1

    @staticmethod
    def negate(f):
        return f ^ 1

    @staticmethod
    def negate_if(f, c):
        if c:
            return f ^ 1
        else:
            return f

    @staticmethod
    def positive_if(f, c):
        if c:
            return f
        else:
            return f ^ 1

    @staticmethod
    def negate_if_negated(f, c):
        return f ^ (c & 1)

    # Higher-level boolean operations

    def create_nand(self, left, right):
        return self.negate(self.create_and(left, right))

    def create_or(self, left, right):
        return self.negate(self.create_and(self.negate(left), self.negate(right)))

    def create_nor(self, left, right):
        return self.negate(self.create_or(left, right))

    def create_xor(self, left, right):
        return self.create_or(
            self.create_and(left, self.negate(right)),
            self.create_and(self.negate(left), right),
        )

    def create_iff(self, left, right):
        return self.negate(self.create_xor(left, right))

    def create_implies(self, left, right):
        return self.create_or(self.negate(left), right)

    def create_ite(self, f_if, f_then, f_else):
        return self.create_or(
            self.create_and(f_if, f_then), self.create_and(self.negate(f_if), f_else)
        )

    # Object numbers

    def n_pis(self):
        return len(self._pis)

    def n_latches(self):
        return len(self._latches)

    def n_ands(self):
        return self.n_nonterminals() - self.n_buffers()

    def n_nonterminals(self):
        return len(self._nodes) - 1 - self.n_latches() - self.n_pis()

    def n_pos(self):
        return len(self._pos)

    def n_pos_by_type(self, type):
        res = 0
        for _ in self.get_pos_by_type(type):
            res += 1
        return res

    def n_justice(self):
        return len(self._justice)

    def n_buffers(self):
        return len(self._buffers)

    # Object access as iterators (use list() to get a copy)

    def construction_order(self):
        return (i << 1 for i in range(1, len(self._nodes)))

    def construction_order_deref(self):
        return ((f, self.deref(f)) for f in self.construction_order())

    def get_pis(self):
        return (i << 1 for i, n in enumerate(self._nodes) if n.is_pi())

    def get_latches(self):
        return (l for l in self._latches)

    def get_buffers(self):
        return (b for b in self._buffers if b >= 0)

    def get_and_gates(self):
        return (i << 1 for i, n in enumerate(self._nodes) if n.is_and())

    def get_pos(self):
        return (
            (po_id, po_fanin, po_type)
            for po_id, (po_fanin, po_type) in enumerate(self._pos)
        )

    def get_pos_by_type(self, type):
        return (
            (po_id, po_fanin, po_type)
            for po_id, po_fanin, po_type in self.get_pos()
            if po_type == type
        )

    def get_po_fanins(self):
        return (po for _, po, _ in self.get_pos())

    def get_po_fanins_by_type(self, type):
        return (po for _, po, po_type in self.get_pos() if po_type == type)

    def get_justice_properties(self):
        return ((i, po_ids) for i, po_ids in enumerate(self._justice))

    def get_nonterminals(self):
        return (i << 1 for i, n in enumerate(self._nodes) if n.is_nonterminal())

    # Python special methods

    def __len__(self):
        return len(self._nodes)

    # return the sequential cone of 'roots', stop at 'stop'

    def get_cone(self, roots, stop=[], fanins=get_positive_fanins):

        visited = set()

        dfs_stack = list(roots)

        while dfs_stack:

            cur = self.get_positive(dfs_stack.pop())

            if cur in visited or cur in stop:
                continue

            visited.add(cur)

            for fi in fanins(self, cur):
                if fi not in visited:
                    dfs_stack.append(fi)

        return sorted(visited)

    # return the sequential cone of roots

    def get_seq_cone(self, roots, stop=[]):
        return self.get_cone(roots, stop, fanins=AIG.get_positive_seq_fanins)

    def topological_sort(self, roots, stop=()):
        """topologically sort the combinatorial cone of 'roots', stop at 'stop'"""

        def fanins(f):
            if f in stop:
                return []
            return [fi for fi in self.get_positive_fanins(f)]

        visited = AIG.fset()
        dfs_stack = []

        for root in roots:

            if visited.add(root):
                continue

            dfs_stack.append((root, fanins(root)))

            while dfs_stack:

                cur, ds = dfs_stack[-1]

                if not ds:

                    dfs_stack.pop()

                    if cur is not None:
                        yield cur

                    continue

                d = ds.pop()

                if visited.add(d):
                    continue

                dfs_stack.append((d, [fi for fi in fanins(d) if fi not in visited]))

    def clean(self, pos=None, justice_pos=None):
        """return a new AIG, containing only the cone of the POs, removing buffers while attempting to preserve names"""

        aig = AIG()
        M = AIG.fmap()

        def visit(f, af):
            if self.has_name(f):
                if AIG.is_negated(af):
                    aig.set_name(AIG.get_positive(af), "~%s" % self.get_name_by_id(f))
                else:
                    aig.set_name(af, self.get_name_by_id(f))
            M[f] = af

        if pos is None:
            pos = range(len(self._pos))

        pos = set(pos)

        if justice_pos is None:
            justice_pos = range(len(self._justice))

        for j in justice_pos:
            pos.update(self._justice[j])

        cone = self.get_seq_cone(self.get_po_fanin(po_id) for po_id in pos)

        for f in self.topological_sort(cone):

            n = self.deref(f)

            if n.is_pi():
                visit(f, aig.create_pi())

            elif n.is_and():
                visit(f, aig.create_and(M[n.get_left()], M[n.get_right()]))

            elif n.is_latch():
                l = aig.create_latch(init=n.get_init())
                visit(f, l)

            elif n.is_buffer():
                assert False
                visit(f, M(n.get_buf_in()))

        for l in self.get_latches():
            if l in cone:
                aig.set_next(M[l], M[self.get_next(l)])

        po_map = {}

        for po_id in pos:
            po_f = self.get_po_fanin(po_id)
            po = aig.create_po(
                M[po_f],
                self.get_name_by_po(po_id) if self.po_has_name(po_id) else None,
                po_type=self.get_po_type(po_id),
            )
            po_map[po_id] = po

        for j in justice_pos:
            aig.create_justice([po_map[j_po] for j_po in self._justice[j]])

        return aig

    def compose(self, src, M, copy_pos=True):
        """rebuild the AIG 'src' inside 'self', connecting the two AIGs using 'M'"""

        for f in src.construction_order():

            if f in M:
                continue

            n = src.deref(f)

            if n.is_pi():
                M[f] = self.create_pi()

            elif n.is_and():
                M[f] = self.create_and(M[n.get_left()], M[n.get_right()])

            elif n.is_latch():
                M[f] = self.create_latch(init=n.get_init())

            elif n.is_buffer():
                M[f] = self.create_buffer()

        for b in src.get_buffers():
            self.set_buf_in(M[b], M[src.get_buf_in(b)])

        for l in src.get_latches():
            self.set_next(M[l], M[src.get_next(l)])

        if copy_pos:
            for po_id, po_fanin, po_type in src.get_pos():
                self.create_po(M[po_fanin], po_type=po_type)

    def cutpoint(self, f):

        assert self.is_buffer(f)
        assert self.has_name(f)

        self.convert_buf_to_pi(f)

    def build_fanouts(self):

        for f in self.construction_order():

            for g in self.get_positive_fanins(f):

                self._fanouts.setdefault(g, set()).add(f)

    def get_fanouts(self, fs):

        res = set()

        for f in fs:
            for fo in self._fanouts[f]:
                res.add(fo)

        return res

    def conjunction(self, fs):

        res = self.get_const1()

        for f in fs:
            res = self.create_and(res, f)

        return res

    def balanced_conjunction(self, fs):

        N = len(fs)

        if N < 2:
            return self.conjunction(fs)

        return self.create_and(
            self.balanced_conjunction(fs[: N / 2]),
            self.balanced_conjunction(fs[N / 2 :]),
        )

    def disjunction(self, fs):

        res = self.get_const0()

        for f in fs:
            res = self.create_or(res, f)

        return res

    def balanced_disjunction(self, fs):

        N = len(fs)

        if N < 2:
            return self.disjunction(fs)

        return self.create_or(
            self.balanced_disjunction(fs[: N / 2]),
            self.balanced_disjunction(fs[N / 2 :]),
        )

    def large_xor(self, fs):

        res = self.get_const0()

        for f in fs:
            res = self.create_xor(res, f)

        return res

    def mux(self, select, args):

        res = []

        for col in zip(*args):

            f = self.disjunction(self.create_and(s, c) for s, c in zip(select, col))
            res.append(f)

        return res

    def create_constraint(aig, f, name=None):
        return aig.create_po(aig, f, name=name, po_type=AIG.CONSTRAINT)

    def create_property(aig, f, name=None):
        return aig.create_po(aig, AIG.negate(f), name=name, po_type=AIG.BAD_STATES)

    def create_bad_states(aig, f, name=None):
        return aig.create_po(aig, f, name=name, po_type=AIG.BAD_STATES)


class Learned_AIG(AIG):
    def __init__(
        self,
        n_pis: int,
        n_pos: int,
        truth_tables: Optional[List[torch.Tensor] | torch.Tensor],
        truth_table_size: Optional[int] = None,
        name: Optional[str] = None,
        pi_names: Optional[List[str]] = None,
        po_names: Optional[List[str]] = None,
        skip_truth_tables: bool = False,
    ) -> None:

        super().__init__(name)
        self._nodes: List[_Learned_Node] = []
        self._pis: List[_Learned_Node] = []
        self._pos: List[_Learned_Node] = []
        self._id_to_object: Dict[int, _Learned_Node] = {}
        self._node_truth_tables: List[torch.Tensor] = []
        self._po_truth_tables: List[torch.Tensor] = []
        self._next_available_node_id: int = 0
        self._instantiated_truth_tables: bool

        if pi_names != None:
            assert n_pis == len(pi_names)
        if po_names != None:
            assert n_pos == len(pi_names)

        self._instantiated_truth_tables = not skip_truth_tables

        self._truth_table_size = 2 ** (n_pis)

        # Create the const node
        self.__create_const()

        # Create the PIs
        for i in range(n_pis):
            if pi_names == None:
                self.__create_pi(name=(i + 1))
            else:
                self.__create_pi(name=pi_names[i])

        # Assigned truth tables to PIs if necessary
        if self._instantiated_truth_tables:
            self.assign_pi_tts()

        if isinstance(truth_tables, torch.Tensor):
            truth_tables = [truth_tables]
        elif not self._instantiated_truth_tables or truth_tables == None:
            truth_tables = [None] * n_pos

        # Create the POs
        for i in range(n_pos):
            if po_names == None:
                self.__create_po(name=-(i + 1), truth_table=truth_tables[i])
            else:
                self.__create_po(name=po_names[i], truth_table=truth_tables[i])

    def __create_pi(
        self, name: int | str, truth_table: torch.Tensor | int | None = None
    ) -> _Learned_Node:
        pi_id = self._next_available_node_id
        self._next_available_node_id += 1
        if truth_table != None and self._truth_table_size != None:
            truth_table = self.create_truth_table(
                bin=truth_table, bits=self._truth_table_size
            )
        node = _Learned_Node.make_pi(node_id=pi_id, truth_table=truth_table)
        self._id_to_object[pi_id] = node
        self._nodes.append(node)
        self._pis.append(node)
        self.set_name(pi_id, name)

        return node

    def __create_po(
        self, name: int | str, truth_table: torch.Tensor | None
    ) -> _Learned_Node:
        po_id = -(len(self._pos) + 1)
        node = _Learned_Node.make_po(
            node_id=po_id, input=None, edge_type=None, truth_table=truth_table
        )
        self._id_to_object[po_id] = node
        self._pos.append(node)
        self.set_po_name(po_id, name)

        return node

    def __create_const(self) -> _Learned_Node:
        pi_id = self._next_available_node_id
        self._next_available_node_id += 1
        size = None
        if self._instantiated_truth_tables:
            size = self._truth_table_size
        node = _Learned_Node.make_const0(size)
        self._id_to_object[pi_id] = node
        self._nodes.append(node)
        return node

    def create_and(
        self,
        left: _Learned_Node | int,
        right: _Learned_Node | int,
        left_edge_type: int,
        right_edge_type: int,
    ) -> _Learned_Node:
        if isinstance(left, int):
            left = self._id_to_object[left]
        if isinstance(right, int):
            right = self._id_to_object[right]

        if left.node_id > right.node_id:
            left, right = right, left
            left_edge_type, right_edge_type = right_edge_type, left_edge_type

        key = (_Learned_Node.AND, id(left), id(right), left_edge_type, right_edge_type)

        if key in self._strash:
            return self._strash[key]
        node_id = self._next_available_node_id
        self._next_available_node_id += 1
        node = _Learned_Node.make_and(
            node_id, left, right, left_edge_type, right_edge_type
        )
        self._nodes.append(node)
        self._strash[key] = node
        self._id_to_object[node_id] = node
        left.add_fanout(node, left_edge_type)
        right.add_fanout(node, right_edge_type)

        return node

    def set_left_edge(
        self, source: _Learned_Node | int, target: _Learned_Node | int, edge_type: int
    ) -> None:
        if isinstance(source, int):
            source = self._id_to_object[source]
        if isinstance(target, int):
            target = self._id_to_object[target]
        target.set_left_edge(source, edge_type)
        source.add_fanout(target, edge_type)

    def set_right_edge(
        self, source: _Learned_Node | int, target: _Learned_Node | int, edge_type: int
    ) -> None:
        if isinstance(source, int):
            source = self._id_to_object[source]
        if isinstance(target, int):
            target = self._id_to_object[target]
        target.set_right_edge(source, edge_type)
        source.add_fanout(target, edge_type)

    def set_po_edge(
        self, source: _Learned_Node | int, po: _Learned_Node | int, edge_type: int
    ) -> None:
        if isinstance(source, int):
            source = self._id_to_object[source]
        if isinstance(po, int):
            po = self._id_to_object[po]
        po.set_right_edge(source, edge_type)
        source.add_fanout(po, edge_type)

    def is_negated(self, idx) -> bool:
        return True

    def instantiate_truth_tables(self) -> None:
        self.assign_const_tts()
        self.assign_pi_tts()

        for node in self._nodes:
            if node.is_and():
                node.calculate_truth_table(force=True)
        for po in self._pos:
            po.calculate_truth_table(force=True)
        self._instantiated_truth_tables = True

    def assign_pi_tts(self) -> None:
        for i in range(len(self._pis)):
            bits = 1 << i
            res = ~(~0 << bits)
            mask_bits = bits << 1
            for _ in range(len(self._pis) - (i + 1)):

                res |= res << mask_bits
                mask_bits <<= 1
            # self.cofactor_masks[0].append( res )
            # self.cofactor_masks[1].append( res << bits )
            self._pis[i].truth_table = self.create_truth_table(bin=res << bits)
            # self._pis[i].truth_table = self.create_truth_table(bin=res << bits, bits=self._truth_table_size)
        # self.all_consts = [ _truth_table(self, self.mask*c) for c in (0, 1) ]
        # self.all_vars = [ [_truth_table(self, self.cofactor_masks[c][i]) for i in range(N)] for c in (0, 1) ]

    def assign_const_tts(self) -> None:
        self._nodes[0].truth_table = torch.zeros(
            self._truth_table_size, dtype=torch.bool
        )

    def create_and_nodes_from_actions(
        self,
        actions: torch.Tensor | list[list[int]],
        const_node: bool = False,
    ) -> None:
        if isinstance(actions, torch.Tensor):
            # print(actions.shape)
            # actions = actions.reshape(actions.shape[1:])
            actions = actions.T.tolist()
            # print(actions)
        assert isinstance(actions, list)
        edges = deque(actions)

        while len(edges) > 0:
            edge_type, left, right = edges.popleft()
            if not const_node:
                left += 1
                right += 1
            if left in self._id_to_object and right in self._id_to_object:
                left_edge, right_edge = self.edge_type_decoder(edge_type)
                self.create_and(left, right, left_edge, right_edge)
            else:
                edges.append([edge_type, left, right])

            self.create_and(left, right, left_edge, right_edge)

    @staticmethod
    def from_adj_matrix(
        n_pis: int,
        adj_matrix: torch.Tensor,
        n_pos: int = 1,
        truth_tables: list[torch.Tensor] | torch.Tensor | None = None,
        pi_names: list[str] | None = None,
        po_names: list[str] | None = None,
        name: str | None = None,
    ) -> Learned_AIG:

        aig = Learned_AIG(
            n_pis=n_pis,
            n_pos=n_pos,
            truth_tables=truth_tables,
            name=name,
            pi_names=pi_names,
            po_names=po_names,
            skip_truth_tables=True,
        )
        aig.assign_const_tts()
        aig.assign_pi_tts()
        aig.create_and_nodes_from_actions(adj_matrix.nonzero().tolist())
        potential_pos = []
        for node in aig._nodes[1:]:
            node.calculate_truth_table(force=True)
            if node.fanout_size() == 0:
                potential_pos.append(node.node_id)
        assert len(potential_pos) == n_pos

        if truth_tables != None:
            if isinstance(truth_tables, torch.Tensor):
                truth_tables = [truth_tables]

            for node in potential_pos:
                for po in aig._pos:
                    if po.truth_table == node.truth_table:
                        aig.set_po_edge(node, po, 1)
                    elif po.truth_table == ~node.truth_table:
                        aig.set_po_edge(node, po, -1)
        return aig

    @staticmethod
    def from_aig_env(aig_env: AIGEnv | EnvBase) -> Learned_AIG:
        aig = Learned_AIG(
            n_pis=int(aig_env.state["num_inputs"].item()),
            n_pos=int(aig_env.n_pos.item()),
            truth_tables=[aig_env.state["target"]],
            skip_truth_tables=True,
        )
        aig.assign_const_tts()
        aig.assign_pi_tts()
        aig[-1].truth_table = aig_env.state["target"]
        actions = torch.stack(
            [aig_env.state["edge_type"], aig_env.state["left"], aig_env.state["right"]],
            dim=0,
        )
        aig.create_and_nodes_from_actions(actions.int(), aig_env.const_node)

        potential_pos = []
        for node in aig._nodes[1:]:
            node.calculate_truth_table(force=True)
            if node.fanout_size() == 0:
                potential_pos.append(node.node_id)
        # assert len(potential_pos) == aig.n_pos

        for n in potential_pos:
            for po in aig._pos:
                node = aig[n]
                if torch.equal(po.truth_table.view(-1), node.truth_table.view(-1)):
                    aig.set_po_edge(node, po, 1)
                elif torch.equal(po.truth_table.view(-1), ~node.truth_table.view(-1)):  # type: ignore
                    aig.set_po_edge(node, po, -1)

        return aig

    @staticmethod
    def edge_type_decoder(edge_type: int) -> Tuple[int, int]:
        match edge_type:
            case 0:
                return (1, 1)
            case 1:
                return (1, -1)
            case 2:
                return (-1, 1)
            case _:
                return (-1, -1)

    @classmethod
    def create_truth_table(cls, bin: int | torch.Tensor | None) -> torch.Tensor:
        if type(bin) == torch.Tensor:
            return bin
        elif bin == None:
            return None
        # np_array = (((bin & (1 << np.arange(bits, dtype=np.uint64))[::-1])) > 0).astype(bool) # Creates a np.array of bit array representing an int
        return torch.tensor([bit == "1" for bit in list("{:03b}".format(bin))])

    def to_networkx(self) -> nx.DiGraph:
        G = nx.DiGraph()

        G.add_node(0, node_type="CONST")
        G.add_nodes_from(range(1, len(self._pis) + 1), node_type="PI")
        G.add_nodes_from(range(-1, -(len(self._pos) + 1), -1), node_type="PO")
        for node in self._nodes:
            if (
                node.node_type != _Learned_Node.PI
                and node.node_type != _Learned_Node.CONST0
            ):
                G.add_edge(
                    node.left.node_id, node.node_id, edge_type=node.left_edge_type
                )
                G.add_edge(
                    node.right.node_id, node.node_id, edge_type=node.right_edge_type
                )
        for node in list(G.nodes):
            if "node_type" not in G.nodes[node]:
                G.nodes[node]["node_type"] = "AND"

        for po in self._pos:
            if po.left != None:
                G.add_edge(po.left.node_id, po.node_id, edge_type=po.left_edge_type)
        return G

    def draw(self) -> None:
        G = self.to_networkx()

        pis = [
            k for k, v in nx.get_node_attributes(G, "node_type").items() if v == "PI"
        ]
        pos = [
            k for k, v in nx.get_node_attributes(G, "node_type").items() if v == "PO"
        ]
        const = [
            k for k, v in nx.get_node_attributes(G, "node_type").items() if v == "CONST"
        ]
        ands = list(set(G.nodes()) - set(pis) - set(pos) - set(const))

        normal = [
            k for k, v in nx.get_edge_attributes(G, "edge_type").items() if v == 1
        ]
        negated = [
            k for k, v in nx.get_edge_attributes(G, "edge_type").items() if v == -1
        ]

        position = nx.nx_agraph.pygraphviz_layout(G, prog="dot")

        labels = {}
        labels[0] = 0
        for i in pis:
            labels[i] = self._id_to_name[i]

        for i in pos:
            labels[i] = self._po_to_name[i]

        nx.draw_networkx_nodes(
            G,
            position,
            nodelist=const,
            node_color="green",
            node_shape="s",
            label="CONST",
        )
        nx.draw_networkx_nodes(
            G, position, nodelist=pis, node_color="green", node_shape="s", label="PI"
        )
        nx.draw_networkx_nodes(
            G, position, nodelist=pos, node_color="red", node_shape="d", label="PO"
        )
        nx.draw_networkx_nodes(
            G, position, nodelist=ands, node_color="blue", node_shape="o", label="AND"
        )

        nx.draw_networkx_labels(G, position, labels=labels)

        nx.draw_networkx_edges(G, position, edgelist=normal)
        nx.draw_networkx_edges(G, position, edgelist=negated, style="--")

    def set_name(self, node_id: int, name: int | str) -> None:
        # assert name not in self._name_to_id
        assert node_id not in self._id_to_name

        if not isinstance(name, str):
            name = str(name)

        self._name_to_id[name] = node_id
        self._id_to_name[node_id] = name

    def get_name(self, node: int | _Learned_Node) -> str:
        if isinstance(node, _Learned_Node):
            return self._id_to_name[node.node_id]
        return self._id_to_name[node]

    def set_po_name(self, po_id: int, name: int | str) -> None:
        assert name not in self._name_to_po
        assert po_id not in self._po_to_name

        self._name_to_po[name] = po_id
        self._po_to_name[po_id] = name

    def n_ands(self):
        """Returns the number of AND gates excluding the PIs and CONST node

        Returns:
            int: number of AND nodes
        """
        return len(self._nodes) - 1 - self.n_pis()

    def get_pos(self) -> tuple[int, int, int]:
        return ((po.node_id, po.left.node_id, po.node_type) for po in self._pos)

    def __getitem__(self, node_id: int) -> _Learned_Node:
        return self._id_to_object[node_id]

    def __iter__(self) -> Iterator[_Learned_Node]:
        self._idx = 0
        return self

    def __next__(self) -> _Learned_Node:
        if self._idx == len(self._nodes):
            raise StopIteration
        self._idx += 1
        return self._nodes[self._idx - 1]

    @staticmethod
    def read_aig(path: str, skip_truth_tables: bool = True):
        old_aig = AIG()
        fin = open(path, "rb")

        header = fin.readline().split()
        assert header[0] == b"aig"

        args = [int(t) for t in header[1:]]
        (M, I, L, O, A) = args[:5]

        B = args[5] if len(args) > 5 else 0
        C = args[6] if len(args) > 6 else 0
        J = args[7] if len(args) > 7 else 0
        F = args[8] if len(args) > 8 else 0

        if I > 16:
            skip_truth_tables = True

        # print("Num PIs:", I, "-- Num POS:", O, "Total:", I+O)
        new_aig = Learned_AIG(
            n_pis=I,
            n_pos=O,
            truth_tables=None,
            truth_table_size=None,
            pi_names=None,
            po_names=None,
            skip_truth_tables=skip_truth_tables,
        )
        # print(len(new_aig._id_to_object))

        vars = []
        nexts = []

        pos_output = []
        pos_bad_states = []
        pos_constraint = []
        pos_justice = []
        pos_fairness = []

        old_to_new_id = {}
        old_to_new_id[0] = 0

        vars.append(old_aig.get_const0())
        # print("Only const in vars", vars)

        for i in range(I):
            vars.append(old_aig.create_pi())
            old_to_new_id[vars[-1]] = i + 1

        def parse_latch(line):
            tokens = line.strip().split(b" ")
            next = int(tokens[0])
            init = 0
            if len(tokens) == 2:
                if tokens[1] == "0":
                    init = Learned_AIG.INIT_ZERO
                if tokens[1] == "1":
                    init = Learned_AIG.INIT_ONE
                else:
                    init = Learned_AIG.INIT_NONDET
            return (next, init)

        for i in range(L):  # Obsolete
            vars.append(old_aig.create_latch())
            nexts.append(parse_latch(fin.readline()))

        for i in range(O):  # Fix
            pos_output.append(int(fin.readline()))

        for i in range(B):  # Obsolete
            pos_bad_states.append(int(fin.readline()))

        for i in range(C):  # Obsolete
            pos_constraint.append(int(fin.readline()))

        n_j_pos = []

        for i in range(J):  # Obsolete
            n_j_pos.append(int(fin.readline()))

        for n in n_j_pos:  # Obsolete
            pos = []
            for i in range(n):
                pos.append(int(fin.readline()))
            pos_justice.append(pos)

        for i in range(F):  # Obsolete
            pos_fairness.append(int(fin.readline()))

        def decode():
            i = 0
            res = 0
            while True:
                c = ord(fin.read(1))
                res |= (c & 0x7F) << (7 * i)
                if (c & 0x80) == 0:
                    break
                i += 1
            return res

        def lit(x):
            return old_aig.negate_if(vars[x >> 1], x & 0x1)

        edge_type_map = {1: -1, 0: 1}

        for i in range(I + L + 1, I + L + A + 1):
            d1 = decode()
            d2 = decode()
            g = i << 1
            # actual id of parent 1 is g-d1
            # actual id of parent 2 is g-d1-2
            # position of parent in the array is vars is p_id>>1 (divide by 2)
            # If the number is odd then the number is negated x&0x1, to undo and find the true parent p_id^1
            p_id1 = g - d1
            p_id2 = g - d1 - d2
            vars.append(old_aig.create_and(lit(p_id1), lit(p_id2)))
            new_node = new_aig.create_and(
                old_to_new_id[old_aig.get_positive(lit(p_id1))],
                old_to_new_id[old_aig.get_positive(lit(p_id2))],
                edge_type_map[lit(p_id1) & 1],
                edge_type_map[lit(p_id2) & 1],
            )
            old_to_new_id[vars[-1]] = new_node.node_id

        for l, v in enumerate(range(I + 1, I + L + 1)):  # Obsolete
            old_aig.set_init(vars[v], nexts[l][1])
            old_aig.set_next(vars[v], lit(nexts[l][0]))

        output_pos = []

        for i in range(len(pos_output)):
            po = pos_output[i]
            output_pos.append(old_aig.create_po(lit(po), po_type=AIG.OUTPUT))
            new_po_id = old_to_new_id[old_aig.get_positive(lit(po))]
            edge_type = edge_type_map[lit(po) & 1]
            new_aig.set_po_edge(new_po_id, -(i + 1), edge_type)
            if not skip_truth_tables:
                new_aig[-(i + 1)].calculate_truth_table(force=True)

        bad_states_pos = []

        for i in range(len(pos_bad_states)):
            bad_states_pos.append(old_aig.create_po(lit(po), po_type=AIG.BAD_STATES))

        constraint_pos = []

        for po in pos_constraint:
            constraint_pos.append(old_aig.create_po(lit(po), po_type=AIG.CONSTRAINT))

        for pos in pos_justice:
            po_ids = [old_aig.create_po(lit(po), po_type=AIG.JUSTICE) for po in pos]
            old_aig.create_justice(po_ids)

        fairness_pos = []

        for po in pos_fairness:
            fairness_pos.append(old_aig.create_po(lit(po), po_type=AIG.FAIRNESS))

        # names = set()
        # po_names = set()

        # for line in fin:
        #     m = re.match( b'i(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in names:
        #             aig.set_name( vars[int(m.group(1))+1], m.group(2))
        #             names.add(m.group(2))
        #         continue

        #     m = re.match( b'l(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in names:
        #             aig.set_name( vars[I+int(m.group(1))+1], m.group(2))
        #             names.add(m.group(2))
        #         continue

        #     m = re.match( b'o(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in po_names:
        #             aig.set_po_name( output_pos[int(m.group(1))], m.group(2))
        #             po_names.add(m.group(2))
        #         continue

        #     m = re.match( b'b(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in po_names:
        #             aig.set_po_name( bad_states_pos[int(m.group(1))], m.group(2))
        #             po_names.add(m.group(2))
        #         continue

        #     m = re.match( b'c(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in po_names:
        #             aig.set_po_name( constraint_pos[int(m.group(1))], m.group(2))
        #             po_names.add(m.group(2))
        #         continue

        #     m = re.match( b'f(\\d+) (.*)', line )
        #     if m:
        #         if m.group(2) not in po_names:
        #             aig.set_po_name( fairness_pos[int(m.group(1))], m.group(2))
        #             po_names.add(m.group(2))
        #         continue

        return new_aig

    def write_aig(self, path) -> None:

        fout = open(path, "wb")

        map_aiger = {}

        aiger_i = 0

        map_aiger[0] = aiger_i
        aiger_i += 1
        _bytes = bytearray()

        for pi in self._pis:
            map_aiger[pi.node_id] = aiger_i << 1
            aiger_i += 1

        for l in self.get_latches():  # Obsolete
            map_aiger[l] = aiger_i << 1
            aiger_i += 1

        # for g in self.get_nonterminals(): #and gates and buffers
        #     map_aiger[ g ] = (aiger_i<<1)
        #     aiger_i += 1

        for n in self._nodes:  # and gates and buffers
            if n.node_type == _Learned_Node.AND:
                map_aiger[n.node_id] = aiger_i << 1
                aiger_i += 1

        def aiger_lit(aig_lit):

            lit_pos = self.get_positive(aig_lit)
            lit = map_aiger[lit_pos]

            if self.is_negated(aig_lit):
                return lit + 1
            else:
                return lit

        def _encode(x):
            while (x & ~0x7F) > 0:
                _bytes.append((x & 0x7F) | 0x80)
                x >>= 7
            _bytes.append(x)

        I = self.n_pis()
        L = self.n_latches()
        # O = self.n_pos_by_type(Learned_AIG.OUTPUT)
        O = len(self._pos)
        A = self.n_nonterminals()
        B = self.n_pos_by_type(Learned_AIG.BAD_STATES)
        C = self.n_pos_by_type(Learned_AIG.CONSTRAINT)
        J = self.n_justice()
        F = self.n_pos_by_type(Learned_AIG.FAIRNESS)

        M = I + L + A
        _bytes.extend(b"aig %d %d %d %d %d" % (M, I, L, O, A))

        if B + C + J + F > 0:
            _bytes.extend(b" %d" % B)

        if C + J + F > 0:
            _bytes.extend(b" %d" % C)

        if J + F > 0:
            _bytes.extend(b" %d" % J)

        if F > 0:
            _bytes.extend(b" %d" % F)

        _bytes.extend(b"\n")

        _next = (I + 1) << 1
        # writer = _aiger_writer(
        #     self.n_pis(),
        #     self.n_latches(),
        #     self.n_pos_by_type(Learned_AIG.OUTPUT),
        #     self.n_nonterminals(),
        #     self.n_pos_by_type(Learned_AIG.BAD_STATES),
        #     self.n_pos_by_type(Learned_AIG.CONSTRAINT),
        #     self.n_justice(),
        #     self.n_pos_by_type(Learned_AIG.FAIRNESS),
        #     )

        # writer.write_inputs()

        # for l in self.get_latches(): #Obsolete
        #     writer.write_latch(aiger_lit(self.get_next(l)), self.get_init(l))

        for po in self._pos:
            po_id_source = po.left.node_id
            new_po_id_source = map_aiger[po_id_source]
            if po.left_edge_type == -1:
                new_po_id_source += 1
            _bytes.extend(b"%d\n" % new_po_id_source)

        # for po in self.get_po_fanins_by_type(Learned_AIG.OUTPUT):
        #     writer.write_po(aiger_lit(po))

        # for po in self.get_po_fanins_by_type(Learned_AIG.BAD_STATES): #Obsolete
        #     writer.write_po(aiger_lit(po))

        # for po in self.get_po_fanins_by_type(Learned_AIG.CONSTRAINT): #Obsolete
        #     writer.write_po(aiger_lit(po))

        # for _, j_pos in self.get_justice_properties(): #Obsolete
        #     writer.write_justice_header(j_pos)

        # for _, j_pos in self.get_justice_properties(): #Obsolete
        #     for po_id in j_pos:
        #         writer.write_po( aiger_lit( self.get_po_fanin(po_id) ) )

        # for po in self.get_po_fanins_by_type(Learned_AIG.FAIRNESS): #Obsolete
        #     writer.write_po(aiger_lit(po))

        # for g in self.get_nonterminals(): #These are the ids of the nodes
        # for n in self_nodes
        #     n = self.deref(g) #This gets the position of the node in the array
        #     if n.is_buffer(): #Obsolete
        #         al = ar = aiger_lit( n.get_buf_in() )
        #     else:
        #         al = map_aiger[n.left.node_id]
        #         ar = map_aiger[n.right.node_id]

        #         if n.left_edge_type == -1:
        #             al += 1
        #         if n.right_edge_type == -1:
        #             ar += 1
        #         # al = aiger_lit(n.get_left())
        #         # ar = aiger_lit(n.get_right())
        #     writer.write_and(al, ar)

        for n in self._nodes:
            if n.node_type == _Learned_Node.AND:
                al = map_aiger[n.left.node_id]
                ar = map_aiger[n.right.node_id]

                if n.left_edge_type == -1:
                    al += 1
                if n.right_edge_type == -1:
                    ar += 1
                if al < ar:
                    al, ar = ar, al
                _encode(_next - al)
                _encode(al - ar)
                _next += 2

            # writer.write_and(al, ar)

        # Write symbol table

        # for i, pi in enumerate(self.get_pis()): # Can be skipped
        #     if self.has_name(pi):
        #         writer.write_input_name(i, self.get_name_by_id(pi) )

        # for i, l in enumerate(self.get_latches()): # Can be skipped
        #     if self.has_name(l):
        #         writer.write_latch_name(i, self.get_name_by_id(l) )

        # for i, (po_id, _, _) in enumerate(self.get_pos_by_type(AIG.OUTPUT)): # Can be skipped
        #     if self.po_has_name(po_id):
        #         writer.write_po_name(b'o', i, self.get_name_by_po(po_id) )

        # for i, (po_id, _, _) in enumerate(self.get_pos_by_type(AIG.BAD_STATES)): # Can be skipped
        #     if self.po_has_name(po_id):
        #         writer.write_po_name(b'b', i, self.get_name_by_po(po_id) )

        # for i, (po_id, _, _) in enumerate(self.get_pos_by_type(AIG.CONSTRAINT)): # Can be skipped
        #     if self.po_has_name(po_id):
        #         writer.write_po_name(b'c', i, self.get_name_by_po(po_id) )

        # for i, po_ids in self.get_justice_properties(): # Obsolete

        #     if not po_ids:
        #         continue

        #     po_id = po_ids[0]

        #     if self.po_has_name(po_id):
        #         writer.write_po_name(b'j', i, self.get_name_by_po(po_id) )

        # for i, (po_id, _, _) in enumerate(self.get_pos_by_type(AIG.FAIRNESS)): #Obsolete
        #     if self.po_has_name(po_id):
        #         writer.write_po_name(b'f',i, self.get_name_by_po(po_id) )

        fout.write(_bytes)
        fout.close()
        # fout.write( writer.get_bytes() )

        return map_aiger

    def prepare_data(
        self, embedding_size: int | None = None
    ) -> tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        if not self._instantiated_truth_tables:
            self.instantiate_truth_tables()

        self.collect_truth_tables(embedding_size)

        (
            actions,
            edge_type_idx,
            left_parent_idx,
            right_parent_idx,
        ) = self.collect_actions()
        # a = torch.stack(self._node_truth_tables).to(torch.float32)
        # b = torch.stack(self._po_truth_tables).to(torch.float32)
        return (
            torch.stack(self._node_truth_tables),
            torch.stack(self._po_truth_tables),
            actions,
            edge_type_idx,
            left_parent_idx,
            right_parent_idx,
        )

    def collect_truth_tables(self, embedding_size: int | None = None) -> None:
        self._node_truth_tables.clear()
        self._po_truth_tables.clear()

        if not self._instantiated_truth_tables:
            self.instantiate_truth_tables()

        repeat_factor = 1
        if embedding_size != None:
            repeat_factor = embedding_size // self._truth_table_size

        for node in self._nodes:
            if repeat_factor > 1:
                self._node_truth_tables.append(node.truth_table.repeat(repeat_factor))
            else:
                self._node_truth_tables.append(node.truth_table)

        for po in self._pos:
            if repeat_factor > 1:
                self._po_truth_tables.append(po.truth_table.repeat(repeat_factor))
            else:
                self._po_truth_tables.append(po.truth_table)

    def update_node_truth_tables(self) -> None:
        repeat_factor = (
            torch.numel(self._node_truth_tables[-1]) // self._truth_table_size
        )
        for node in self._nodes[len(self._node_truth_tables) :]:
            if repeat_factor > 1:
                self._node_truth_tables.append(node.truth_table.repeat(repeat_factor))
            else:
                self._node_truth_tables.append(node.truth_table)

    def get_truth_tables(self) -> torch.Tensor:
        return (
            torch.stack(self._node_truth_tables + self._po_truth_tables)
            .to(torch.float32)
            .unsqueeze(0)
        )

    def get_action_mask(self) -> torch.Tensor:
        (
            actions,
            edge_type_idx,
            left_parent_idx,
            right_parent_idx,
        ) = self.collect_actions()
        src_mask = torch.full(
            (len(self._nodes), len(self._nodes)),
            torch.finfo(torch.float32).min,
            dtype=torch.float32,
        )
        src_mask = torch.triu(src_mask, diagonal=0).T
        src_mask = src_mask.repeat(4, 1, 1)
        src_mask[edge_type_idx, left_parent_idx, right_parent_idx] = float("-inf")
        return src_mask

    def collect_actions(
        self,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        actions = torch.zeros(4, len(self._nodes), len(self._nodes), dtype=torch.bool)
        edge_type_idx = []
        left_parent_idx = []
        right_parent_idx = []

        for node in self._nodes:
            if node.is_and():
                if node.left_edge_type == 1 and node.right_edge_type == 1:
                    edge_type_idx.append(0)
                elif node.left_edge_type == 1 and node.right_edge_type == -1:
                    edge_type_idx.append(1)
                elif node.left_edge_type == -1 and node.right_edge_type == 1:
                    edge_type_idx.append(2)
                elif node.left_edge_type == -1 and node.right_edge_type == -1:
                    edge_type_idx.append(3)

                left_parent_idx.append(node.left.node_id)
                right_parent_idx.append(node.right.node_id)

        edge_type_idx = torch.tensor(edge_type_idx, dtype=torch.int32)
        left_parent_idx = torch.tensor(left_parent_idx, dtype=torch.int32)
        right_parent_idx = torch.tensor(right_parent_idx, dtype=torch.int32)
        actions[edge_type_idx, left_parent_idx, right_parent_idx] = True
        return actions, edge_type_idx, left_parent_idx, right_parent_idx

    def create_and_from_tensor(
        self, action: torch.Tensor, temperature: float = 0.000001
    ) -> _Learned_Node:
        edge_type = [(1, 1), (1, -1), (-1, 1), (-1, -1)]

        probabilities = torch.nn.functional.softmax(
            action.squeeze().view(-1) / temperature, dim=-1
        )
        new_idx = torch.multinomial(probabilities, 1)
        # print(new_idx, torch.argmax(probabilities))
        x = new_idx // len(self._nodes) ** 2
        left = new_idx % len(self._nodes) ** 2 // len(self._nodes)
        right = new_idx - (x * len(self._nodes) ** 2 + left * len(self._nodes))

        # print(x, y, z)
        left_edge_type, right_edge_type = edge_type[x.item()]
        return self.create_and(
            left.item(), right.item(), left_edge_type, right_edge_type
        )

    def create_and_from_tensor_max(self, action: torch.Tensor) -> _Learned_Node:
        edge_type = [(1, 1), (1, -1), (-1, 1), (-1, -1)]
        idx = (action.squeeze() == torch.max(action)).nonzero().squeeze()
        if len(idx.shape) == 2:
            i = random.randint(0, idx.shape[0] - 1)
            idx = idx[i]
        left_edge_type, right_edge_type = edge_type[idx[0]]
        left = idx[1].item()
        right = idx[2].item()
        return self.create_and(left, right, left_edge_type, right_edge_type)

    def clean_up(self) -> None:
        deleted_nodes = {}
        for node in self._nodes:
            if node.is_and() and node.fanout_size() == 0:
                self.delete_node(node, deleted_nodes)
        # print("Num deleted nodes:", len(deleted_nodes.keys()))
        self.remap_nodes(deleted_nodes)

    def remap_nodes(self, deleted_nodes: dict[int, int]) -> None:
        new_nodes = []
        new_id = 0
        for node in self._nodes:
            if not node.is_and():
                new_nodes.append(node)
                new_id += 1
            elif node.node_id not in deleted_nodes:
                if node.node_id != new_id:
                    # update where ids are used in AIG
                    self._id_to_object[new_id] = node
                    if node.node_id in self._id_to_name:
                        self._id_to_name[new_id] = self._id_to_name[node.node_id]
                    # update the id fanout from the parents
                    if node.left is not None:
                        node.left.delete_fanout(node)
                        node.left.add_fanout(node, node.left_edge_type)

                    if node.right is not None:
                        node.right.delete_fanout(node)
                        node.right.add_fanout(node, node.right_edge_type)
                    # change node id
                    node.node_id = new_id

                new_id += 1
                new_nodes.append(node)
        self._nodes = new_nodes

    def delete_node(self, node: _Learned_Node, deleted_nodes: dict[int, int]) -> None:
        deleted_nodes[node.node_id] = 1
        left = node.left
        right = node.right
        if left is not None:
            left.delete_fanout(node)
            if left.fanout_size() == 0:
                self.delete_node(left, deleted_nodes)
        if right is not None:
            right.delete_fanout(node)
            if right.fanout_size() == 0:
                self.delete_node(right, deleted_nodes)
