from tqdm import tqdm
from synthesizer.expression import ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression
import random
import json
from synthesizer.theory import Theory
from synthesizer.template import TemplateFactory
from synthesizer.pool import PoolFactory
import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)


def main():
    logical_hhard_config = {"entity_num": 30, "attribute_num": 40, "relation_num": 40,
                            "depth_interval": (7, 20), "condition_interval": (3, 6)}
    numerical_hard_expression = {
        "normal": {ConstantExpression: 0, IdentityExpression: 0, LinearExpression: 1, BinaryExpression: 1},
        "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
    }
    numerical_hard_config = {
        "interval": (-100, 100), "expression_weights": numerical_hard_expression}

    pool_factory = PoolFactory("../resources/pools.json")
    template_factory = TemplateFactory("../resources/templates.json")

    # hard numerical
    for depth in range(7, 11):
        hardness = f"depth-{depth}"
        fact_num = depth * 15
        rule_num = depth * 5
        condition_num_interval = logical_hhard_config['condition_interval']
        numerical_expression_weights = numerical_hard_expression
        numerical_interval = numerical_hard_config['interval']
        with open(f"../raw-data/{hardness}.jsonl", "w", encoding='utf-8') as f:
            for i in tqdm(range(100), desc=f"Generating {hardness} theories"):
                entities = pool_factory.get_entity_pool(
                    logical_hhard_config["entity_num"])
                attributes = pool_factory.get_attribute_pool(
                    logical_hhard_config["attribute_num"])
                relations = pool_factory.get_relation_pool(
                    logical_hhard_config["relation_num"])

                theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                                condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
                data = theory.to_json()
                data["id"] = f"{hardness}-{i+1}"
                f.write(json.dumps(data, ensure_ascii=False) + "\n")

        with open(f"../raw-prompt/{hardness}.jsonl", "w", encoding='utf-8') as f:
            for i in tqdm(range(10), desc=f"Generating {hardness} prompts"):
                entities = pool_factory.get_entity_pool(
                    logical_hhard_config["entity_num"])
                attributes = pool_factory.get_attribute_pool(
                    logical_hhard_config["attribute_num"])
                relations = pool_factory.get_relation_pool(
                    logical_hhard_config["relation_num"])

                theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                                condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
                data = theory.to_json()
                data["id"] = f"{hardness}-{i+1}"
                f.write(json.dumps(data, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    main()
