"""
This script is used to validate the rollout of the domain. It will run the rollout for each domain and save the results in a file.
"""
import json
import sys
import subprocess
import os 
from tqdm import tqdm
import multiprocessing
# multi thread 
import threading
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
from typing import Optional, List
import pdb
import random
import re
import argparse
def validate(pddl_domain_path: str) -> str:
    validate_command = f'Validate -v {pddl_domain_path}'
    result = subprocess.run(validate_command, shell=True, stdout=subprocess.PIPE)
    return result.stdout.decode('utf-8')

def validate_prob(pddl_domain_path,prob):
    validate_command = f'Validate -v {pddl_domain_path} {prob}'
    result = subprocess.run(validate_command, shell=True, stdout=subprocess.PIPE)
    return result.stdout.decode('utf-8')

def validate_str(pddl_domain_str: str) -> str:
    # generate a tem with random hash to avoid conflict
    random_hash = random.getrandbits(128)
    temp_domain_path = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/temp_domain_{random_hash}.pddl'
    # write the domain to the temp file
    with open(temp_domain_path, 'w') as f:
        f.write(pddl_domain_str)
    result = validate(temp_domain_path)
    # remove the temp file
    os.remove(temp_domain_path)
    return result

def validate_batch(pddl_domain_list: List[str]) -> List[str]:
    pool = ThreadPool()
    results = list(tqdm(pool.imap(validate_str, pddl_domain_list), total=len(pddl_domain_list)))
    return results

def preprocess_pddl_batch(pddl_str_list: List[str]) -> List[str]:
    return [pddl_str_list(pddl_str) for pddl_str in pddl_str_list]

class Validator:
    def __init__(self, args) -> None:
        self.input_file = args.input_file
        self.output_file = args.output_file
        self.results = []
        with open(self.input_file, 'r') as f:
            self.pddl_file = json.load(f)

        assert self.output_file.endswith('.json'), 'The output file must be a json file'
    
    def preprocess_pddl(self, pddl_str: str) -> str:
        # extract the content of pddl by regex
        patterns = [
            r'```pddl\n(.*?)\n```',
            r'```lisp\n(.*?)\n```'
        ]
        for pattern in patterns:
            match = re.search(pattern, pddl_str, re.DOTALL)
            if match:
                return match.group(1)

        return pddl_str
    
    def validate_str(self, pddl_domain_str: str) -> str:
        # generate a tem with random hash to avoid conflict
        random_hash = random.getrandbits(128)
        temp_domain_path = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/temp_domain_{random_hash}.pddl'
       
        # write the domain to the temp file
        with open(temp_domain_path, 'w') as f:
            f.write(pddl_domain_str)
        result = validate(temp_domain_path)
        # remove the temp file
        os.remove(temp_domain_path)
        return result
    def validate_prob(self, domain_path,prob):
        # generate a tem with random hash to avoid conflict
        random_hash = random.getrandbits(128)
     
        # temp_domain_path = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/temp_domain_{random_hash}.pddl'
        
        temp_prob_path = f'/lustre/fast/fast/txiao/zly/spatial_head/cot/tmp/temp_prob_{random_hash}.pddl'

        with open(temp_prob_path, 'w') as f:
            f.write(prob)
        # with open(temp_domain_path, 'w') as f:
        #     f.write(domain_path)
        result = validate_prob(domain_path,temp_prob_path)
        # remove the temp file
        # os.remove(temp_domain_path)
        os.remove(temp_prob_path)
        return result
    def validate_batch(self) -> List[str]:
        pddl_str_list = [self.preprocess_pddl(pddl['domain']) for pddl in self.pddl_file]
        return validate_batch(pddl_str_list)

    def multi_run_1(self):
        data = self.pddl_file
        cores_num = multiprocessing.cpu_count()
        with ThreadPoolExecutor(max_workers=cores_num) as executor:
            # tqdm 
            list(tqdm(executor.map(self.validate_json, data), total=len(data)))    

    def validate_json(self, pddl):
        pddl_code = self.preprocess_pddl(pddl['domain'])
        pddl['result'] = self.validate_str(pddl_code)
        if 'nl' in self.input_file:
            result = {
                "file": pddl['file'],
                "domain": pddl['domain'],
                "result": pddl['result'],
                "nl_description": pddl['nl_description'],
                "response_id": pddl["response_id"],
                "loss":pddl["loss"],
                "old_domain":pddl["old_domain"],
                'rounds':pddl["rounds"],
            }
        else:
            result = {

                "file": pddl['file'],
                "domain": pddl['domain'],
                "result": pddl['result'],
                "question":pddl['question'],
                # "nl_description": pddl['nl_description'],

                # 'rounds':pddl["rounds"],
                "response_id": pddl["response_id"],
            }
        self.results.append(result)
        
    def validateprob(self) -> List[str]:
        results = []
        for pddl in tqdm(self.pddl_file):
        
            # pddl_code = pddl['domain']
            prob_code = pddl['question']
        
            if 'block' in pddl['name']:
                domain_path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val/planetarium/block_world_domain.pddl'
            elif 'gripper' in pddl['name']:
                domain_path = '/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val/planetarium/gripper_domain.pddl'
            else:
                pass
            pddl['result'] = self.validate_prob(domain_path, prob_code)
            # name = [pddl['name'] for pddl in self.pddl_file] #grippern_room_distributed_to_n_room_distributed_2_1_2_3_4_11', 'blocksworldtower_to_on_table_
        
            result = {

                "id": pddl['id'],
                "name": pddl["name"],
                # "domain": pddl['domain'],
                "result": pddl['result'],
                "question":pddl['question'],
                "response_id": pddl["response_id"],
            }
            
            results.append(result)
                
        return results  
    def gen_bad(self) -> List[str]:
        results = []
        for pddl in self.pddl_file:
            if pddl['result'].startswith("Type-checking") and "fail" not in pddl['result'] and "incorrectly" not in pddl['result']:
                pass
            else:
                if 'prob' in self.input_file:

                    result = {
                        # "informal_stmt": pddl['informal_stmt'],
                        "file": pddl['file'],
                        "domain": pddl['domain'],
                        "result": pddl['result'],
                        "question":pddl['question'],
                        # "nl_description": pddl['nl_description'],
                        "response_id": pddl["response_id"],
                    }
                    results.append(result)
                else:
                    result = {
                        # "informal_stmt": pddl['informal_stmt'],
                        "file": pddl['file'],
                        "domain": pddl['domain'],
                        "result": pddl['result'],
                        # "question":pddl['question'],
                        "nl_description": pddl['nl_description'],
                        "response_id": pddl["response_id"],
                    }
                    results.append(result)

        files = [data['file'] for data in results]
        files = list(set(files))
        new_results=[]
        bad_files = []
        for file in tqdm(files):
            same_file = [data for data in results if data['file']==file]
            res = [data['response_id'] for data in same_file]
            if len(res)==8:
                bad_file = file
                bad_files.append(bad_file)
        datas = [data for data in self.pddl_file if data['file'] in bad_files]
        # for data in datas:

        return datas    
    def select_bad(self) -> List[str]:

        results = []
        file_real = []
        files = [data['file'] for data in self.pddl_file]
        files = list(set(files))
        for file in files:
            same_file = [data for data in self.pddl_file if data['file']==file]
            res = [data['response_id'] for data in same_file]
            # res = list(set(res))
            if len(res) == 16:
                file_real.append(file)
        pddl_file = [data for data in self.pddl_file if data['file'] in file_real]
        # pdb.set_trace()
        for pddl in tqdm(pddl_file):
            pddl_code = self.preprocess_pddl(pddl['domain'])
            pddl['result'] = self.validate_str(pddl_code)

            result = {
            # "informal_stmt": pddl['informal_stmt'],
            "file": pddl['file'],
            "domain": pddl['domain'],
            "result": pddl['result'],
            "question":pddl['question'],
            "nl_description": pddl['nl_description'],
            "response_id": pddl["response_id"],
        }
            results.append(result)
        return results    

    def save_results(self, results):
        with open(self.output_file, 'w') as f:
            json.dump(results, f, indent=4)
        print(f"save in {self.output_file}")
    def save(self):
        with open(self.output_file, 'w') as f:
            # save the results with indent 4
            json.dump(self.results, f, indent=4)
    def validate_bc(self,data):
      
        pddl_code = self.preprocess_pddl(data['domain'])
        data['result'] = self.validate_str(pddl_code)

        result = {

            "file": data['file'],
            "domain": data['domain'],
            "result": data['result'],
            # "question":pddl['question'],
            # "nl_description": pddl['nl_description'],
            'rounds':data["rounds"],
            "response_id": data["response_id"],
        }
        self.results.append(result)
          
    def multi_run(self):
        data = self.pddl_file
        cores_num = multiprocessing.cpu_count()
        with ThreadPoolExecutor(max_workers=cores_num) as executor:
            list(tqdm(executor.map(self.validate_bc, data), total=len(data)))    
        
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/textgrad/coder8_4_nl.json', help="Input file")
    parser.add_argument("--output_file", type=str, default='/lustre/fast/fast/txiao/zly/spatial_head/cot/result/Val/textgrad/coder8_4_nl.json', help="Output file")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    validator = Validator(args)
    # datas = validator.gen_bad()
    # validator.save_results(datas)
    # print(len(datas))
    # if 'planetarium' in args.input_file:
    #     results = validator.validateprob()
    # else: 
    #     results = validator.validate_json()
    validator.multi_run_1()
    # validator.save_results(results)
    # if 'textgrad' in args.input_file:
    #     validator.multi_run()
    validator.save()

