# -> BUILD EVALUATOR
import re
from typing import Set, Union

import regex

from zendo.game.rule_grammar import Node, human_readable


def evaluate(ast: Node, structure: str):
    """
    Given an AST of a parsed rule and a structure,
    evaluate the structure against that rule.
    
    :param ast: the AST of a rule 
    :param structure: a valid structure to evaluate

    :return: True if the structure respects the rule, False otherwise
    """
    if ast.name == "rule":
        return _evaluate_rule(ast, structure)
    elif ast.name == "prop":
        return _evaluate_prop(ast, structure)
    elif ast.name == "rel":
        return _evaluate_rel(ast, structure)
    elif ast.name == "quantity":
        return _evaluate_quantity(ast, structure)
    elif ast.name == "num":
        return _evaluate_num(ast, structure)
    elif ast.name == "obj":
        return _evaluate_obj(ast, structure)
    elif ast.name == "color":
        return _evaluate_color(ast, structure)
    elif ast.name == "shape":
        return _evaluate_shape(ast, structure)
    assert False


def _evaluate_rule(node: Node, structure: str):
    if node.literal is None and len(node.children) == 1:
        return evaluate(node.children[0], structure)
    elif node.literal == "and":
        return evaluate(node.children[0], structure) and evaluate(
            node.children[1], structure
        )
    elif node.literal == "or":
        return evaluate(node.children[0], structure) or evaluate(
            node.children[1], structure
        )
    assert False


def _regify_set(charset: Set[str]):
    if isinstance(charset, Set):
        charset = f'[{"".join(charset)}]'
    return charset


def _evaluate_prop(node: Node, structure: str):
    if len(node.children) == 2:
        quantity = evaluate(node.children[0], structure)
        obj = evaluate(node.children[1], structure)
        return quantity(len(re.findall(obj, structure)))

    elif len(node.children) == 4:
        quantity = evaluate(node.children[0], structure)
        obj_left = evaluate(node.children[1], structure)
        rel = evaluate(node.children[2], structure)
        obj_right = evaluate(node.children[3], structure)

        if rel == "surrounded_by":
            pattern = f"(?<=({obj_right})){obj_left}(?=({obj_right}))"
        elif rel == "touching":
            pattern = f"((?<=({obj_right})){obj_left}|{obj_left}(?=({obj_right})))"
        elif rel == "at_the_right_of":
            pattern = f"(?<=({obj_right}.*)){obj_left}"
        elif rel == "at_the_left_of":
            pattern = f"{obj_left}(?=(.*{obj_right}))"
        else:
            raise RuntimeError()
        # print(
        #     f"\npattern: {pattern}\nstructure: {structure}\nmatches: {regex.findall(pattern, structure)}\n"
        # )
        return quantity(len(regex.findall(pattern, structure)))
    assert False


def _evaluate_rel(node: Node, structure: str):
    return node.literal


def _evaluate_quantity(node: Node, structure: str):
    if len(node.children) == 0 and node.literal == "zero":
        return lambda x: x == 0

    num = evaluate(node.children[0], structure)
    if node.literal == "exactly":
        return lambda x: x == num
    elif node.literal == "at_least":
        return lambda x: x >= num
    elif node.literal == "at_most":
        return lambda x: x <= num
    assert False


def _evaluate_num(node: Node, structure: str):
    return int(node.literal)


def _evaluate_obj(node: Node, structure: str):
    if len(node.children) == 1:
        piece = evaluate(node.children[0], structure)
    elif len(node.children) == 2:
        piece = _evaluate_color(node.children[0], structure).intersection(
            _evaluate_shape(node.children[1], structure)
        )
        assert len(piece) == 1
        piece = piece.pop()
    else:
        assert False
    return _regify_set(piece)


def _evaluate_color(node: Node, structure: str) -> Set[str]:
    return human_readable[node.literal]


def _evaluate_shape(node: Node, structure: str) -> Set[str]:
    return human_readable[node.literal]