import glob
import logging
import math
import os
from dataclasses import dataclass
from typing import Any, Iterable

import hydra
import torch
from data.pyaig import Learned_AIG

from data.pyaig.optimal_aigs import (
    parallel_create_cut_data,
    parallel_create_multiple_cut_data,
    parallel_extract_large_cuts,
    parallel_optimize_cut_data,
)
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING
from tqdm import tqdm


@dataclass
class UtilConfig:
    _target_: str = MISSING
    data_path: str = MISSING
    f: str = MISSING
    node_limit: int = 40
    workers: int = 4
    num_inputs: int = 8
    min_inputs: int = 6
    max_inputs: int = 8
    chunk_size: int = 10000
    abc_path: str = "src/data/abc/build/abc"
    debug: bool = False


@dataclass
class OptimalAIGConfig(UtilConfig):
    _target_: str = MISSING


# @dataclass
# class

cs = ConfigStore.instance()
cs.store(name="base_util_config", node=UtilConfig)


def optimize_cuts(data_path: str, abc_path: str, num_workers: int):
    parallel_optimize_cut_data(data_path, abc_path, num_workers)


def extract_large_cuts(
    data_path: str | list[str], min_inputs: int, max_inputs: int, num_workers: int
):
    aigs = retrieve_paths(data_path)
    parallel_extract_large_cuts(
        aigs,
        min_inputs,
        max_inputs,
        num_workers,
        "./data/large/unoptimized",
    )


def retrieve_paths(data_path: str | list[str]) -> list[str]:
    aigs: list[str] = []
    if (
        isinstance(data_path, Iterable)
        and isinstance(data_path[0], str)
        and "*" in data_path[0]
    ):
        tmp = []
        for aig in data_path:
            tmp += glob.glob(aig, recursive=True)
        aigs = tmp
    else:
        assert isinstance(data_path, str)
        if "*" in data_path:
            aigs = glob.glob(data_path, recursive=True)
        else:
            aigs = [data_path]
    return aigs


def extract_truth_tables(data_path: str | list[str]):
    aigs = retrieve_paths(data_path)

    truth_tables: dict[str, int] = {}
    for aig in tqdm(aigs):
        aig = Learned_AIG.read_aig(aig, False)
        tt = aig[-1].truth_table
        assert isinstance(tt, torch.Tensor)
        tt = "".join([str(i) for i in (tt.to(torch.long)).tolist()])
        if tt not in truth_tables:
            truth_tables[tt] = 0
        truth_tables[tt] += 1

    files: dict[int, Any] = {}
    total = 0
    for tt in truth_tables:
        input_num = int(math.log2(len(tt)))
        if input_num not in files:
            fname = f"./data/truth_tables/extracted/{input_num}_inputs.txt"
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            files[input_num] = open(fname, "w")
        files[input_num].write(f"{tt}\n")
        total += truth_tables[tt]

    log = logging.getLogger(__name__)
    log.info(
        f"Total number of truth-tables: {total}\n"
        f"Number of unique truth-tables: {len(truth_tables)}\n"
        f"Ratio of unique to total: {round(len(truth_tables)/total, 2)}"
    )


def extract_cuts(
    data_path: str | list[str],
    num_inputs: int,
    num_workers: int | None = None,
    chunk_size: int | None = None,
):
    aigs = retrieve_paths(data_path)
    if num_workers is not None:
        parallel_create_cut_data(
            aigs,
            num_inputs,
            "./data/unoptimized",
            num_workers,
            chunk_size,
        )
    else:
        parallel_create_cut_data(
            aigs,
            num_inputs,
            "./data/unoptimized",
            1,
            chunk_size,
        )


def extract_multicuts(
    data_path: str | list[str],
    num_inputs: int,
    num_workers: int | None = None,
    chunk_size: int | None = None,
):
    aigs = retrieve_paths(data_path)
    if num_workers is not None:
        parallel_create_multiple_cut_data(
            aigs,
            num_inputs,
            "./data/unoptimized",
            num_workers,
            chunk_size,
        )
    else:
        parallel_create_multiple_cut_data(
            aigs,
            num_inputs,
            "./data/unoptimized",
            1,
            chunk_size,
        )


def generate_truth_tabels():
    pass


@hydra.main(
    version_base=None,
    config_path="../conf",
    config_name="util_config",
)
def eval(cfg: UtilConfig):
    # print(OmegaConf.to_yaml(cfg))
    # optimize_cuts(cfg.data_path, cfg.abc_path, cfg.workers)
    # hydra.utils.call(cfg)
    # extract_truth_tables(data_path="./data/optimized/6_inputs/**/*.aig")
    # extract_truth_tables(data_path="./data/optimized/6_inputs/**/*.aig")
    if cfg.f == "extract_truth_tables":
        extract_truth_tables(cfg.data_path)
    elif cfg.f == "extract_cuts":
        extract_cuts(
            cfg.data_path, cfg.num_inputs, cfg.workers, cfg.chunk_size
        )  # ./data/EPFL/**/*.aig
    elif cfg.f == "extract_multicuts":
        extract_multicuts(
            cfg.data_path, cfg.num_inputs, cfg.workers, cfg.chunk_size
        )  # ./data/EPFL/**/*.aig
    elif cfg.f == "optimize_cuts":
        optimize_cuts(
            cfg.data_path, cfg.abc_path, cfg.workers
        )  # ./data/unoptimized/8_inputs/**/*.aig
    elif cfg.f == "extract_large_cuts":
        extract_large_cuts(cfg.data_path, cfg.min_inputs, cfg.max_inputs, cfg.workers)


if __name__ == "__main__":
    eval()
