"""
Test the portion of problems that can be solved by fast-downward
"""

import os
import pdb
import subprocess
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import json
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--domain_file', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/llm-pddl/domains/barman/syn_domain.pddl')
    parser.add_argument('--problem_dir', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/llm-pddl/domains/barman/')
    parser.add_argument('--temp_dir', type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp')
    return parser.parse_args()

class Get_gt:
    def __init__(self,args) -> None:
        self.domain_file = args.domain_file
        self.problem_dir = args.problem_dir
        self.temp_dir = args.temp_dir
        self.results = []
    def get_gt(self, prob_file):
        prob_name = prob_file.split('.')[0]
        plan_file = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/plan_syn_domain_{prob_name}'
        prob_path = self.problem_dir + prob_file
        sas_random_file = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/sas_syn_domain_{prob_name}'
        # cmd = f'python3 /lustre/fast/fast/txiao/zly/downward/fast-downward.py --plan-file {plan_file} --search-time-limit 60s {self.domain_file} {prob_path} --search "astar(lmcut())" > nul 2>&1' # > nul 2>&1
        cmd1 =f'python3 /lustre/fast/fast/txiao/zly/downward/fast-downward.py --translate --sas-file {sas_random_file} {self.domain_file} {prob_path} > nul 2>&1'
        cmd2 = f'python3 /lustre/fast/fast/txiao/zly/downward/fast-downward.py --alias seq-sat-fdss-2018 --plan-file {plan_file} --search-time-limit 20s {sas_random_file} > nul 2>&1'
        # get the files name with plan_file,plan_file can be like plan_syn_domain_p01.2,  plan_syn_domain_p01.1....
        
        
        # if os.path.exists(plan_file):
        #     with open(plan_file, 'r') as f:
        #         plan = f.read()
        # else:
        #     subprocess.run(cmd1, shell=True)
        #     subprocess.run(cmd2, shell=True)
        #     plan_files = [f for f in os.listdir(self.temp_dir) if f.startswith(f'plan_syn_domain_{prob_name}')]
        #     plan_file = plan_files[-1]
        #         # subprocess.run(cmd, shell=True)
        #     if os.path.exists(plan_file):
        #         with open(plan_file, 'r') as f:
        #             plan = f.read()
        #     else:
        #         plan = ''
        plan_files = [f for f in os.listdir(self.temp_dir) if f.startswith(f'plan_syn_domain_{prob_name}')]
        if len(plan_files) == 0:
            pdb.set_trace()
            print(plan_files)
        plan_file = plan_files[-1]
        plan_path = self.temp_dir + '/' + plan_file
        
        with open(plan_path, 'r') as f:
            plan = f.read()
        result = {
            'domain_type': 'barman',
            'domain_file': self.domain_file,
            'problem_name': prob_name,
            'plan': plan
        }
        self.results.append(result)
    def multi_run(self):
      
        prob_files = {prob_file for prob_file in os.listdir(self.problem_dir) if prob_file.endswith('.pddl') and 'domain.pddl' not in prob_file}
        cores_num = multiprocessing.cpu_count()
        with ThreadPoolExecutor(max_workers=cores_num) as executor:
            list(tqdm(executor.map(self.get_gt, prob_files), total=len(prob_files)))  
    def save(self,output_file):
        with open(output_file, 'w') as f:
            json.dump(self.results, f, indent=4)

if __name__ == '__main__':
    args = parse_args()
    get_gt = Get_gt(args)
    get_gt.multi_run()
    get_gt.save('/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/syn_barman_gt.json')   












    # count the ratio of the problems that have plans
    # def process_problem(args):
    #     domain, problem, domain_path, temp_dir = args
    #     problem_path = os.path.join(os.path.dirname(domain_path), problem)
    #     plan_file = os.path.join(temp_dir, f'plan_{domain}_{problem}')
        
    #     if os.path.exists(plan_file):
    #         return True, domain, problem, None
    #     # using bfws to solve the problem 
    #     cmd = f'python3 /lustre/fast/fast/txiao/zly/downward/fast-downward.py --plan-file {plan_file} --search-time-limit 20s {domain_path} {problem_path} --search "astar(lmcut())" '
    #     subprocess.run(cmd, shell=True)
    #     # pdb.set_trace()
    #     if os.path.exists(plan_file):
    #         with open(plan_file, 'r') as f:
    #             plan = f.read()
    #         return True, domain, problem, plan
    #     return False, domain, problem, None

    # test_dir = '/lustre/fast/fast/txiao/zly/spatial_head/cot/benchmark/llm-pddl/domains/'
    # temp_dir = '/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp'
    # plan_count = 0
    # total_count = 0
    
    # # Prepare all problems for parallel processing
    # all_problems = []
    # for domain in os.listdir(test_dir):
    #     domain_path = os.path.join(test_dir, domain, 'syn_domain.pddl')
    #     if not os.path.exists(domain_path):
    #         continue
    #     problems = [p for p in os.listdir(os.path.dirname(domain_path)) if p.startswith('p') and p.endswith('.pddl')]
    #     all_problems.extend([(domain, problem, domain_path, temp_dir) for problem in problems])

    # total_count = len(all_problems)
    
    # # Process problems in parallel
    # cores = multiprocessing.cpu_count()
    # with multiprocessing.Pool(processes=cores) as pool:
    #     results = list(tqdm(pool.imap(process_problem, all_problems), total=total_count))
    
    # # Process results
    # for has_plan, domain, problem, plan in results:
    #     if has_plan:
    #         plan_count += 1
    #         if plan:  # Only print if plan was actually generated this run
    #             print(f'{domain} {problem} {plan}')
    #     else:
    #         print(f'{domain} {problem} no plan')

    # print(f'plan_count: {plan_count}, total_count: {total_count}, ratio: {plan_count/total_count}')