import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from synthesizer.pool import PoolFactory
from synthesizer.template import TemplateFactory
from synthesizer.theory import Theory
import json
import random
from synthesizer.expression import ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression
from tqdm import tqdm


def main():
    logical_easy_config = {"entity_num": 10, "attribute_num": 15, "relation_num": 10,
                           "fact_num": 15, "rule_num": 15, "depth_interval": (1, 3), "condition_interval": (1, 1)}
    logical_hard_config = {"entity_num": 10, "attribute_num": 15, "relation_num": 10,
                           "fact_num": 15, "rule_num": 15, "depth_interval": (4, 6), "condition_interval": (2, 3)}
    numerical_easy_expression = {
        "normal": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1},
        "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
    }
    numerical_hard_expression = {
        "normal": {ConstantExpression: 0, IdentityExpression: 0, LinearExpression: 1, BinaryExpression: 1},
        "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
    }
    numerical_easy_config = {"interval": (
        1, 10), "expression_weights": numerical_easy_expression}
    numerical_hard_config = {
        "interval": (-100, 100), "expression_weights": numerical_hard_expression}

    pool_factory = PoolFactory("../resources/pools.json")
    template_factory = TemplateFactory("../resources/templates.json")
    
    # easy logic
    hardness = 'train-el'
    fact_num = logical_easy_config['fact_num']
    rule_num = logical_easy_config['rule_num']
    condition_num_interval = logical_easy_config['condition_interval']
    numerical_expression_weights = {
        "normal": {
            ConstantExpression: 1,
            IdentityExpression: 1,
            LinearExpression: 2,
            BinaryExpression: 2
        },
        "binary": {
            ConstantExpression: 1,
            IdentityExpression: 1,
            LinearExpression: 2
        }
    }
    numerical_interval = (-100, 100)
    with open(f"../raw-data/{hardness}.jsonl", "w", encoding='utf-8') as f:
        for i in tqdm(range(5000), desc=f"Generating {hardness} theories"):
            entities = pool_factory.get_entity_pool(
                logical_easy_config["entity_num"])
            attributes = pool_factory.get_attribute_pool(
                logical_easy_config["attribute_num"])
            relations = pool_factory.get_relation_pool(
                logical_easy_config["relation_num"])
            depth = random.randint(logical_easy_config["depth_interval"][0],
                                    logical_easy_config["depth_interval"][1])

            theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                            condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
            data = theory.to_json()
            data["id"] = f"{hardness}-{i+1}"
            f.write(json.dumps(data, ensure_ascii=False) + "\n")

    # easy numerical
    hardness = 'train-en'
    fact_num = logical_hard_config['fact_num']
    rule_num = logical_hard_config['rule_num']
    condition_num_interval = logical_hard_config['condition_interval']
    numerical_expression_weights = numerical_easy_expression
    numerical_interval = (1, 10)
    with open(f"../raw-data/{hardness}.jsonl", "w", encoding='utf-8') as f:
        for i in tqdm(range(5000), desc=f"Generating {hardness} theories"):
            entities = pool_factory.get_entity_pool(
                logical_hard_config["entity_num"])
            attributes = pool_factory.get_attribute_pool(
                logical_hard_config["attribute_num"])
            relations = pool_factory.get_relation_pool(
                logical_hard_config["relation_num"])
            depth = random.randint(logical_hard_config["depth_interval"][0],
                                    logical_hard_config["depth_interval"][1])

            theory = Theory(template_factory, entities, attributes, relations, fact_num, rule_num, depth,
                            condition_num_interval=condition_num_interval, expression_weights=numerical_expression_weights, interval=numerical_interval)
            data = theory.to_json()
            data["id"] = f"{hardness}-{i+1}"
            f.write(json.dumps(data, ensure_ascii=False) + "\n")


if __name__ == "__main__":
    main()
