import re
import ast
import importlib.util
from pathlib import Path

action_param_map_path = Path("memory/long_term_memory/action_param_map.py")
spec = importlib.util.spec_from_file_location("action_param_map", action_param_map_path)
action_param_map = importlib.util.module_from_spec(spec)
spec.loader.exec_module(action_param_map)

ACTION_PARAM_MAP = action_param_map.ACTION_PARAM_MAP
SWITCH_ACTIONS = action_param_map.SWITCH_ACTIONS
USE_ACTIONS = action_param_map.USE_ACTIONS

class Node:
    """A task node (each node represents a specific atomic operation)"""

    def __init__(self, node_id, node_type, name, arm_num, take_time, dependencies, location, target_obj=None, source_obj=None,delay_after=0):
        self.id = node_id                       # Unique ID of the node
        self.type = node_type                   # Task type (e.g., pick/place/cut)
        self.name = name                        # Task description (including involved objects)
        self.arm_num = arm_num                  # Number of arms required (1/2)
        self.take_time = take_time              # Execution duration (in seconds)
        self.dependencies = dependencies        # List of IDs of preceding tasks
        self.predecessors = len(dependencies)   # Number of unfinished preceding tasks
        self.earliest_start = 0                 # Earliest start time based on dependencies
        self.successors = []                    # List of succeeding node IDs
        self.location = location                # Task location (e.g., fridge, table)
        self.target_obj = target_obj            # Object picked by this task (if applicable)
        self.source_obj = source_obj            # Source object placed by this task (if applicable)
        self.delay_after = delay_after          # Delay time after task completion

def parse_input(text):
    """Parse the input text and extract task node information"""
    
    nodes_data = {}
    current_node = None

    for line in text.split('\n'):
        line = line.strip()
        if not line:
            continue

        if line.startswith('node_'):
            current_node = line.split(':')[0].strip()
            nodes_data[current_node] = {
                'type': '', 'name': '', 'arm_num': 0,
                'take_time': 0, 'edge': [], 'delay_after': 0
            }
        elif line.startswith('delay_after:'):
            nodes_data[current_node]['delay_after'] = int(line.split(':')[1].strip())
        elif line.startswith('type:'):
            node_type = line.split(':', 1)[1].strip()
            if node_type == 'task_completion':
                del nodes_data[current_node]
                current_node = None
            else:
                nodes_data[current_node]['type'] = node_type
        elif current_node is not None:
            if line.startswith('name:'):
                nodes_data[current_node]['name'] = line.split(':', 1)[1].strip()
                node_type = nodes_data[current_node]['type']
                params = ACTION_PARAM_MAP.get(node_type, [])
                for param in params:
                    pattern = rf'{param}="([^"]+)"'
                    match = re.search(pattern, nodes_data[current_node]['name'])
                    if match:
                        nodes_data[current_node][f'{param}_obj'] = match.group(1)
            elif line.startswith('arm_num:'):
                nodes_data[current_node]['arm_num'] = int(line.split(':')[1])
            elif line.startswith('take_time:'):
                nodes_data[current_node]['take_time'] = int(line.split(':')[1])
            elif line.startswith('edge:'):
                edge_str = line.split(':', 1)[1].strip()
                nodes_data[current_node]['edge'] = ast.literal_eval(edge_str)

    return nodes_data

def build_nodes(nodes_data):
    """Build a node network by creating Node objects and establishing dependencies"""
    nodes = {}
    id_mapping = {}  # Map node name to integer ID (e.g., node_1 → 1)

    # Step 1: Create ID mapping
    for name in nodes_data:
        node_id = int(name.split('_')[1])
        id_mapping[name] = node_id

    # Step 2: Create all Node instances
    for name, attrs in nodes_data.items():
        location = ""
        node_type = attrs['type']

        if node_type in ['pick', 'place']:
            match = re.search(r'source="([^"]+)"', attrs['name'])
            if match:
                location = match.group(1)
        elif node_type in ['flap_open', 'flap_close']:
            match = re.search(r'target="([^"]+)"', attrs['name'])
            if match:
                location = match.group(1)
        elif node_type == 'cut':
            location = 'cutting_board'

        target_obj = attrs.get('target_obj')
        source_obj = attrs.get('source_obj')

        nodes[id_mapping[name]] = Node(
            node_id=id_mapping[name],
            node_type=node_type,
            name=attrs['name'],
            arm_num=attrs['arm_num'],
            take_time=attrs['take_time'],
            dependencies=attrs['edge'],
            location=location,
            target_obj=target_obj,
            source_obj=source_obj,
            delay_after=attrs['delay_after'],
        )

    # Step 3: Establish node dependencies
    for node in nodes.values():
        node.dependencies = [id_mapping[f"node_{d}"] for d in node.dependencies]
        node.predecessors = len(node.dependencies)
        for dep_id in node.dependencies:
            nodes[dep_id].successors.append(node.id)
    return nodes

# === Problem Flags ===
# These flags are used to detect issues in task dependency logic.

# | Flag           | Meaning                                                                 |
# |----------------|-------------------------------------------------------------------------|
# | problem_flag1  | A non-pick node depends on a 'place' node that involves a different object. |
# | problem_flag2  | In a pick-use-place sequence, the 'place' node directly depends on the 'pick'
# |                | node instead of the last 'use' node.                                    |
# | problem_flag3  | A task depends on a 'use' node involving a different object.            |

global problem_flag1, problem_flag2, problem_flag3
problem_flag1 = 0
problem_flag2 = 0
problem_flag3 = 0

problem1 = []  # Nodes violating problem_flag1
problem2 = []  # Nodes violating problem_flag2
problem3 = []  # Nodes violating problem_flag3


def modify_nodes(nodes):
    global problem_flag1, problem_flag2, problem_flag3
    global problem1, problem2, problem3
    
    node_num=len(nodes)
    for i in range(node_num+1):
        if i!=0 and nodes[i].type not in SWITCH_ACTIONS:
            for j in range(nodes[i].predecessors):
                preid=nodes[i].dependencies[j]
                now_pick_match=nodes[i].name.startswith("pick")
                pre_pick_match=nodes[preid].name.startswith("pick")
                use_match=0
                len1=len(nodes[preid].successors)
                for k in range(len1):
                    pre_after_id=nodes[preid].successors[k]
                    if nodes[pre_after_id].type in USE_ACTIONS and nodes[pre_after_id].source_obj==nodes[preid].target_obj and pre_pick_match:
                        use_match=1
                if now_pick_match:
                    print(nodes[i].name)
                elif nodes[i].source_obj!=nodes[preid].source_obj and nodes[preid].type.startswith("place"):
                    problem_flag1=1
                    problem1.append(i)
                    problem1 = list(set(problem1))
                elif pre_pick_match and use_match and nodes[i].source_obj==nodes[preid].target_obj and nodes[i].type.startswith("place"):
                    problem_flag2=1
                    problem2.append(i)
                    problem2 = list(set(problem2))
                elif nodes[preid].type in USE_ACTIONS and nodes[i].source_obj!=nodes[preid].source_obj:
                    problem_flag3=1
                    problem3.append(i)
                    problem3 = list(set(problem3))
    if problem_flag1 or problem_flag2 or problem_flag3:
        print("Problematic nodes found:")
        print(problem1)
        print(problem2)
        print(problem3)
        return 1
    else:
        return 0