import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from synthesizer.pool import PoolFactory
from synthesizer.template import TemplateFactory
from synthesizer.theory import Theory
import json
import random
from synthesizer.expression import ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression
from tqdm import tqdm


def main():
    logical_easy_config = {"entity_num": 10, "attribute_num": 15, "relation_num": 10,
                           "fact_num": 15, "rule_num": 15, "depth_interval": (1, 3), "condition_interval": (1, 1)}
    logical_hard_config = {"entity_num": 10, "attribute_num": 15, "relation_num": 10,
                           "fact_num": 15, "rule_num": 15, "depth_interval": (4, 6), "condition_interval": (2, 3)}
    numerical_easy_expression = {
        "normal": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1},
        "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
    }
    numerical_hard_expression = {
        "normal": {ConstantExpression: 0, IdentityExpression: 0, LinearExpression: 1, BinaryExpression: 1},
        "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
    }
    numerical_easy_config = {"interval": (
        1, 10), "expression_weights": numerical_easy_expression}
    numerical_hard_config = {
        "interval": (-100, 100), "expression_weights": numerical_hard_expression}
    logical_configs = [("el", logical_easy_config),
                       ("hl", logical_hard_config)]
    numerical_configs = [("en", numerical_easy_config),
                         ("hn", numerical_hard_config)]
    pool_factory = PoolFactory("../resources/pools.json")
    template_factory = TemplateFactory("../resources/templates.json")

    for logical_hardness, logical_config in logical_configs:
        for numerical_hardness, numerical_config in numerical_configs:
            hardness = f"{logical_hardness}-{numerical_hardness}"
            fact_num = logical_config["fact_num"]
            rule_num = logical_config["rule_num"]
            condition_num_interval = logical_config["condition_interval"]
            numerical_expression_weights = numerical_config["expression_weights"]
            numerical_interval = numerical_config["interval"]
            with open(f"../raw-data/{hardness}.jsonl", "w", encoding='utf-8') as f:
                for i in tqdm(range(500), desc=f"Generating {hardness} theories"):
                    entities = pool_factory.get_entity_pool(
                        logical_config["entity_num"])
                    attributes = pool_factory.get_attribute_pool(
                        logical_config["attribute_num"])
                    relations = pool_factory.get_relation_pool(
                        logical_config["relation_num"])
                    depth = random.randint(logical_config["depth_interval"][0],
                                           logical_config["depth_interval"][1])

                    theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                                    condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
                    data = theory.to_json()
                    data["id"] = f"{hardness}-{i+1}"
                    f.write(json.dumps(data, ensure_ascii=False) + "\n")
            with open(f"../raw-prompt/{hardness}.jsonl", "w", encoding='utf-8') as f:
                for i in tqdm(range(10), desc=f"Generating {hardness} theories"):
                    entities = pool_factory.get_entity_pool(
                        logical_config["entity_num"])
                    attributes = pool_factory.get_attribute_pool(
                        logical_config["attribute_num"])
                    relations = pool_factory.get_relation_pool(
                        logical_config["relation_num"])
                    depth = random.randint(logical_config["depth_interval"][0],
                                           logical_config["depth_interval"][1])

                    theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                                    condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
                    data = theory.to_json()
                    data["id"] = f"{hardness}-{i+1}"
                    f.write(json.dumps(data, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    main()
