import os
import pickle
import ast
import sys
sys.path.append('..')
from config.predicates import get_predicate_defs
from config.exemplars import get_exemplars
import pandas as pd

def generate_recall_and_completion(row, dir):
    info = {}

    # 1. Correct labels
    info['correct_completion'] = row['query_relation']
    info['correct_recall'] = row['set_of_world_used']

    # 2. Predicate definitions
    predicate_defs = get_predicate_defs()

    # 3. World rules
    rules_path = os.path.join(dir, "world_rule_index.pkl")
    with open(rules_path, "rb") as f:
        world_rules = pickle.load(f)
    world_rule_str = "Here are the world rules. Rules are indexed, and follow the format Head :- Body.\n"
    for idx, rule in world_rules.items():
        world_rule_str += f"  {idx}: {rule}\n"

    rule_formatting_guide = (
        "There are three types of rules:\n"
        "  A. Definite Rule: Has a head and a body. It means if all atoms in the body are true, then the head is true.\n"
        "  B. Constraint: Has only a body. It states that the atoms in the body cannot all be true at the same time.\n"
        "  C. Fact: Has only a head. This atom is always true.\n"
        "Variables are capitalized and rules with variables hold universally for all substitutions."
    )

    # 4. Task description (Exemplars file to be added later in config)
    task_instructions = (
    "TASKS:\n"
    "You will be given a story made up of predicates describing relationships between entities...\n\n"
    + get_exemplars() + "\n"
    + "Output format:\n"
    "  query_label: ...\n"
    "  rules_used: {...}\n"
    "  reasoning: ...\n"
    )

    # 5. Story and query parsing
    story_and_query = describe_story_and_query(row)

    # Compose final prompt
    prompt = (
    predicate_defs + "\n\n"
    + rule_formatting_guide + "\n\n"
    + world_rule_str + "\n"
    + task_instructions + "\n"
    + story_and_query + "\n"
    + "Your output should be:\n"
    + "  query_label: ...\n"
    + "  rules_used: {...}\n"
    + "  reasoning: ...\n"
)

    info['prompt'] = prompt
    return info

def describe_story_and_query(row):
    story_edges = safe_eval_if_str(row['story_edges'])
    edge_types = safe_eval_if_str(row['edge_types'])
    query_edge = safe_eval_if_str(row['query_edge'])

    # Build the story string
    story_desc = "STORY:\n"
    for (source, target), relation in zip(story_edges, edge_types):
        if source == target:
            story_desc += f"  {source} is a {relation}.\n"
        elif relation in {"belongs_to_group", "belongs_to", "has_property"}:
            story_desc += f"  {source} {relation} {target}.\n"
        elif relation == "living_in_same_place":
            story_desc += f"  {source} is living in the same place as {target}.\n"
        else:
            story_desc += f"  {source} is the {relation} of {target}.\n"

    query_str = f'''\nQUERY:\n  What is the predicate between {query_edge[0]} and {query_edge[1]}? If a relationship between {query_edge[0]} and {query_edge[1]}
    is explicitly given in the story facts, there is some other relationship that is also true, you need to uncover this unstated predicate. If multiple 
    predicates, capture the relationship between {query_edge[0]} and {query_edge[1]}, which is the most specific one.\n'''
    rules_query = "  What are the indexes of the world rules you will need to derive this?\n"

    return story_desc + "\n" + query_str + rules_query


def safe_eval_if_str(val):
    if isinstance(val, str):
        return ast.literal_eval(val)
    return val

    