import re
from collections import defaultdict, deque


def parse_logic_text(text):
    """
    Parses a block of text with logical statements about group membership and traits,
    and extracts:
      - target and query trait,
      - the inherited group memberships,
      - the positive and negative traits for the target.

    Returns a dictionary with keys:
      "target", "query", "positive_traits", "negative_traits", "group_membership".
    """
    # === STEP 1: Sentence Extraction ===
    sentences = re.findall(r"[^.]+[.]", text)
    logic_statements = []
    target = None
    query_trait = None

    # === FIXED Singularization ===
    def singular(word):
        if word.endswith("uses"):
            return word[:-2]  # lorpuses → lorpus
        elif word.endswith("ses"):
            return word[:-2]
        elif word.endswith("ies"):
            return word[:-3] + "y"
        elif word.endswith("s") and not word.endswith("us"):
            return word[:-1]
        return word

    def normalize_name(name):
        return singular(name.strip().lower())

    # === STEP 2: Separate logic and query ===
    for sentence in sentences:
        sentence = sentence.strip()
        if sentence.lower().startswith("true or false:"):
            match = re.match(
                r"true or false: (\w+) is (not )?(\w+)\.", sentence.lower()
            )
            if match:
                target = match.group(1).capitalize()
                query_trait = ("not " if match.group(2) else "") + match.group(
                    3
                ).lower()
                # print(match)
        else:
            logic_statements.append(sentence)

    # === STEP 3: Logic Processing ===
    group_membership = defaultdict(set)
    group_traits = defaultdict(set)
    individual_groups = defaultdict(set)

    for statement in logic_statements:
        s = statement.strip(".").lower()
        # Normalize some connectors
        s = (
            s.replace("each ", "")
            .replace("every ", "")
            .replace(" are ", " is ")
            .replace(" an ", " a ")
            .replace(" a ", " ")
        )

        # Trait: "X is not Y"
        match = re.match(r"(\w+) is not (\w+)", s)
        if match:
            group = normalize_name(match.group(1))
            trait = "not " + normalize_name(match.group(2))
            group_traits[group].add(trait)

            if group == target.lower():
                individual_groups[target].add(trait)
                group_membership[group].add(trait)
            else:
                group_membership[group].add(trait)

            continue

        # Trait: "X is Y" or membership
        match = re.match(r"(\w+) is (\w+)", s)
        if match:
            subj = normalize_name(match.group(1))
            obj = normalize_name(match.group(2))
            # First, add trait for subject (it could be a trait statement)
            group_traits[subj].add(obj)
            # Now, if subject equals target, mark it as an individual membership;
            # otherwise treat it as group membership
            if subj == target.lower():
                individual_groups[target].add(obj)
                group_membership[subj].add(obj)
            else:
                group_membership[subj].add(obj)
            continue

    # === STEP 4: Group Inheritance for the target ===
    def get_all_groups(start_groups):
        seen = set()
        queue = deque(start_groups)
        while queue:
            group = queue.popleft()
            if group in seen:
                continue
            seen.add(group)
            for parent in group_membership[group]:
                queue.append(parent)
        return seen

    target_groups = get_all_groups(individual_groups[target])
    target_traits = set()
    for group in target_groups:
        target_traits.update(group_traits.get(group, set()))

    # === FINAL OUTPUT PROCESSING ===
    all_traits = set()
    for group in target_groups:
        all_traits.update(group_traits.get(group, set()))

    positive_traits = sorted([t for t in all_traits if not t.startswith("not ")])
    negative_traits = sorted([t[4:] for t in all_traits if t.startswith("not ")])

    return {
        "text": text,
        "target": target,
        "query": query_trait,
        "positive_traits": positive_traits,
        "negative_traits": negative_traits,
        "group_membership": {k: sorted(list(v)) for k, v in group_membership.items()},
    }


def print_tree(node, data, prefix=""):
    children = data.get(node, [])
    for i, child in enumerate(children):
        is_last = i == len(children) - 1
        connector = "└── " if is_last else "├── "
        print(prefix + connector + child)
        if child in data:
            extension = "    " if is_last else "│   "
            print_tree(child, data, prefix + extension)


def get_membership(node: str, data: dict, membership: list = []) -> list:
    children = data.get(node, [])
    for child in children:
        membership += [child]

        if child in data:
            get_membership(child, data, membership)

    return membership


def create_rules(text: str, verbose: bool = False) -> dict[str, list[str]]:
    rules = dict()

    parsed_logic = parse_logic_text(text)

    group_membership = parsed_logic["group_membership"]
    target = parsed_logic["target"].lower()

    if verbose:
        print(target)
        print_tree(target, group_membership)

    mem = get_membership(target, group_membership, [])

    for key, items in group_membership.items():
        if not items and key not in mem:
            continue

        for item in items:
            if item in mem:
                # rules[item.replace("not ", "")].append(key)
                if "not" in item:
                    rules[key, "is", "not", item.replace("not ", "")] = "true"
                    rules[key, "is", item.replace("not ", "")] = "false"
                else:
                    rules[key, "is", item] = "true"
                    rules[key, "is", "not", item] = "false"

    return rules
