import os
import sys
import argparse
import math

# Global list to store all nodes of the totalizer tree
# Each node is a dict: {'id': int, 'max_sum': int, 'person_idx': int or None, 'child1_id': int or None, 'child2_id': int or None}
totalizer_nodes_global = []
next_node_id_global = 0

def _get_new_node_id():
    global next_node_id_global
    node_id = next_node_id_global
    next_node_id_global += 1
    return node_id

def _build_totalizer_recursive(person_start_idx, person_end_idx):
    """
    Recursively builds a part of the totalizer tree.
    person_start_idx and person_end_idx are 1-based.
    Returns the ID of the root node of this subtree.
    """
    global totalizer_nodes_global
    node_id = _get_new_node_id()
    num_persons_covered = person_end_idx - person_start_idx + 1

    if num_persons_covered == 1: # Leaf node
        node_info = {
            'id': node_id,
            'max_sum': 1,
            'person_idx': person_start_idx, # Store the 1-based person index
            'child1_id': None,
            'child2_id': None
        }
        totalizer_nodes_global.append(node_info)
        return node_id
    else: # Internal node
        mid_person_split_idx = person_start_idx + math.floor(num_persons_covered / 2) - 1

        child1_id = _build_totalizer_recursive(person_start_idx, mid_person_split_idx)
        child2_id = _build_totalizer_recursive(mid_person_split_idx + 1, person_end_idx)

        node_info = {
            'id': node_id,
            'max_sum': num_persons_covered,
            'person_idx': None,
            'child1_id': child1_id,
            'child2_id': child2_id
        }
        totalizer_nodes_global.append(node_info)
        return node_id

def build_totalizer_tree_structure(n_inputs):
    """
    Builds the complete totalizer tree for n_inputs.
    Resets global tree variables and returns the list of all nodes and the root_id.
    """
    global totalizer_nodes_global, next_node_id_global
    totalizer_nodes_global = []
    next_node_id_global = 0

    if n_inputs == 0:
        return [], -1
    if n_inputs == 1: # Special case, root is a leaf
        root_id = _get_new_node_id()
        node_info = {
            'id': root_id,
            'max_sum': 1,
            'person_idx': 1,
            'child1_id': None,
            'child2_id': None
        }
        totalizer_nodes_global.append(node_info)
        return list(totalizer_nodes_global), root_id # Return a copy

    root_id = _build_totalizer_recursive(1, n_inputs)
    # The nodes list is populated globally, ensure root is the last one added if IDs are sequential
    # Or find root by checking which node has parent_id = None (not stored here) or is simply root_id.
    return list(totalizer_nodes_global), root_id # Return a copy


def create_ground_food_program_totalizer(n_people, m_food, add_comments=True):
    lines = []
    min_choosers = math.floor(n_people / 2) + 1

    if add_comments:
        lines.append(f"% Grounded Food Preference Problem with Bailleux-Boufkhad Cardinality Encoding")
        lines.append(f"% n_people = {n_people}, m_food = {m_food}")
        lines.append(f"% Minimum people choosing a food for it to be valid: floor({n_people}/2) + 1 = {min_choosers}")
        lines.append("")

    # Basic Facts
    if add_comments:
        lines.append("% Basic facts: person(ID), food(ID).")
    for i in range(1, n_people + 1):
        lines.append(f"1.0::person({i}).")
    for j in range(1, m_food + 1):
        lines.append(f"1.0::food({j}).")
    lines.append("")

    # Unroll the constraint: For each person, exactly one food preference is chosen.
    # Assuming an underlying constraint like: 1 { person_food_preference(P, F) : food(F) } 1 :- person(P).
    if add_comments:
        lines.append("% Unrolled constraint: For each person, exactly one food preference is true.")
    for i in range(1, n_people + 1):
        # At least one preference for person i
        #at_least_one_body = ", ".join([f"not person_food_preference({i},{j})" for j in range(1, m + 1)])
        #if m > 0: # Only add if there's at least one food
        #    lines.append(f":- {at_least_one_body}.")

        # At most one preference for person i (in this case, it results in exactly one)
        for j in range(1, m_food + 1):
            at_most_one_body = ", ".join([f"not person_food_preference({i},{k})" for k in range(1, m_food + 1) if k != j])
            lines.append(f"person_food_preference({i},{j}) :- {at_most_one_body}.")
    lines.append("")

    # Generate grounded chooses rules
    if add_comments:
        lines.append("% Grounded chooses rules")
    for i in range(1, n_people + 1):
        for j in range(1, m_food + 1):
            # Since person(i) and food(j) are facts, the rule simplifies
            lines.append(f"chooses({i},{j}) :- person_food_preference({i},{j}).")
    lines.append("")

    # Build the totalizer tree structure (once, it's generic for n_people inputs)
    # tree_nodes will be a list of node_info dictionaries
    # root_node_id will be the ID of the root of this tree.
    tree_nodes, root_node_id_for_n_inputs = build_totalizer_tree_structure(n_people)

    if not tree_nodes:
        lines.append("% ERROR: Totalizer tree construction failed.")
        return lines

    # Predicate for Bailleux sum: bsg(NodeID, FoodID, ValueAtLeast)
    # bsg = Bailleux Sum Greater_or_equal

    for node in tree_nodes:
        node_id = node['id']
        max_s = node['max_sum']
        for f_id in range(1, m_food + 1):
            # Unary constraints: bsg(Node,Food,K) -> bsg(Node,Food,K-1)
            for k_val in range(2, max_s + 1): # From K=2 up to max_sum for this node
                lines.append(f"bsg({node_id},{f_id},{k_val-1}) :- bsg({node_id},{f_id},{k_val}).")
            # Integrity for unary: not bsg(K) ^ bsg(K+1) is false
            for k_val in range(1, max_s): # From K=1 up to max_sum-1
                    lines.append(f":- not bsg({node_id},{f_id},{k_val}), bsg({node_id},{f_id},{k_val+1}).")
    lines.append("")

    for f_id in range(1, m_food + 1):
        if add_comments:
            lines.append(f"% Totalizer for food {f_id}")

        # Leaf node rules
        for node in tree_nodes:
            if node['person_idx'] is not None: # It's a leaf node
                leaf_node_id = node['id']
                person_this_leaf = node['person_idx']
                lines.append(f"bsg({leaf_node_id},{f_id},1) :- chooses({person_this_leaf},{f_id}).")
        lines.append("")

        # Internal node rules (C1 and C2 from paper)
        for node in tree_nodes:
            if node['child1_id'] is not None: # It's an internal node
                u_id = node['id']
                # Find child nodes - their max_sum is already in their own dict entry
                lc_node = next(n for n in tree_nodes if n['id'] == node['child1_id'])
                rc_node = next(n for n in tree_nodes if n['id'] == node['child2_id'])
                lc_id, rc_id = lc_node['id'], rc_node['id']
                m_lc, m_rc = lc_node['max_sum'], rc_node['max_sum']

                # C1 type rules: bsg(U,F,Sigma) :- bsg(LC,F,Alpha), bsg(RC,F,Beta) where Sigma=Alpha+Beta.
                # Also covers cases where Alpha or Beta is 0 (i.e., one child contributes >= K, other >= 0)
                # Case 1: Alpha > 0, Beta > 0
                for alpha in range(1, m_lc + 1):
                    for beta in range(1, m_rc + 1):
                        sigma = alpha + beta
                        if sigma <= node['max_sum']: # Ensure sigma is within current node's capacity
                            lines.append(f"bsg({u_id},{f_id},{sigma}) :- bsg({lc_id},{f_id},{alpha}), bsg({rc_id},{f_id},{beta}).")
                # Case 2: Alpha > 0, Beta = 0 (effectively)
                for alpha in range(1, m_lc + 1):
                    if alpha <= node['max_sum']:
                        lines.append(f"bsg({u_id},{f_id},{alpha}) :- bsg({lc_id},{f_id},{alpha}).")
                # Case 3: Alpha = 0, Beta > 0 (effectively)
                for beta in range(1, m_rc + 1):
                    if beta <= node['max_sum']:
                        lines.append(f"bsg({u_id},{f_id},{beta}) :- bsg({rc_id},{f_id},{beta}).")

                # C2 type rules: :- bsg(U,F,Alpha+Beta+1), not bsg(LC,F,Alpha+1), not bsg(RC,F,Beta+1).
                # Alpha from 0 to m_lc-1, Beta from 0 to m_rc-1
                for alpha_val in range(m_lc): # loop for alpha in paper, so Alpha+1 is AlphaP1
                    for beta_val in range(m_rc):   # loop for beta in paper, so Beta+1 is BetaP1
                        sigma_p1 = alpha_val + beta_val + 1
                        alpha_p1 = alpha_val + 1
                        beta_p1 = beta_val + 1
                        if sigma_p1 <= node['max_sum']: # Ensure sum is valid for current node
                            lines.append(f":- bsg({u_id},{f_id},{sigma_p1}), not bsg({lc_id},{f_id},{alpha_p1}), not bsg({rc_id},{f_id},{beta_p1}).")
        lines.append("")

    # Valid food rule (Comparator)
    if add_comments:
        lines.append(f"% valid_food(F) if sum at root ({root_node_id_for_n_inputs}) for F is at least {min_choosers}.")
    for f_id in range(1, m_food + 1):
        lines.append(f"valid_food({f_id}) :- bsg({root_node_id_for_n_inputs},{f_id},{min_choosers}).")
    lines.append("")

    # Unroll the final constraint: exactly one food must be chosen from the valid foods
    # Original: 1 { chosen_food(F) : valid_food(F) } 1.
    if add_comments:
        lines.append("% Unrolled constraint: exactly one food must be chosen from the valid foods")

    # Define chosen_food based on valid_food and absence of other chosen valid foods
    for j in range(1, m_food + 1):
        # At least one valid food
        # at_least_one_body = ", ".join([f"not valid_food({j})" for j in range(1, m + 1)])
        # if m > 0: # Only add if there's at least one food
        #     lines.append(f":- {at_least_one_body}.")

        # At most one chosen food (in this case, it results in exactly one)
        # Generate the body of the rule with chosen_food for each valid food
        at_most_one_body = ", ".join([f"not chosen_food({k})" for k in range(1, m_food + 1) if k != j])
        lines.append(f"chosen_food({j}) :- valid_food({j}), {at_most_one_body}.")

    return lines

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Create Grounded Food Preference Problem programs.')
    parser.add_argument('n_people_start', type=int, help='Number of people (n)')
    parser.add_argument('n_people_end', type=int, help='Number of people (n)')
    parser.add_argument('m_food', type=int, help='Number of food items (m)')
    parser.add_argument('-c', action='store_true', help='Add comments to output')

    args = parser.parse_args()

    beginning = args.n_people_start
    end = args.n_people_end if args.n_people_end is not None else beginning
    m_foods = args.m_food
    add_comments = args.c

    if beginning < 3:
        print("Error: The number of people must be at least 3.")
        parser.print_help()
        exit(1)

    if m_foods < 3:
        print("Error: The number of food items must be at least 4.")
        parser.print_help()
        exit(1)

    for N in range(beginning, end + 1):
        directory = f"plp/programs/food_totalizer_{N}_{m_foods}"
        os.makedirs(directory, exist_ok=True)
        file_path = os.path.join(directory, f"food_totalizer_{N}_{m_foods}.pasp")
        with open(file_path, 'w') as f:
            lines = create_ground_food_program_totalizer(N, m_foods, add_comments)
            f.write("\n".join(lines))
