import json
import itertools
import os
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm
import signal

INPUT_FILE = "/data/raw_data.jsonl"
OUTPUT_FILE = "/data/processed_data.jsonl"

class TimeoutError(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutError("Function call timed out")

def find_all_decision_paths(node_id, nodes, memo):
    if node_id in memo:
        return memo[node_id]
    node_info = nodes.get(node_id)
    if not node_info or not node_info.get('options'):
        return [[]]
    all_paths = []
    options = node_info.get('options', [])
    for r in range(1, len(options) + 1):
        for combo in itertools.combinations(options, r):
            current_choice_nodes = sorted(list(set(
                itertools.chain.from_iterable(opt.get('next_node_id', []) for opt in combo)
            )))
            sub_problem_start_nodes = sorted(list(set(
                opt['next_node_id'][-1] for opt in combo if opt.get('next_node_id')
            )))
            if not sub_problem_start_nodes:
                path_segment = [current_choice_nodes]
                all_paths.append(path_segment)
            else:
                paths_from_sub_problems = []
                for sub_start_node in sub_problem_start_nodes:
                    sub_paths = find_all_decision_paths(sub_start_node, nodes, memo)
                    paths_from_sub_problems.append(sub_paths)
                for sub_path_combination in itertools.product(*paths_from_sub_problems):
                    merged_sub_path = list(itertools.chain.from_iterable(sub_path_combination))
                    full_path = [current_choice_nodes] + merged_sub_path
                    all_paths.append(full_path)
    if not all_paths and options:
         return [[]]
    elif not all_paths and not options:
        return [[]]
    memo[node_id] = all_paths
    return all_paths


def generate_all_path_logs(tree_data):
    start_node_ids = tree_data['start_node_ids']
    nodes = tree_data['nodes']
    
    all_branches_per_start_node = []
    memo = {}
    for start_id in start_node_ids:
        decision_branches = find_all_decision_paths(start_id, nodes, memo)
        all_branches_per_start_node.append(decision_branches)

    if any(not branches for branches in all_branches_per_start_node):
        return []
    path_combinations = list(itertools.product(*all_branches_per_start_node))
    
    all_paths_logs = []

    for combination in path_combinations:
        choice_map_new = {}
        for i, start_id in enumerate(start_node_ids):
            decision_path = combination[i]
            current_dec_point = start_id
            for choice in decision_path:
                if current_dec_point in choice_map_new: 
                    pass 
                choice_map_new[current_dec_point] = choice
                node_info = nodes.get(current_dec_point)
                next_dec_points = set()
                if node_info:
                    for opt in node_info.get('options',[]):
                        opt_chain = opt.get('next_node_id', [])
                        if not opt_chain: continue
                        if any(node in choice for node in opt_chain):
                            next_dec_points.add(opt_chain[-1])
                if len(next_dec_points) == 1:
                    current_dec_point = list(next_dec_points)[0]
                elif len(next_dec_points) > 1:
                    break

        visited = []
        to_visit = list(start_node_ids)
        current_path_log = []
        current_path_log.append([visited[:], to_visit[:]])

        while to_visit:
            current_node = to_visit.pop(0)
            if current_node in visited: continue
            
            visited.append(current_node)
            
            next_nodes_to_add = choice_map_new.get(current_node, [])
            
            to_visit_set = set(to_visit)
            new_nodes_to_prepend = [n for n in next_nodes_to_add if n not in to_visit_set]
            to_visit = new_nodes_to_prepend + to_visit
            
            current_path_log.append([visited[:], to_visit[:]])
            
        all_paths_logs.append(current_path_log)
        
    return all_paths_logs


def process_sample(sample):
    data_structure = sample["finegrained_tree"]
    final_result_list = generate_all_path_logs(data_structure)
    
    unique_list = []
    seen = set()
    for item in final_result_list:
        item_str = json.dumps(item)
        if item_str not in seen:
            seen.add(item_str)
            unique_list.append(item)
    
    sample.update({"trajectory": unique_list})
    return sample


def main():
    try:
        signal.signal(signal.SIGALRM, timeout_handler)
    except AttributeError:
        print("Warning: 'signal' module not available on this platform. Timeout functionality is disabled.")

    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        samples = [json.loads(line) for line in f]
    
    save_list = []
    
    print("Starting serial processing...")
    for i, sample in enumerate(tqdm(samples, desc="Processing samples")):
        try:
            if 'signal' in globals() and hasattr(signal, 'alarm'):
                signal.alarm(300)
            
            processed_sample = process_sample(sample)
            save_list.append(processed_sample)

        except TimeoutError:
            print(f"\nSkipping sample {i} due to 5-minute timeout.")
            continue 
        
        except Exception as e:
            print(f"\nAn error occurred while processing sample {i}: {e}")
            continue

        finally:
            if 'signal' in globals() and hasattr(signal, 'alarm'):
                signal.alarm(0)

    print("\nSerial processing finished.")
    
    if not save_list:
        print("No results were generated.")
        return

    length_list = [len(item["trajectory"]) for item in save_list if "trajectory" in item]
    total_paths = sum(length_list)

    print("\n--- Statistics ---")
    if length_list:
        print(f"Number of successfully processed samples: {len(save_list)}")
        print(f"Max paths per sample: {max(length_list)}")
        print(f"Min paths per sample: {min(length_list)}")
        print(f"Avg paths per sample: {sum(length_list)/len(length_list):.2f}")
        print(f"Total unique paths generated: {total_paths}")

    print(f"\nWriting {len(save_list)} results to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        for item in save_list:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print("Done.")


if __name__ == "__main__":
    main()