"""
Generate a dataset with root specified by the env variable DATASET_S6_ROOT

S6 specifies the length of the structures.


-> Remember to set the working directory as the project root!
"""
import json
from functools import partial
from itertools import product
from multiprocessing import Pool
from pathlib import Path
from typing import List, Optional

from nltk import CFG, Nonterminal
from nltk.parse.generate import generate
from tqdm import tqdm
from zendo.game import evaluate, parser
from zendo.game.rule_generator import generate_all_rules
from zendo.game.rule_grammar import zendo_grammar
from zendo.utils import get_env, load_envs

# # All rules should be tested!
# # When adding some rule here, add it to the tests too.
# static_rules = [
#     "at_least 1 blue",
#     "exactly 2 red block",
#     "at_most 1 pyramid",
#     "exactly 3 block",
#     "exactly 1 block touching blue",
#     "not at_least 1 blue",
#     "at_least 1 blue surrounded_by .",
#     "zero pyramid",
#     "exactly 1 group_of 3 .",
#     "exactly 2 red block at_the_right_of .",
#     "exactly 2 red block at_the_left_of .",
#     "at_least 1 group_of 2 blue touching red",
#     "exactly 2 group_of 2 blue block touching red block",
#     "zero blue",
#     "zero pyramid",
#     "zero blue pyramid",
# ]


def generate_rule_labels(
    i_rule: int, rule: str, root: Path, structures: List[str]
) -> None:
    """
    Generate the labels for all the structures given a single rule.
    Can be run in parallel.
    
    :param i_rule: index of the rul
    :param rule: the rule in str format
    :param root: the root path of the dataset
    :param structures: the list of all the possible structures
    """
    try:
        ast = parser.parse(rule)
    except ValueError as v:
        print(rule)
        raise v

    structures_labels = []
    count = 0
    for structure in structures:
        res = int(evaluate(ast, structure))
        count += res
        structures_labels.append(str(res))

    rule_root = root / f"{i_rule:010d}"
    rule_root.mkdir(exist_ok=True)

    (rule_root / "rule.txt").write_text(rule)
    (rule_root / "labels.txt").write_text("\n".join(structures_labels))

    (rule_root / "info.txt").write_text(
        json.dumps(
            {
                "white": count,
                "black": len(structures) - count,
                "total": len(structures),
                "white_perc": count / len(structures) * 100,
            },
            sort_keys=True,
            indent=4,
        )
    )


def generate_dataset(
    dataset_name: str,
    structures_length: int,
    max_rule_depth: Optional[int] = None,
    max_number_of_rules: Optional[int] = None,
    start_nonterminal: Optional[Nonterminal] = None,
    compute_labels: bool = False,
    workers: int = 12,
    chunksize: int = 16,
) -> None:
    """
    Generate the dataset s6, organized as:

        ZendoDatasetS6      - dataset root folder
            structures.txt  - one structure with length 6 per line
            rules.txt       - one rule per line, the index corresponds to the folder
            xxxxxxxxx       - rule folder (10 digits, numerically increasing)
                rule.txt    - the rule in standard str format
                labels.txt  - the evaluation of each structure against the current rule
                              1=white and 0=black, one per line following structures.txt
                info.txt    - some stats about the current rule


    Note: use `from distutils.util import strtobool` to convert from int to bool
    """
    load_envs()
    root = Path(get_env(dataset_name))
    root.mkdir(parents=True, exist_ok=True)

    # Generate all possible 6-structures
    structures = sorted("".join(x) for x in product("aAbB.", repeat=structures_length))

    # Generate all possible rules with depth max_rule_depth
    rules = generate_all_rules(
        max_rule_depth=max_rule_depth,
        max_n=max_number_of_rules,
        start_nonterminal=start_nonterminal,
    )

    # Save structures and rules in files
    (root / "structures.txt").write_text("\n".join(structures))
    (root / "rules.txt").write_text("\n".join(rules))

    if compute_labels:
        with Pool(processes=workers) as pool:
            pool.starmap(
                partial(generate_rule_labels, root=root, structures=structures),
                enumerate(rules),
                chunksize,
            )


def generate_default_datasets():
    load_envs()
    print(
        f"""Generating: DATASET_S6_startPROP in {get_env('DATASET_S6_startPROP')}
    - structure length:6
    - starting non terminal: PROP"""
    )
    generate_dataset(
        dataset_name="DATASET_S6_startPROP",
        structures_length=6,
        start_nonterminal=Nonterminal("PROP"),
        compute_labels=False,
        workers=12,
        chunksize=16,
    )

    print()

    print(
        f"""Generating: DATASET_S6_startGPROP in {get_env('DATASET_S6_startGPROP')}
    - structure length: 6
    - starting non terminal: GPROP"""
    )
    generate_dataset(
        dataset_name="DATASET_S6_startGPROP",
        structures_length=6,
        start_nonterminal=Nonterminal("GPROP"),
        compute_labels=False,
        workers=12,
        chunksize=16,
    )


if __name__ == "__main__":
    generate_default_datasets()