import glob
import logging
import math
import os
import random
import tempfile
import time
from dataclasses import dataclass
from typing import Tuple

import data

import hydra
import matplotlib.pyplot as plt

import pandas as pd
import torch
import torch.multiprocessing as mp
import torchrl
from data.pyaig.aig import Learned_AIG
from data.pyaig.aig_env import AIGEnv

from hydra.core.config_store import ConfigStore
from omegaconf import ListConfig, MISSING, OmegaConf

from torch.multiprocessing import Manager
from torchrl.collectors import SyncDataCollector
from torchrl.envs.transforms import Reward2GoTransform

from tqdm import tqdm

from model.model_lib import *
from model.utils import load_snapshot
from optimizer.optimizer_lib import *
from scheduler.scheduler_lib import *
from loss.loss_lib import *
from data.data_lib import *

from boolformer import load_boolformer

# from rl.mcts_policy import *
from misc_utils import time_formatter
from rl import *
import sys

import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    import mlflow
    from mlflow.metrics import make_metric, MetricValue
# warnings.filterwarnings("ignore", category=DeprecationWarning)

from rl.utils import generate_AIG

from rl.mcts_policy import *


# sys.setrecursionlimit(100)
@tensorclass(autocast=True)
class Result:
    aig_file: str
    aig_size: int
    last_generated_tt: str
    target_tt: str
    truth_tables: list[str]
    deleted_nodes: int
    elapsed_time: float


@dataclass
class EvalAgentConfig:
    # Model
    # path: str = "runs/pre_AlmaEncoder_64_v002/"
    # path: str = "runs/pre_AlmaEncoder_64_v045/"
    # path: str = "runs/pre_AlmaEncoder_emb=64_heads=8_v0153/"
    path: str = "runs/pre_ShortCircuit_emb=64_heads=8_v0047/"
    # model_name: str = "eval_best_policy_loss_model"
    mlflow_model_uri: str | None = "254946076cb24f3abc539679680f7c17"
    model_name: str = "model_50"
    device: str | None = "cpu"
    device_id: Any = 0
    abc_path: str = "src/data/abc/build/abc"

    # Policy
    AZ: AlphaZeroConfig = MISSING

    # Env
    const_node: bool = True

    # Evaluation
    aigs_path: str = "data/unoptimized/6_inputs/**/*.aig"
    mlflow_uri: str = "http://localhost:5001"
    shuffle: bool = True
    AZ_workers: int = 4
    max_nodes: int = 20
    test_limit: int = 500
    retries: int = 1
    seed: int = 42
    reutilize_tree: bool = True

    # Quality Testing
    test_quality: bool = False

    # Ablation
    test_ablation: bool = False
    ablation_sims: Any = (1, 2, 4, 8, 16, 32, 64, 128)

    # Debugging
    test: bool = False
    debug: bool = False
    easy_target: bool = True
    plot_graph: bool = False
    plot_freq: int = 10


cs = ConfigStore.instance()
cs.store(name="AlphaZero_config", node=AlphaZeroConfig, group="policy")
cs.store(name="eval_agent_base_config", node=EvalAgentConfig)


def tt2str(tt: torch.Tensor) -> str:
    return "".join([str(int(x)) for x in tt.tolist()])


def boolformer_to_aig(bool_root, num_inputs):
    new_aig = Learned_AIG(num_inputs, 1, None)

    def get_child_node(node):
        out_edge = 1
        left_edge = 1
        right_edge = 1
        if node.value == "not":
            out_edge *= -1
            node = node.children[0]

        if node.value == "or":
            out_edge *= -1
            left_edge = -1
            right_edge = -1

        if len(node.children) == 0:
            return (out_edge, new_aig[int(node.value[-1]) + 1])
        else:
            left_edge_trans, left = get_child_node(node.children[0])
            right_edge_trans, right = get_child_node(node.children[1])
            new_node = new_aig.create_and(
                left, right, left_edge_trans * left_edge, right_edge_trans * right_edge
            )
            return (out_edge, new_node)

    out_edge, root = get_child_node(bool_root)
    new_aig.set_po_edge(root, -1, out_edge)
    return new_aig


def init_aig_env_from_file(aig_file: str, aig_env: EnvBase) -> Learned_AIG:
    # Read AIG
    test_aig = Learned_AIG.read_aig(aig_file)
    test_aig.instantiate_truth_tables()

    # Reset AIG environment
    aig_env.reset(
        TensorDict(
            {
                "num_inputs": torch.tensor(
                    [test_aig.n_pis()], dtype=torch.int32, device=aig_env.device
                ),
                "target": test_aig[-1].truth_table.unsqueeze(0).to(device=aig_env.device),  # type: ignore
            },
            batch_size=torch.Size(),
            device=aig_env.device,
        )
    )
    return test_aig


def test_quality_helper(
    model: torch.nn.Module,
    aigs_queue: mp.Queue,
    results_queue: mp.Queue,
    cfg: EvalAgentConfig,
    device: torch.device,
) -> None:
    # device = torch.device("cpu")
    aig_env = AIGEnv(model.embedding_size, const_node=cfg.const_node).to(device)
    aig_env.state = aig_env.state.to(device)
    while not aigs_queue.empty():
        aig_f = aigs_queue.get()
        # prepare AIG environment
        test_aig = init_aig_env_from_file(aig_f, aig_env)

        # Generate new AIG
        start_time = time.time()
        success = generate_AIG(model, aig_env, cfg.max_nodes, cfg.AZ)
        elapsed_time = time.time() - start_time

        # Append original and new AIG sizes
        new_size = -1
        deleted_nodes = 0
        if success:
            aig_env.state = aig_env.state.to("cpu")
            new_aig = Learned_AIG.from_aig_env(aig_env)
            aig_env.state = aig_env.state.to(device)
            size_before_clean = new_aig.n_ands()
            new_aig.clean_up()
            new_size = new_aig.n_ands()
            deleted_nodes = size_before_clean - new_size

        r = Result(
            aig_file=aig_f,
            aig_size=new_size,
            last_generated_tt=tt2str(aig_env.state["nodes"][..., -1, :].view(-1)),
            target_tt=tt2str(test_aig[-1].truth_table),  # type: ignore
            truth_tables=[tt2str(aig_env.state["nodes"][i]) for i in range(len(aig_env.state["nodes"]))],  # type: ignore
            deleted_nodes=deleted_nodes,
            elapsed_time=elapsed_time,
        )

        results_queue.put(r)


def parallel_test_quality(
    model: torch.nn.Module,
    aigs: list[str],
    cfg: EvalAgentConfig,
    device_list: list[torch.device] = [torch.device("cpu")],
) -> List[Result]:

    m = Manager()
    results_queue = m.Queue()
    aigs_queue = mp.Queue()

    for aig_f in aigs:
        aigs_queue.put(aig_f)

    models = [
        copy.deepcopy(model).to(device=device_list[i], non_blocking=True)
        for i in range(len(device_list))
    ]

    # launch processes
    processes = []
    for i in range(cfg.AZ_workers):
        p = mp.Process(
            target=test_quality_helper,
            args=(
                models[i % len(models)],
                aigs_queue,
                results_queue,
                cfg,
                device_list[i % len(device_list)],
            ),
        )
        p.start()
    processes.append(p)

    # Main process progress tracking
    results = []
    qsize = len(aigs)
    last_progress = 0
    with tqdm(unit="truth-table", total=qsize) as progress_bar:
        while not aigs_queue.empty() or len(results) < qsize:
            while not results_queue.empty():
                try:
                    result = results_queue.get()
                    results.append(result)
                except:
                    break

            progress_bar.update(len(results) - last_progress)
            last_progress = len(results)
            progress_bar.refresh()
            time.sleep(1)

    for p in processes:
        p.join()

    while not results_queue.empty():
        result = results_queue.get()
        results.append(result)

    return results


def simple_test_quality(
    model: torch.nn.Module,
    aigs: list[str],
    cfg: EvalAgentConfig,
    device: torch.device,
) -> List[Result]:

    results = []
    aig_env = AIGEnv(model.embedding_size, const_node=cfg.const_node)
    # aig_env.state.to(device)
    for aig_f in tqdm(aigs):
        # Read AIG
        test_aig = init_aig_env_from_file(aig_f, aig_env)

        # Plot test AIG
        if cfg.plot_graph:
            test_aig.draw()
            plt.show()

        # Generate new AIG
        start_time = time.time()
        success = generate_AIG(model, aig_env, cfg.max_nodes, cfg.AZ)
        elapsed_time = time.time() - start_time

        if cfg.plot_graph:
            new_aig = Learned_AIG.from_aig_env(aig_env)
            new_aig.draw()
            plt.show()

        # Append original and new AIG sizes
        new_size = -1
        deleted_nodes = 0
        orig_size = test_aig.n_ands()
        if success:
            new_aig = Learned_AIG.from_aig_env(aig_env)
            num_nodes_no_cleanup = new_aig.n_ands()
            new_aig.clean_up()
            new_size = new_aig.n_ands()
            deleted_nodes = num_nodes_no_cleanup - new_size

        if orig_size < new_size and cfg.plot_graph and cfg.debug:
            test_aig.draw()
            plt.show()
            new_aig.draw()
            plt.show()

        r = Result(
            aig_file=aig_f,
            aig_size=new_size,
            last_generated_tt=tt2str(aig_env.state["nodes"][..., -1, :].view(-1)),
            target_tt=tt2str(aig_env.state["target"].view(-1)),
            truth_tables=[
                tt2str(aig_env.state["nodes"][i])
                for i in range(len(aig_env.state["nodes"]))
            ],
            deleted_nodes=deleted_nodes,
            elapsed_time=elapsed_time,
        )  # type: ignore
        results.append(r)

    return results


def test_greedy_generation(
    model: torch.nn.Module,
    aigs: list[str],
    cfg: EvalAgentConfig,
    device: torch.device | None = None,
) -> List[Result]:

    results = []
    aig_env = AIGEnv(model.embedding_size, const_node=cfg.const_node)
    # aig_env.state.to(device)
    for aig_f in tqdm(aigs):
        # Read AIG
        test_aig = init_aig_env_from_file(aig_f, aig_env)

        # Plot test AIG
        if cfg.plot_graph:
            test_aig.draw()
            plt.show()

        # Generate new AIG
        start_time = time.time()
        success = generate_AIG_greedy(model, aig_env, cfg.max_nodes, cfg.AZ)
        elapsed_time = time.time() - start_time

        if cfg.plot_graph:
            new_aig = Learned_AIG.from_aig_env(aig_env)
            new_aig.draw()
            plt.show()

        # Append original and new AIG sizes
        new_size = -1
        deleted_nodes = 0
        orig_size = test_aig.n_ands()
        if success:
            new_aig = Learned_AIG.from_aig_env(aig_env)
            num_nodes_no_cleanup = new_aig.n_ands()
            new_aig.clean_up()
            new_size = new_aig.n_ands()
            deleted_nodes = num_nodes_no_cleanup - new_size

        if orig_size < new_size and cfg.plot_graph and cfg.debug:
            test_aig.draw()
            plt.show()
            new_aig.draw()
            plt.show()

        r = Result(
            aig_file=aig_f,
            aig_size=new_size,
            last_generated_tt=tt2str(aig_env.state["nodes"][..., -1, :].view(-1)),
            target_tt=tt2str(aig_env.state["target"].view(-1)),
            truth_tables=[
                tt2str(aig_env.state["nodes"][i])
                for i in range(len(aig_env.state["nodes"]))
            ],
            deleted_nodes=deleted_nodes,
            elapsed_time=elapsed_time,
        )  # type: ignore
        results.append(r)

    return results


def process_results(
    logger: logging.Logger,
    successful_aigs: List[Tuple[str, int, int]],
    failed_aigs: List[Tuple[str, int, int]],
    problematic_aigs: List[Tuple[str, int, int]],
):
    total_aigs = len(successful_aigs) + len(failed_aigs)

    success_orig_size = [x[1] for x in successful_aigs]
    avg_orig_size = sum(success_orig_size) / len(success_orig_size)
    success_new_size = [x[2] for x in successful_aigs]
    avg_new_size = sum(success_new_size) / len(success_new_size)

    problematic_aig_files = {}
    for f, s1, s2 in problematic_aigs:
        problematic_aig_files[f] = True

    failed_orig_size = [x[1] for x in failed_aigs]
    avg_aig_size = (sum(success_orig_size) + sum(failed_orig_size)) / total_aigs

    # Successful AIGs without problematic
    best_orig_circuits = [
        x[1] for x in successful_aigs if x[0] not in problematic_aig_files
    ]
    avg_best_orig_size = sum(best_orig_circuits) / len(best_orig_circuits)
    best_new_circuits = [
        x[2] for x in successful_aigs if x[0] not in problematic_aig_files
    ]
    avg_best_new_size = sum(best_new_circuits) / len(best_new_circuits)

    logger.info(f"Successful AIG creation: {len(successful_aigs)}")
    logger.info(f"Failed AIG creation: {len(failed_aigs)}")
    logger.info(f"Problematic AIG: {len(problematic_aigs)}")
    logger.info(f"Success without problematic AIG: {len(best_new_circuits)}")
    logger.info(f"Total tested AIGs: {total_aigs}")
    logger.info(f"Success rate: {len(successful_aigs)/total_aigs}")
    logger.info(f"Average AIG size: {avg_aig_size}")
    logger.info(f"Average size of generated AIG: {avg_new_size}")
    logger.info(f"Average size of original AIG: {avg_orig_size}")
    logger.info(f"Size reduction {(avg_new_size - avg_orig_size)/avg_orig_size}")

    if avg_best_orig_size > 0:
        logger.info(
            f"Average size of generated AIG (without problematic AIGs): {avg_best_new_size}"
        )
        logger.info(
            f"Average size of original AIG (without problematic AIGs): {avg_best_orig_size}"
        )
        logger.info(
            f"Size reduction (without problematic AIGs): {(avg_best_orig_size - avg_best_new_size)/avg_best_orig_size}"
        )


def df_from_results(results: List[Result]) -> pd.DataFrame:
    df = pd.DataFrame(
        [
            {
                "AIG file": r.aig_file,
                "AIG size": r.aig_size,
                "Last generated TT": r.last_generated_tt,
                "Target TT": r.target_tt,
                "Truth Tables": r.truth_tables,
                "Deleted nodes": r.deleted_nodes,
                "Elapsed time": r.elapsed_time,
            }
            for r in results
        ]
    )
    return df


def test_mcts(model: torch.nn.Module, cfg, device):
    actor_value_agent = get_actor_value_model(model)
    test_aig = Learned_AIG.read_aig("data/unoptimized/6_inputs/voter/14757_0.aig")
    test_aig.instantiate_truth_tables()
    aig_env = AIGEnv(model.embedding_size, cfg.const_node).to(device)
    aig_env.state = aig_env.state.to(device)

    if cfg.easy_target:
        target = torch.tensor(
            [
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                0,
            ],
            dtype=torch.bool,
            device=device,
        ).unsqueeze(0)
    else:
        target = torch.tensor(
            [
                0,
                1,
                1,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                0,
                1,
                1,
                0,
            ],
            dtype=torch.bool,
        ).unsqueeze(0)

    easy_target = torch.tensor(
        [
            1,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
        ],
        dtype=torch.bool,
    ).unsqueeze(0)

    aig_env.reset(
        TensorDict(
            {
                "num_inputs": torch.tensor(
                    [test_aig.n_pis()], dtype=torch.int32, device=device
                ),
                "target": test_aig[-1].truth_table.unsqueeze(0).to(device=device),
                # "target": target.to(device=device),
            },
            batch_size=torch.Size(),
            device=device,
        )
    )

    tree_strategy = UpdateTreeStrategy(
        value_network=actor_value_agent.get_value_operator(),
        use_value_network=cfg.AZ.use_value_network,
    )
    expansion_strategy = AlphaZeroExpansionStrategy(
        policy_module=actor_value_agent.get_policy_operator(),
    )
    selection_strategy = PuctSelectionPolicy(cfg.AZ.c_puct)
    # selection_strategy = PUCTSelectionPolicy(cfg.AZ.c_puct)
    exploration_strategy = ActionExplorationModule()
    mcts_policy = MctsPolicy(
        # mcts_policy = MCTSPolicy(
        expansion_strategy=expansion_strategy,
        selection_strategy=selection_strategy,
        exploration_strategy=exploration_strategy,
    )
    policy = SimulatedSearchPolicy(
        policy=mcts_policy,
        tree_updater=tree_strategy,
        env=aig_env,
        num_simulations=cfg.AZ.num_simulations,
        simulation_max_steps=cfg.AZ.simulation_max_steps,
        max_steps=cfg.max_nodes,
        reutilize_tree=cfg.reutilize_tree,
    )

    # Run AlphaZero
    t2 = time.time()
    rollout = aig_env.rollout(
        policy=policy, max_steps=cfg.max_nodes, return_contiguous=False
    )
    print(time.time() - t2)

    if cfg.plot_graph:
        Learned_AIG.from_aig_env(aig_env).draw()
        plt.show()

    print("Next done:", rollout["next", "done"][-1])  # type: ignore
    print("Next terminated:", rollout["next", "terminated"][-1])  # type: ignore
    print("Next reward:", rollout["next", "reward"][-1])  # type: ignore
    # print(rollout["done"])  # type: ignore

    # print(rollout)
    r = Reward2GoTransform(0.5, ("next", "reward"), ("next", "reward"))  # type: ignore
    rollout = r.inv(rollout)  # type: ignore
    # print(rollout["next", "done"])
    print("Transformed next reward:", rollout["next", "reward"])
    # print(rollout["done"])
    # print(rollout["reward"])

    her = HindsightExperienceReplayTransform(
        SubGoalAssigner=AIGSubGoalAssigner(),
        RewardTransform=AIGRewardTransform(),
        PostTransaform=AIGNegateTarget(),
    )
    augmented_traj = her.her_augmentation(rollout)
    # print(augmented_traj)
    # print(rollout)

    def make_env(aig_env, reset_td):
        aig_env.reset(reset_td)
        return aig_env

    rstd = TensorDict(
        {
            "num_inputs": torch.tensor(
                [test_aig.n_pis()], dtype=torch.int32, device=device
            ),
            "target": test_aig[-1].truth_table.unsqueeze(0).to(device=device),
            # "target": target.to(device=device),
        },
        batch_size=torch.Size(),
        device=device,
    )
    aig_env.reset(rstd)
    partial_make_env = partial(make_env, aig_env=aig_env, reset_td=rstd)
    parallel_env = torchrl.envs.SerialEnv(2, partial_make_env)
    p_rollout = parallel_env.rollout(
        policy=policy, max_steps=20, return_contiguous=False
    )

    print(p_rollout)
    # collector = SyncDataCollector(
    #     aig_env,
    #     policy,
    #     frames_per_batch=cfg.max_nodes,
    # )

    # for data in collector:
    #     print(data)

    # print(augmented_traj["next", "done"])
    # print(augmented_traj["next", "reward"])
    # print(augmented_traj["done"])
    # print(augmented_traj["reward"])
    # print(rollout["next", "reward"])
    # print(rollout["next", "done"])
    # print(rollout["reward"])
    # return rollout


def get_path(cfg) -> Tuple[str, str]:
    # Path resolution
    if cfg.path == "latest" or cfg.path is None:
        path = max(glob.glob(os.path.join("runs/", "*/")), key=os.path.getmtime)
        model_config_path = path + "/config.yaml"
        model_path = path + (
            cfg.model_name if ".pt" in cfg.model_name else cfg.model_name + ".pt"
        )
    else:
        model_config_path = cfg.path + "/config.yaml"
        model_path = cfg.path + (
            cfg.model_name if ".pt" in cfg.model_name else cfg.model_name + ".pt"
        )
    return model_config_path, model_path


def success(new_size: pd.Series) -> MetricValue:
    success_col = new_size > 0
    return MetricValue(
        scores=list(success_col),
        aggregate_results={
            "rate": success_col.sum() / len(success_col),
            "count": success_col.sum(),
        },
    )


def size_reduction(eval_df: pd.DataFrame, _builtin_metrics) -> MetricValue:
    return stat_comparison(eval_df["target"], eval_df["prediction"])


def float_stat(stat: pd.Series) -> MetricValue:
    return MetricValue(
        scores=list(stat),
        aggregate_results={
            "mean": stat.mean(),
            "std": stat.std(),
            "max": stat.max(),
            "min": stat.min(),
        },
    )


def aig_size(aig_size: pd.Series) -> MetricValue:
    return float_stat(aig_size)


def deleted_nodes(deleted_nodes: pd.Series) -> MetricValue:
    return float_stat(deleted_nodes)


def stat_comparison(old_stat: pd.Series, new_stat: pd.Series) -> MetricValue:
    size_reduction = old_stat - new_stat
    relative_size_reduction = size_reduction / old_stat

    return MetricValue(
        scores=list(size_reduction),
        aggregate_results={
            "mean": size_reduction.mean(),
            "std": size_reduction.std(),
            "max": size_reduction.max(),
            "min": size_reduction.min(),
            "relative": relative_size_reduction.mean(),
        },
    )


def time_stat(elapsed_time: pd.Series) -> MetricValue:
    return float_stat(elapsed_time)


def mlflow_log_results(df: pd.DataFrame, logger: logging.Logger):
    aig_size_metric = make_metric(
        eval_fn=aig_size,
        greater_is_better=False,
        name="aig_size",
    )

    elapsed_time_metric = make_metric(
        eval_fn=time_stat,
        greater_is_better=False,
        name="elapsed_time",
    )

    del_nodes_metric = make_metric(
        eval_fn=deleted_nodes,
        greater_is_better=False,
        name="deleted_nodes",
    )

    eval = mlflow.evaluate(
        data=df,
        predictions="AIG size",
        extra_metrics=[
            aig_size_metric,
            elapsed_time_metric,
            del_nodes_metric,
        ],
        evaluator_config={
            # "metric_prefix": "successful/",
            "col_mapping": {
                "aig_size": "AIG size",
                "deleted_nodes": "Deleted nodes",
                "elapsed_time": "Elapsed time",
                "aig_file": "AIG file",
            }
        },
    )
    # results = eval.metrics
    # logger.info(f"Total AIGs: {len(df)}")
    # logger.info(f"Average cut AIG size: {results['cut_aig_size/mean']}")
    # logger.info(f"Average generated AIG size: {results['gen_aig_size/mean']}")
    # logger.info(f"Average deleted nodes: {results['deleted_nodes/mean']}")
    # logger.info(f"Average size reduction: {results['size_comparison/mean']}")
    # logger.info(
    #     f"Average relative size reduction: {results['size_comparison/relative']}"
    # )
    # logger.info(f"Average elapsed time: {results['elapsed_time/mean']}")


def log_model_params(cfg):
    def _explore_recursive(parent_name, element):
        if isinstance(element, DictConfig):
            for k, v in element.items():
                if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                    _explore_recursive(f"{parent_name}.{k}", v)
                else:
                    mlflow.log_param(f"{parent_name}.{k}", v)
        elif isinstance(element, ListConfig):
            for i, v in enumerate(element):
                mlflow.log_param(f"{parent_name}.{i}", v)
        else:
            mlflow.log_param(parent_name, element)

    for param_name, element in cfg.items():  # type: ignore
        _explore_recursive(param_name, element)


def test_extracted_cut(aig_files: list[str]) -> list[Result]:
    results = []
    for aig_f in tqdm(aig_files):
        test_aig = Learned_AIG.read_aig(aig_f)
        test_aig.instantiate_truth_tables()
        r = Result(
            aig_file=aig_f,
            aig_size=test_aig.n_ands(),
            last_generated_tt=tt2str(test_aig[-1].truth_table),  # type: ignore
            target_tt=tt2str(test_aig[-1].truth_table),  # type: ignore
            truth_tables=[tt2str(node.truth_table) for node in test_aig._nodes],  # type: ignore
            deleted_nodes=0,
            elapsed_time=0,
        )
        results.append(r)
    return results


def test_optimized_extracted_cut(aig_files: list[str], abc_path: str):
    results = []
    with tempfile.TemporaryDirectory() as tmp_dir:
        for aig_f in tqdm(aig_files):
            f = os.path.basename(aig_f)
            tmp_f = os.path.join(tmp_dir, f)

            os.system(
                (
                    f"{abc_path} -c 'read {aig_f}; balance; rewrite; "
                    f"refactor; balance; rewrite; rewrite -z; balance; "
                    f"refactor -z; rewrite -z; balance; write {tmp_f};' > /dev/null 2>&1"
                )
            )

            test_aig = Learned_AIG.read_aig(tmp_f)
            test_aig.instantiate_truth_tables()
            test_aig.clean_up()
            r = Result(
                aig_file=aig_f,
                aig_size=test_aig.n_ands(),
                last_generated_tt=tt2str(test_aig[-1].truth_table),  # type: ignore
                target_tt=tt2str(test_aig[-1].truth_table),  # type: ignore
                truth_tables=[tt2str(node.truth_table) for node in test_aig._nodes],  # type: ignore
                deleted_nodes=0,
                elapsed_time=0,
            )
            results.append(r)
    return results


def test_abc_generation(aig_files: list[str], abc_path: str) -> list[Result]:
    results = []
    with tempfile.TemporaryDirectory() as tmp_dir:
        for aig_f in tqdm(aig_files, desc="ABC Generation"):
            f = os.path.basename(aig_f)
            tmp_f = os.path.join(tmp_dir, f)
            test_aig = Learned_AIG.read_aig(aig_f)
            test_aig.instantiate_truth_tables()
            tt = tt2str(test_aig[-1].truth_table)  # type: ignore
            start_time = time.time()
            # os.system(f"{abc_path} -c 'read_truth -x {tt}; collapse; sop; strash; dc2; write {tmp_f}; print_stats' > /dev/null 2>&1")
            os.system(
                f"{abc_path} -c 'read_truth -x {tt}; collapse; sop; strash; write {tmp_f}; print_stats' > /dev/null 2>&1"
            )
            elapsed_time = time.time() - start_time
            out_aig = Learned_AIG.read_aig(tmp_f)
            out_aig.instantiate_truth_tables()
            r = Result(
                aig_file=aig_f,
                aig_size=out_aig.n_ands(),
                last_generated_tt=tt2str(out_aig[-1].truth_table),  # type: ignore
                target_tt=tt,
                truth_tables=[tt2str(node.truth_table) for node in out_aig._nodes],  # type: ignore
                deleted_nodes=0,
                elapsed_time=elapsed_time,
            )
            results.append(r)
    return results


def test_optimized_abc_generation(aig_files: list[str], abc_path: str) -> list[Result]:
    results = []
    with tempfile.TemporaryDirectory() as tmp_dir:
        for aig_f in tqdm(aig_files, desc="ABC Generation"):
            f = os.path.basename(aig_f)
            tmp_f = os.path.join(tmp_dir, f)
            test_aig = Learned_AIG.read_aig(aig_f)
            test_aig.instantiate_truth_tables()
            tt = tt2str(test_aig[-1].truth_table)  # type: ignore
            start_time = time.time()
            # os.system(f"{abc_path} -c 'read_truth -x {tt}; collapse; sop; strash; dc2; write {tmp_f}; print_stats' > /dev/null 2>&1")
            os.system(
                (
                    f"{abc_path} -c 'read_truth -x {tt}; "
                    "collapse; sop; strash; balance; rewrite; "
                    "refactor; balance; rewrite; rewrite -z; balance; "
                    "refactor -z; rewrite -z; balance; "
                    f"write {tmp_f}; print_stats' > /dev/null 2>&1"
                )
            )
            elapsed_time = time.time() - start_time
            out_aig = Learned_AIG.read_aig(tmp_f)
            out_aig.instantiate_truth_tables()
            r = Result(
                aig_file=aig_f,
                aig_size=out_aig.n_ands(),
                last_generated_tt=tt2str(out_aig[-1].truth_table),  # type: ignore
                target_tt=tt,
                truth_tables=[tt2str(node.truth_table) for node in out_aig._nodes],  # type: ignore
                deleted_nodes=0,
                elapsed_time=elapsed_time,
            )
            results.append(r)
    return results


def test_boolformer(aigs: List[str], embedding_size: int) -> List[Result]:
    boolformer_noiseless = load_boolformer(mode="noiseless")
    aig_env = AIGEnv(embedding_size, const_node=True)
    results = []
    for aig_f in tqdm(aigs):
        test_aig = init_aig_env_from_file(aig_f, aig_env)

        inputs = aig_env._construct_inputs()[1:, :].transpose(0, 1).numpy()
        target = aig_env.state["target"].squeeze().numpy()

        start_time = time.time()
        bool_expr, error_arr, complexity_arr = boolformer_noiseless.fit(
            [inputs], [target], verbose=False, beam_size=10, beam_type="search"
        )
        elapsed_time = time.time() - start_time

        out_tt = None
        if bool_expr[0] is not None:
            out_aig = boolformer_to_aig(bool_expr[0], test_aig.n_pis())
            out_aig.instantiate_truth_tables()

            out_tt = out_aig[-1].truth_table
            target_tt = test_aig[-1].truth_table

        new_size = -1
        if out_tt is not None and (
            torch.equal(out_tt, target_tt) or torch.equal(out_tt, ~target_tt)
        ):
            new_size = out_aig.n_ands()

        r = Result(
            aig_file=aig_f,
            aig_size=new_size,
            last_generated_tt=tt2str(out_tt) if out_tt is not None else "",
            target_tt=tt2str(target_tt),
            truth_tables=[tt2str(node.truth_table) for node in out_aig._nodes]
            if out_tt is not None
            else [""],
            deleted_nodes=0,
            elapsed_time=elapsed_time,
        )  # type: ignore
        results.append(r)
    return results


def test_ShortCircuit(
    model: torch.nn.Module,
    aigs: List[str],
    cfg: EvalAgentConfig,
    device_list: list[torch.device],
) -> List[Result]:
    if cfg.AZ.num_simulations < 2:
        return test_greedy_generation(model, aigs, cfg)
    elif cfg.AZ_workers > 1:
        results = parallel_test_quality(model, aigs, cfg, device_list)
    else:
        device = torch.device("cpu")
        results = simple_test_quality(model, aigs, cfg, device)

    return results


def clean_results(results: list[Result], successful_aigs) -> List[Result]:
    return [r for r in results if r.aig_file in successful_aigs]


def mcts_ablation(model, aigs, cfg, device_list, logger):
    for sims in cfg.ablation_sims:
        cfg.AZ.num_simulations = sims
        shortcircuit = test_ShortCircuit(model, aigs, cfg, device_list)

        successful_aigs = set([x.aig_file for x in shortcircuit if x.aig_size > 0])
        with mlflow.start_run(
            run_name=f"ShortCircuit[{sims}]",
            nested=True,
        ):
            mlflow.log_metric("success/rate", len(successful_aigs) / len(shortcircuit))
            mlflow.log_metric("success/count", len(successful_aigs))
            shortcircuit_df = df_from_results(
                clean_results(shortcircuit, successful_aigs)
            )
            mlflow_log_results(shortcircuit_df, logger)


def test_quality(
    model: torch.nn.Module,
    aigs: List[str],
    cfg: EvalAgentConfig,
    model_cfg: DictConfig,
    device_list: list[torch.device],
    logger: logging.Logger,
) -> None:
    start_time = time.time()

    # Run cuts stats
    cuts = test_extracted_cut(aigs)

    # Run optimized cuts stats
    optimized_cuts = test_optimized_extracted_cut(aigs, cfg.abc_path)

    # Run Boolformer generation
    bf = test_boolformer(aigs, model.embedding_size)

    # Run ShortCircuit generation
    shortcircuit = test_ShortCircuit(model, aigs, cfg, device_list)
    # shortcircuit = test_abc_generation(aigs, cfg.abc_path)

    # Run ABC generation
    abc = test_abc_generation(aigs, cfg.abc_path)

    # Run optimized ABC generation
    optimized_abc = test_optimized_abc_generation(aigs, cfg.abc_path)

    elapsed_time = time_formatter(time.time() - start_time, show_ms=False)
    logger.info(f"Evaluation total time: {elapsed_time}")

    # Successful aigs
    successful_aigs = set([x.aig_file for x in shortcircuit if x.aig_size > 0])
    bf_successful_aigs = set([x.aig_file for x in bf if x.aig_size > 0])
    overlap_aigs = successful_aigs.intersection(bf_successful_aigs)
    mlflow.log_metric("success/rate", len(successful_aigs) / len(shortcircuit))
    mlflow.log_metric("success/count", len(successful_aigs))
    mlflow.set_tag("success", len(successful_aigs) / len(shortcircuit))

    avg_fail_time = sum([x.elapsed_time for x in shortcircuit if x.aig_size == -1]) / (
        len(shortcircuit) - len(successful_aigs)
    )

    # mlflow.log_table(data=all_results, artifact_file="all_results.json")
    logger.info(f"Average time for failed AIGs: {time_formatter(avg_fail_time)}")
    logger.info("Logging results...")
    with mlflow.start_run(
        run_name="Cuts [clean]",
        nested=True,
    ):
        cuts_df = df_from_results(cuts)
        mlflow_log_results(cuts_df, logger)

    with mlflow.start_run(
        run_name="Cuts [overlap]",
        nested=True,
    ):
        cuts_df = df_from_results(clean_results(cuts, overlap_aigs))
        mlflow_log_results(cuts_df, logger)

    with mlflow.start_run(
        run_name="Cuts [all]",
        nested=True,
    ):
        cuts_df = df_from_results(clean_results(cuts, successful_aigs))
        mlflow_log_results(cuts_df, logger)

    with mlflow.start_run(
        run_name="Optimized Cuts [clean]",
        nested=True,
    ):
        optimized_cuts_df = df_from_results(
            clean_results(optimized_cuts, successful_aigs)
        )
        mlflow_log_results(optimized_cuts_df, logger)

    with mlflow.start_run(
        run_name="Optimized Cuts [overlap]",
        nested=True,
    ):
        optimized_cuts_df = df_from_results(clean_results(optimized_cuts, overlap_aigs))
        mlflow_log_results(optimized_cuts_df, logger)

    with mlflow.start_run(
        run_name="Optimized Cuts [all]",
        nested=True,
    ):
        optimized_cuts_df = df_from_results(optimized_cuts)
        mlflow_log_results(optimized_cuts_df, logger)

    with mlflow.start_run(
        run_name="ABC [clean]",
        nested=True,
    ):
        abc_df = df_from_results(clean_results(abc, successful_aigs))
        mlflow_log_results(abc_df, logger)

    with mlflow.start_run(
        run_name="ABC [overlap]",
        nested=True,
    ):
        abc_df = df_from_results(clean_results(abc, overlap_aigs))
        mlflow_log_results(abc_df, logger)

    with mlflow.start_run(
        run_name="ABC [all]",
        nested=True,
    ):
        abc_df = df_from_results(abc)
        mlflow_log_results(abc_df, logger)

    with mlflow.start_run(
        run_name="Optimized ABC [clean]",
        nested=True,
    ):
        optimized_abc_df = df_from_results(
            clean_results(optimized_abc, successful_aigs)
        )
        mlflow_log_results(optimized_abc_df, logger)

    with mlflow.start_run(
        run_name="Optimized ABC [overlap]",
        nested=True,
    ):
        optimized_abc_df = df_from_results(clean_results(optimized_abc, overlap_aigs))
        mlflow_log_results(optimized_abc_df, logger)

    with mlflow.start_run(
        run_name="Optimized ABC [all]",
        nested=True,
    ):
        optimized_abc_df = df_from_results(optimized_abc)
        mlflow_log_results(optimized_abc_df, logger)

    with mlflow.start_run(
        run_name=f"Boolformer",
        nested=True,
    ):
        boolformer_df = df_from_results(clean_results(bf, overlap_aigs))
        mlflow_log_results(boolformer_df, logger)
        mlflow.log_metric("success/rate", len(bf_successful_aigs) / len(shortcircuit))
        mlflow.log_metric("success/count", len(bf_successful_aigs))

    with mlflow.start_run(
        run_name=f"ShortCircuit[{cfg.AZ.num_simulations}]",
        nested=True,
        tags={
            "model_uri": cfg.mlflow_model_uri,
            "embedding_size": str(model_cfg.model.embedding_size),
            "n_heads": str(model_cfg.model.n_heads),
            "num_layers": str(model_cfg.model.n_layers),
        },
    ):
        log_model_params(model_cfg.model)
        mlflow.log_metric("success/rate", len(successful_aigs) / len(shortcircuit))
        mlflow.log_metric("success/count", len(successful_aigs))
        shortcircuit_df = df_from_results(clean_results(shortcircuit, successful_aigs))
        mlflow_log_results(shortcircuit_df, logger)

    with mlflow.start_run(
        run_name=f"ShortCircuit[{cfg.AZ.num_simulations}][overlap]",
        nested=True,
        tags={
            "model_uri": cfg.mlflow_model_uri,
            "embedding_size": str(model_cfg.model.embedding_size),
            "n_heads": str(model_cfg.model.n_heads),
            "num_layers": str(model_cfg.model.n_layers),
        },
    ):
        log_model_params(model_cfg.model)
        mlflow.log_metric("success/rate", len(overlap_aigs) / len(bf_successful_aigs))
        mlflow.log_metric("success/count", len(overlap_aigs))
        shortcircuit_df = df_from_results(clean_results(shortcircuit, overlap_aigs))
        mlflow_log_results(shortcircuit_df, logger)


@hydra.main(version_base=None, config_path="../conf", config_name="eval_agent_config")
def eval(cfg: EvalAgentConfig):
    logger = logging.getLogger(__name__)

    if int(cfg.test_ablation) + int(cfg.test_quality) > 1:
        raise ValueError("Only one test type can be selected")

    exp_name = ""
    if cfg.test_ablation:
        exp_name = "Ablation Study"
    elif cfg.test_quality:
        exp_name = "Quality Testing"

    # Device setup
    if isinstance(cfg.device_id, int):
        cfg.device_id = [cfg.device_id]

    if cfg.device is None or cfg.device == "cpu":
        device_list = [torch.device("cpu")]
    else:
        device_list = [torch.device(f"{cfg.device}:{i}") for i in cfg.device_id]

    # Load model config
    model_config_path, model_path = get_path(cfg)
    model_cfg = OmegaConf.load(model_config_path)

    # Load model
    model = hydra.utils.instantiate(model_cfg.model)
    snapshot = load_snapshot(model_path)
    model.load_state_dict(snapshot.model_state)
    model.eval()

    # For testing
    if cfg.test:
        start_time = time.time()
        test_mcts(model.to(device_list[0]), cfg, device_list[0])
        elapsed_time = time_formatter(time.time() - start_time, show_ms=False)
        logger.info(f"Test total time: {elapsed_time}")
        return

    # Load AIGs
    aigs = glob.glob(cfg.aigs_path, recursive=True)
    if cfg.shuffle:
        random.Random(cfg.seed).shuffle(aigs)
    aigs = aigs[: cfg.test_limit]

    # Evaluate
    os.environ["HTTP_PROXY"] = ""
    mlflow.set_tracking_uri(cfg.mlflow_uri)
    mlflow.set_experiment(f"Circuit Generation Evaluation")
    with mlflow.start_run(
        run_name=f"{exp_name}[{cfg.seed}][{cfg.test_limit}][{cfg.AZ.num_simulations}]",
        tags={
            "seed": str(cfg.seed),
            "test_limit": str(cfg.test_limit),
            "num_simulations": str(cfg.AZ.num_simulations),
            "model_uri": cfg.mlflow_model_uri,
            "embedding_size": str(model_cfg.model.embedding_size),
            "n_heads": str(model_cfg.model.n_heads),
            "num_layers": str(model_cfg.model.n_layers),
        },
    ) as run:
        log_model_params(model_cfg.model)
        mlflow.log_text(OmegaConf.to_yaml(cfg), "config.yaml")

        if cfg.test_ablation:
            mcts_ablation(model, aigs, cfg, device_list, logger)

        elif cfg.test_quality:
            test_quality(model, aigs, cfg, model_cfg, device_list, logger)  # type: ignore


if __name__ == "__main__":
    eval()
