import random
import re

from etr_case_generator.ontology import Ontology
from pyetr import View, PredicateAtom, ArbitraryObject, State
from typing import cast


def atom_to_natural_language(atom: PredicateAtom, obj_map: dict[str, str], ontology: Ontology) -> str:
    if atom.predicate.arity != 1:
        raise ValueError("Currently only working with unary predicates.")

    # print("DEBUG START atom_to_natural_language - atom:", atom, ", obj_map:", obj_map)

    neg = ""
    if not atom.predicate.verifier:
        neg = "not"

    term = atom.terms[
        0
    ]  # We can do this because we only consider unary predicates for now

    # Predicate is of the form "x is P"
    # Get names for predicate and term, and check if they are already mapped
    predicate_name = atom.predicate.name
    term_name = str(term)

    # print("DEBUG ATOM - predicate_name:", predicate_name)
    # print("DEBUG ATOM - term_name:", term_name)
    # print("DEBUG ATOM - obj_map:", obj_map)
    # print("DEBUG ATOM - ontology predicates:", [p.name for p in ontology.predicates])

    if predicate_name in obj_map.keys():
        predicate_nl = obj_map[predicate_name]
    else:
        # For now, since these predicates are all arity 1, we just take the
        # name property straightaway
        available_predicates = [
            p.name for p in ontology.predicates if p.name not in obj_map.values()
        ]
        # print("DEBUG ATOM - available_predicates:", [p.name for p in available_predicates])
        predicate_nl = random.sample(available_predicates, k=1)[0]
        obj_map[predicate_name] = predicate_nl

    # Check if term is arbitrary or not
    if type(term) == ArbitraryObject:
        # For now, for arbitrary terms we just use their variables (uppercased)
        term_nl = str(term).upper()

    else:
        if term_name in obj_map.keys():
            term_nl = obj_map[term_name]
        else:
            available_terms = [
                t for t in ontology.objects if t not in obj_map.values()
            ]
            term_nl = random.sample(available_terms, k=1)[0]
            obj_map[term_name] = term_nl

    s = " ".join(" ".join([term_nl, "is", neg, predicate_nl]).split())

    # print("DEBUG END atom_to_natural_language - atom:", atom, ", obj_map:", obj_map, "s:", s)

    return s


def state_to_natural_language(state: State, obj_map: dict[str, str], ontology: Ontology) -> str:
    # print("DEBUG STATE - state:", state)
    ret = ""
    atoms = [atom_to_natural_language(cast(PredicateAtom, atom), obj_map, ontology) for atom in state]

    # Sort atoms so that atoms beginning with "not" come last -- this helps the
    # natural language not read ambiguous, e.g. like "there is not an ace and a
    # ten"
    atoms.sort(key=lambda atom: atom.startswith("not"))

    return ret + " and ".join(atoms)


def view_to_natural_language(
    ontology: Ontology, view: View, obj_map: dict[str, str] = None
) -> tuple[str, dict[str, str]]:
    """Take a View and convert it into a natural language string.

    The natural language string returned has no ending punctuation, and doesn't
    capitalize words except for proper nouns.

    Args:
        view (View): The view to convert.
        obj_map (dict[str, str]): A map from variable names to objects in the
            ontology. Defaults to {}.

    Returns:
        str: A string describing the View in natural language.
        dict[str, str]: The object map transformed as a result of running this new
            conversion. This can be useful if you want to transform multiple views
            according to the same variable interpretations.
    """

    if obj_map is None:
        obj_map = {}
    # print("DEBUG ATOM - obj_map at view_to_natural_language start:", obj_map)

    # Find the variables we need to name
    universals: frozenset = view.dependency_relation.universals
    existentials: frozenset = view.dependency_relation.existentials
    all_quantifiers: frozenset = universals | existentials

    # Create the quantifier string
    quantifier_str = ""
    if len(all_quantifiers) > 0:
        delims = r"[ ∃∀]"
        q_order = re.split(delims, view.to_str().split("{")[0])
        q_order = [q for q in q_order if q != ""]  # Drop empty strings
        for q in q_order:
            if q in [u.name for u in universals]:
                quantifier_str += f"for all {q.upper()}, "
            elif q in [e.name for e in existentials]:
                quantifier_str += f"there is some {q.upper()} such that "

    states_for_stage: list[str] = [state_to_natural_language(state, obj_map, ontology) for state in view.stage]
    stage_str: str = ", or ".join(states_for_stage)
    if len(states_for_stage) > 1:
        stage_str = "either " + stage_str

    # TODO: this SetOfStates object should have an .empty method
    if not view.supposition.is_verum and not len(view.supposition) == 0:
        states_for_supposition = [
            state_to_natural_language(state, obj_map, ontology) for state in view.supposition
        ]
        supposition_str = ", or ".join(states_for_supposition)
        if len(states_for_supposition) > 1:
            supposition_str = "either " + supposition_str
        stage_str = "if " + supposition_str + ", then " + stage_str

    full_nl_str = quantifier_str + stage_str

    # print("view_to_natural_language END - obj_map:", obj_map)
    # print("view_to_natural_language END - full_nl_str:", full_nl_str)

    return full_nl_str, obj_map
