# Copyright (c) 2024-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import List

from ml_collections import ConfigDict

from domains import Domain, PDDLEnv
from evaluation import PlanningEvaluator, PlanRatings
from gpt_client import GPTClient
from pddl_utils import PDDLObj, validate_problem_pddl
from utils import wrap_code, mean, harmonic_mean, extract_code
import prompts
from dataclasses import dataclass
import logging
import pdb
import re
import os

@dataclass
class PlanningStrategy:
    turns: int = 5
    best_of_n: int = 1
    rw_feedback: bool = True
    bi_rw_feedback: bool = True
    vlm_path: str = ''
    multimodal: bool = False
    sequence: bool = False


STOCHASTIC_TEMPERATURE = 0.7
DETERMINISTIC_TEMPERATURE = 0.0

SYSTEM_MESSAGE = """You are a helpful assistant, skilled in producing Planning Domain Definition Language (PDDL) code of environments.
You are only allowed to modify the PDDL code using the following two python function interfaces:

```python
add_or_update_predicates(predicates: List[str])
modify_action(action_name: str, new_preconditions: List[str], new_effects: List[str])
```
"""

def generate_one_domain_file(
        problem_translation_candidates: List[str],
        context_domain: Domain,
        target_domain: Domain,
        gpt_client: GPTClient,
        pddl_env: PDDLEnv,
        planning_strategy: PlanningStrategy,
        task_index: int,
        exp_flags: ConfigDict
):
    print('Starting generating domain file screening')
    assert len(problem_translation_candidates) == 1
    target_gen_problem_pddl = problem_translation_candidates[0]
    target_domain_nl_wrapped = wrap_code(target_domain.get_domain_nl(), lang='markdown')
    target_domain_pddl = target_domain.get_domain_pddl()
    target_domain_template_pddl = target_domain.get_domain_template_pddl()
    target_domain_template_pddl_wrapped = wrap_code(target_domain_template_pddl, lang='pddl')
    target_problem_pddl, _, _, img_path = target_domain.get_task(task_index) # target_problem_pddl = None
    target_gen_problem_pddl_wrapped = wrap_code(target_gen_problem_pddl, lang='pddl')
    target_initial_image_observation = prompts.encode_image(img_path)
    # assert context_domain.name == 'blocksworld', "Improved one-shot prompt is only supported for blocksworld."
    if context_domain.name == 'maze':
        context_shot_example=prompts.MAZE_EXAMPLE
    else:
        context_shot_example=prompts.BLOCKS_WORLD_EXAMPLE
    # target_domain_nl_wrapped = (re.sub(r'pos-(\d+)-(\d+)', r'pos-\2-\1', target_domain_nl_wrapped)).replace('column', 'column_no').replace('row', 'column').replace('column_no', 'row').replace('third row and fourth column', 'forth row and third column')
    init_prompt = prompts.ONE_SHOT_INIT_PROMPT_TEMPLATE.format(
        context_domain_name=context_domain.name,
        target_domain_name=target_domain.name,
        context_shot_example=context_shot_example,
        target_domain_nl=target_domain_nl_wrapped,
        target_domain_template_pddl=target_domain_template_pddl_wrapped,
        target_problem_pddl=target_gen_problem_pddl_wrapped
    )
    if not planning_strategy.multimodal: 
        init_prompt = init_prompt.replace('and an image observation', '')

    pddl_obj = PDDLObj.from_pddl_str(target_domain_template_pddl, domain_pddl_template=target_domain_template_pddl)
    conv_id, _ = gpt_client.make_new_chat(system_message=SYSTEM_MESSAGE)
    user_input = (init_prompt, target_initial_image_observation)
    
    best_conv_id, gpt_output, _ = gpt_client.complete_one_chat(conv_id, user_input)
    # pdb.set_trace()
    planning_evaluator = PlanningEvaluator(
        pddl_env, target_domain_pddl, target_problem_pddl, target_gen_problem_pddl, None,
        planning_strategy.rw_feedback, planning_strategy.vlm_path, None,#target_domain.get_domain_predicate_descriptor()
        exp_flags=exp_flags, bi_rw_feedback=planning_strategy.bi_rw_feedback, multimodal = planning_strategy.multimodal, sequence = planning_strategy.sequence
    )
    # pdb.set_trace()
    # gen_plan, is_domain_valid, error_msg =planning_evaluator.env.search_plan(planning_evaluator.target_domain_pddl, planning_evaluator.target_problem_pddl)
    # planning_evaluator.env.validate_plan(planning_evaluator.target_domain_pddl, planning_evaluator.target_problem_pddl, gen_plan)
    is_domain_valid, new_domain_pddl = planning_evaluator.rate_domain_valid(
                            pddl_obj, gpt_output, img_path
                        )
    print(is_domain_valid)
    print(new_domain_pddl)
    print(gpt_output)
    # pdb.set_trace()
    return is_domain_valid, conv_id, new_domain_pddl, gpt_output

def rate_and_update_files(planning_strategy, pddl_env, target_domain, task_index, conv_id, gpt_client, target_gen_problem_pddl, target_gen_domain_pddl, exp_flags, run_exp_dir, othervlm=False):
    print('Starting rating and updating both files')
    target_domain_pddl = target_domain.get_domain_pddl()
    target_problem_pddl, target_problem_nl, _, img_path = target_domain.get_task(task_index) # target_problem_pddl = None
    target_initial_image_observation = prompts.encode_image(img_path)
    target_domain_template_pddl = target_domain.get_domain_template_pddl()
    pddl_obj = PDDLObj.from_pddl_str(target_domain_template_pddl, domain_pddl_template=target_domain_template_pddl)
    turns = planning_strategy.turns
    best_rating, best_generated_problem, best_generated_domain, best_obj, best_conv_id = float('-inf'), target_gen_problem_pddl, target_gen_domain_pddl, None, conv_id
    aux = {}
    first_solve = False
    vlm_runtime = []
    vlm_numcall = []
    for step in range(1, turns + 1):
        print('starting round ', str(step))
        if othervlm == '4o':
            input_gpt_client = gpt_client
        elif othervlm == 'vilau':
            input_gpt_client = 'vila-u'
        else:
            input_gpt_client = None
        planning_evaluator = PlanningEvaluator(
            pddl_env, target_domain_pddl, target_problem_pddl, target_gen_problem_pddl, target_problem_nl,
            planning_strategy.rw_feedback, planning_strategy.vlm_path, None,#target_domain.get_domain_predicate_descriptor()
            exp_flags=exp_flags, bi_rw_feedback=planning_strategy.bi_rw_feedback, multimodal = planning_strategy.multimodal, sequence = planning_strategy.sequence, gpt_client=input_gpt_client
        )
        # first rate the (P,D) pair
        planning_evaluation = planning_evaluator.rate_domain_modification_feedback(
                pddl_obj, target_gen_problem_pddl, target_gen_domain_pddl, img_path, step, best_rating
            )
        pddl_obj = planning_evaluation.new_pddl_obj
        err_msg = planning_evaluation.error_msg
        rating = planning_evaluation.rating
        if err_msg is not None and len(err_msg) > 0:
            maybe_error = f"The environment returned the following error:\n\n{err_msg}\n\n"
        else:
            maybe_error = ""
        logging.info(f"Generated Domain Rating: {rating}")
        vlm_runtime.append(planning_evaluator.vlm_runtime)
        vlm_numcall.append(planning_evaluator.vlm_numcall)
        if rating >= best_rating:
            print('better rating, update!')
            best_rating = rating
            best_generated_problem = target_gen_problem_pddl
            best_generated_domain = target_gen_domain_pddl
            best_obj = pddl_obj
            best_conv_id = conv_id
        if planning_evaluation.solution_found:
            if step == 1:
                print('first met')
                first_solve = True
            print('Solution Found!!')
            break
        
        system_prompt, few_shot_messages, user_input = prompts.get_two_files_modification_messages(
            target_domain_nl=wrap_code(target_domain.get_domain_nl(), lang='markdown'),
            target_problem_nl=wrap_code(target_domain.get_task_nl(task_index), lang='markdown'),
            target_domain_template=wrap_code(target_domain.get_domain_template_pddl(), lang='pddl'),
            problem_pddl=wrap_code(best_generated_problem, lang='pddl'),
            domain_pddl=wrap_code(best_generated_domain, lang='python'), # here you need to either input the domain file or the domain template. types is not included here!!!
            execution_feedback=maybe_error,
            problem_img=target_initial_image_observation,
            multimodal=planning_strategy.multimodal
        )
        # pdb.set_trace()
        conv_id, _ = gpt_client.make_new_chat(system_message=system_prompt)
        gpt_client.add_chat_messages(conv_id, few_shot_messages)
        # user_input = (f"Incorrect. {maybe_error}Please reason about the issue with your generated problem file or generated code for domain generation. The current problem pddl is as follows:\n\n{wrap_code(target_gen_problem_pddl, lang='pddl')} and the current domain pddl is as follows:\n\n{wrap_code(target_gen_domain_pddl, lang='pddl')}\n\nIn your response, please 1. reason about if you want to update your problem file and domain file to fix the issue by answering Q1: do all action have all necessary, clear, logically correct preconditions and effects?\n Q2: are actions distinguishable between each other?\n Q3: is there any missing objects/init conditions that are needed for the action? for example, is adjacency (e.g., move-dir or neighbors) symmetric? Q4: is any objects/init states/goals are specified incorrectly? 2. update the problem file if needed; 3. generate the new code for domain generation if needed. Please give your answer in the following format:\nProblem file and Domain file update reasoning: []\n\nNew problem file: [fill in N/A if not needed]\n\nNew domain file: [fill in N/A if not needed]\n\n", target_initial_image_observation)
        print(user_input[0])
        conv_id, gpt_output, _ = gpt_client.complete_one_chat(conv_id, user_input)
        print(gpt_output)
        os.makedirs(os.path.join(run_exp_dir, 'update_records'), exist_ok=True)
        with open(os.path.join(run_exp_dir, f'update_records/{step}_in.txt'), "w") as file:
            file.write(user_input[0])
        with open(os.path.join(run_exp_dir, f'update_records/{step}_out.txt'), "w") as file:
            file.write(gpt_output)
        # pdb.set_trace()
        updated_problem_pddl = gpt_output.split('New problem file:')[1].split('New domain file')[0].replace('[', '').replace(']', '')
        if 'N/A' not in updated_problem_pddl:
            try:
                target_gen_problem_pddl = extract_code(updated_problem_pddl, lang='pddl')
                validate_problem_pddl(target_gen_problem_pddl)
            except Exception as e:
                print(e)
                # pdb.set_trace()
        updated_domain_pddl = gpt_output.split('New domain file:')[1]
        if 'N/A' not in updated_domain_pddl:
            target_gen_domain_pddl = updated_domain_pddl
        # pdb.set_trace()
    new_domain_pddl = best_obj.to_str()
    return best_generated_problem, new_domain_pddl, first_solve, step, vlm_runtime, vlm_numcall

def rate_and_update_files_nofeedback(planning_strategy, pddl_env, target_domain, task_index, conv_id, gpt_client, target_gen_problem_pddl, target_gen_domain_pddl, exp_flags, run_exp_dir):
    print('Starting rating and updating both files')
    target_domain_pddl = target_domain.get_domain_pddl()
    target_problem_pddl, target_problem_nl, _, img_path = target_domain.get_task(task_index) # target_problem_pddl = None
    target_initial_image_observation = prompts.encode_image(img_path)
    target_domain_template_pddl = target_domain.get_domain_template_pddl()
    pddl_obj = PDDLObj.from_pddl_str(target_domain_template_pddl, domain_pddl_template=target_domain_template_pddl)
    turns = planning_strategy.turns
    best_rating, best_generated_problem, best_generated_domain, best_obj, best_conv_id = float('-inf'), target_gen_problem_pddl, target_gen_domain_pddl, None, conv_id
    aux = {}
    first_solve = False
    vlm_runtime = []
    vlm_numcall = []
    for step in range(1, turns + 1):
        print('starting round ', str(step))
        planning_evaluator = PlanningEvaluator(
            pddl_env, target_domain_pddl, target_problem_pddl, target_gen_problem_pddl, target_problem_nl,
            planning_strategy.rw_feedback, planning_strategy.vlm_path, None,#target_domain.get_domain_predicate_descriptor(),
            exp_flags=exp_flags, bi_rw_feedback=planning_strategy.bi_rw_feedback, multimodal = planning_strategy.multimodal, sequence = planning_strategy.sequence
        )
        # first rate the (P,D) pair
        planning_evaluation = planning_evaluator.rate_domain_modification_nofeedback(
                pddl_obj, target_gen_problem_pddl, target_gen_domain_pddl, img_path, step, best_rating
            )
        pddl_obj = planning_evaluation.new_pddl_obj
        err_msg = planning_evaluation.error_msg
        rating = planning_evaluation.rating
        if err_msg is not None and len(err_msg) > 0:
            maybe_error = err_msg
        else:
            maybe_error = ""
        logging.info(f"Generated Domain Rating: {rating}")
        vlm_runtime.append(planning_evaluator.vlm_runtime)
        vlm_numcall.append(planning_evaluator.vlm_numcall)
        if rating >= best_rating:
            print('better rating, update!')
            best_rating = rating
            best_generated_problem = target_gen_problem_pddl
            best_generated_domain = target_gen_domain_pddl
            best_obj = pddl_obj
            best_conv_id = conv_id
        if rating == 1:
            if step == 1:
                print('first met')
                first_solve = True
            print('Solution Found!!')
            break
        
        system_prompt, few_shot_messages, user_input = prompts.get_two_files_modification_messages(
            target_domain_nl=wrap_code(target_domain.get_domain_nl(), lang='markdown'),
            target_problem_nl=wrap_code(target_domain.get_task_nl(task_index), lang='markdown'),
            target_domain_template=wrap_code(target_domain.get_domain_template_pddl(), lang='pddl'),
            problem_pddl=wrap_code(best_generated_problem, lang='pddl'),
            domain_pddl=wrap_code(best_generated_domain, lang='python'), # here you need to either input the domain file or the domain template. types is not included here!!!
            execution_feedback=maybe_error,
            problem_img=target_initial_image_observation,
            multimodal=planning_strategy.multimodal
        )
        # pdb.set_trace()
        conv_id, _ = gpt_client.make_new_chat(system_message=system_prompt)
        gpt_client.add_chat_messages(conv_id, few_shot_messages)
        # user_input = (f"Incorrect. {maybe_error}Please reason about the issue with your generated problem file or generated code for domain generation. The current problem pddl is as follows:\n\n{wrap_code(target_gen_problem_pddl, lang='pddl')} and the current domain pddl is as follows:\n\n{wrap_code(target_gen_domain_pddl, lang='pddl')}\n\nIn your response, please 1. reason about if you want to update your problem file and domain file to fix the issue by answering Q1: do all action have all necessary, clear, logically correct preconditions and effects?\n Q2: are actions distinguishable between each other?\n Q3: is there any missing objects/init conditions that are needed for the action? for example, is adjacency (e.g., move-dir or neighbors) symmetric? Q4: is any objects/init states/goals are specified incorrectly? 2. update the problem file if needed; 3. generate the new code for domain generation if needed. Please give your answer in the following format:\nProblem file and Domain file update reasoning: []\n\nNew problem file: [fill in N/A if not needed]\n\nNew domain file: [fill in N/A if not needed]\n\n", target_initial_image_observation)
        print(user_input[0])
        conv_id, gpt_output, _ = gpt_client.complete_one_chat(conv_id, user_input)
        print(gpt_output)
        os.makedirs(os.path.join(run_exp_dir, 'update_records'), exist_ok=True)
        with open(os.path.join(run_exp_dir, f'update_records/{step}_in.txt'), "w") as file:
            file.write(user_input[0])
        with open(os.path.join(run_exp_dir, f'update_records/{step}_out.txt'), "w") as file:
            file.write(gpt_output)
        # pdb.set_trace()
        updated_problem_pddl = gpt_output.split('New problem file:')[1].split('New domain file')[0].replace('[', '').replace(']', '')
        if 'N/A' not in updated_problem_pddl:
            try:
                target_gen_problem_pddl = extract_code(updated_problem_pddl, lang='pddl')
                validate_problem_pddl(target_gen_problem_pddl)
            except Exception as e:
                print(e)
                # pdb.set_trace()
        updated_domain_pddl = gpt_output.split('New domain file:')[1]
        if 'N/A' not in updated_domain_pddl:
            target_gen_domain_pddl = updated_domain_pddl
        # pdb.set_trace()
    new_domain_pddl = best_obj.to_str()
    return best_generated_problem, new_domain_pddl, first_solve, step, vlm_runtime, vlm_numcall
    
def evaluate_planning_on_problem_candidates(
        problem_translation_candidates: List[str],
        context_domain: Domain,
        target_domain: Domain,
        gpt_client: GPTClient,
        pddl_env: PDDLEnv,
        planning_strategy: PlanningStrategy,
        task_index: int,
        exp_flags: ConfigDict
):
    all_ratings = []
    best_best_rating, return_params = float('-inf'), None
    all_aux = {'problem_candidates_aux': [], 'best_candidate_idx': -1}
    for i, candidate in enumerate(problem_translation_candidates):
        logging.info(f"Evaluating candidate {i + 1}/{len(problem_translation_candidates)}")
        best_rating, best_generated_pddl, aux = evaluate_action_level_planning(
            context_domain=context_domain,
            target_domain=target_domain,
            target_gen_problem_pddl=candidate,
            gpt_client=gpt_client,
            pddl_env=pddl_env,
            planning_strategy=planning_strategy,
            task_index=task_index,
            exp_flags=exp_flags
        )
        aux['gen_problem_pddl'] = candidate
        logging.info(f"Best rating for candidate {i + 1}/{len(problem_translation_candidates)}: {best_rating}")
        logging.info(f"Candidate {i + 1}/{len(problem_translation_candidates)}: {candidate}")
        logging.info(
            f"Best generated PDDL for candidate {i + 1}/{len(problem_translation_candidates)}: {best_generated_pddl}")
        if best_rating > best_best_rating:
            best_best_rating = best_rating
            return_params = (best_rating, best_generated_pddl, candidate)
            all_aux['best_candidate_idx'] = i

        all_aux['problem_candidates_aux'].append(aux)
        all_ratings.append(best_rating)
        if best_rating == PlanRatings.SOLUTION_FOUND:
            logging.info(f"Solution found for candidate {i + 1}/{len(problem_translation_candidates)}")
            logging.info(f"Stopping early since a solution was found.")
            break
    return all_ratings, return_params, all_aux


def evaluate_action_level_planning(
        context_domain: Domain,
        target_domain: Domain,
        target_gen_problem_pddl: str,
        gpt_client: GPTClient,
        pddl_env: PDDLEnv,
        planning_strategy: PlanningStrategy,
        task_index: int,
        exp_flags: ConfigDict
):
    target_domain_nl_wrapped = wrap_code(target_domain.get_domain_nl(), lang='markdown')
    target_domain_pddl = target_domain.get_domain_pddl() # not provided in our case
    target_domain_template_pddl = target_domain.get_domain_template_pddl()
    target_domain_template_pddl_wrapped = wrap_code(target_domain_template_pddl, lang='pddl')
    target_problem_pddl, target_problem_nl, _, img_path = target_domain.get_task(task_index) # target_problem_pddl = None
    target_gen_problem_pddl_wrapped = wrap_code(target_gen_problem_pddl, lang='pddl')
    target_initial_image_observation = prompts.encode_image(img_path)

    if context_domain.name == 'maze':
        context_shot_example=prompts.MAZE_EXAMPLE
    else:
        context_shot_example=prompts.BLOCKS_WORLD_EXAMPLE
    # assert context_domain.name == 'blocksworld', "Improved one-shot prompt is only supported for blocksworld."
    init_prompt = prompts.ONE_SHOT_INIT_PROMPT_TEMPLATE.format(
        context_domain_name=context_domain.name,
        target_domain_name=target_domain.name,
        context_shot_example=context_shot_example,
        target_domain_nl=target_domain_nl_wrapped,
        target_domain_template_pddl=target_domain_template_pddl_wrapped,
        target_problem_pddl=target_gen_problem_pddl_wrapped
    )
    if not planning_strategy.multimodal: 
        init_prompt = init_prompt.replace('and an image observation', '')

    pddl_obj = PDDLObj.from_pddl_str(target_domain_template_pddl, domain_pddl_template=target_domain_template_pddl)
    planning_evaluator = PlanningEvaluator(
        pddl_env, target_domain_pddl, target_problem_pddl, target_gen_problem_pddl, target_problem_nl,
        planning_strategy.rw_feedback, planning_strategy.vlm_path, None,#target_domain.get_domain_predicate_descriptor(), 
        exp_flags=exp_flags, bi_rw_feedback=planning_strategy.bi_rw_feedback, multimodal = planning_strategy.multimodal, sequence = planning_strategy.sequence
    )
    turns = planning_strategy.turns
    best_rating, best_generated_pddl, best_conv_id = float('-inf'), "", ""
    aux = {}
    conv_id, _ = gpt_client.make_new_chat(system_message=SYSTEM_MESSAGE)
    user_input = (init_prompt, target_initial_image_observation)
    # pdb.set_trace()
    for step in range(1, turns + 1):
        old_conv_id = conv_id
        conv_id, planning_evaluation, response_aux = _get_best_of_n_responses(
            gpt_client, planning_evaluator, pddl_obj, conv_id, user_input, planning_strategy.best_of_n, img_path, step
        )
        pddl_obj = planning_evaluation.new_pddl_obj
        new_domain_pddl = pddl_obj.to_str()
        err_msg = planning_evaluation.error_msg
        rating = planning_evaluation.rating

        if err_msg is not None and len(err_msg) > 0:
            maybe_error = f"The environment returned the following error:\n\n{err_msg}\n\n"
        else:
            maybe_error = ""
        logging.info(f"Generated Domain Rating: {rating}")
        if rating > best_rating:
            best_rating = rating
            best_generated_pddl = new_domain_pddl
            best_conv_id = conv_id
        if planning_evaluation.solution_found:
            break
        user_input = (f"Incorrect. {maybe_error}Please reason about the issue with your generated code. The current domain pddl is as follows:\n\n{wrap_code(new_domain_pddl, lang='pddl')}\n\nIn your response, please generate a new code to fix the issue.", target_initial_image_observation)

    aux.update({
        "best_conv_id": best_conv_id,
        "best_rating": best_rating,
        "best_generated_domain_pddl": best_generated_pddl,
    })
    logging.info(f"Best rating: {best_rating} with conversation id: {best_conv_id}")
    return best_rating, best_generated_pddl, aux


def evaluate_all_tasks(
        pddl_env: PDDLEnv,
        target_domain_pddl: str,
        target_domain_problem_pddls: List[str],
        target_gen_domain_pddl: str,
        target_gen_problem_pddls: List[str],
        other_task_image_paths:List[str],
        vlm_path: str,
        multimodal: bool,
        sequence: bool,
        exp_flags: ConfigDict,
):
    args = (
        pddl_env, target_domain_pddl, target_domain_problem_pddls, target_gen_domain_pddl, target_gen_problem_pddls, other_task_image_paths,
        vlm_path, multimodal, sequence, exp_flags
    )
    return _evaluate_all_tasks_plan_gen(*args)#, evaluate_all_tasks_random_walk(*args)
    
def _evaluate_all_tasks_plan_gen(
        pddl_env: PDDLEnv,
        target_domain_pddl: str,
        target_domain_problem_pddls: List[str],
        target_gen_domain_pddl: str,
        target_gen_problem_pddls: List[str],
        other_task_image_paths,
        vlm_path: str,
        multimodal: bool,
        sequence: bool,
        exp_flags: ConfigDict,
):
    assert len(target_domain_problem_pddls) == len(target_gen_problem_pddls)
    n_valids = 0
    valid = []
    count = 1
    for t_gen_p, t_gt_p in zip(target_gen_problem_pddls, target_domain_problem_pddls):
        print('evaluating task')
        gen_plan, is_domain_valid, error_msg = pddl_env.search_plan(target_gen_domain_pddl, t_gen_p)
        print(error_msg)
        # pdb.set_trace()
        if gen_plan is not None:
            print('generated plan')
            print(gen_plan)
            # t_gt_p = (re.sub(r'pos-(\d+)-(\d+)', r'pos-\2-\1', t_gt_p))
            if 'maze' in target_domain_pddl:
                gen_plan_convert = gen_plan.replace('pos', 'loc')
            elif 'sokoban' in target_domain_pddl:
                gen_plan_convert = gen_plan.replace('box-1', 'stone-01').replace('box-2', 'stone-02').replace(' left', ' dir-left').replace(' right', ' dir-right').replace(' up', ' dir-up').replace(' down', ' dir-down').replace('move ', 'move player-01 ').replace('push-to-goal ', 'push-to-goal player-01 ').replace('push-to-nongoal ', 'push-to-nongoal player-01 ')
                gen_plan_convert = gen_plan_convert.replace('east', 'right').replace('west', 'left').replace('south', 'down').replace('north', 'up')
                gen_plan_convert = re.sub(r'pos-(\d+)-(\d+)', r'pos-\2-\1', gen_plan_convert)
                print('converted plan', gen_plan_convert)
            elif 'package' in target_domain_pddl:
                gen_plan_convert = gen_plan.replace('pkg1', 'pkg-1').replace('pkg2', 'pkg-2').replace('package1', 'pkg-1').replace('package2', 'pkg-2')
            elif 'overcooked' in target_domain_pddl:
                gen_plan_convert = gen_plan.replace('tomato ', 'tomato1 ').replace('lettuce ', 'lettuce1 ').replace('board1 ', 'chopping-board1 ').replace('onion ', 'onion1 ')
            else:
                gen_plan_convert = gen_plan
            print(gen_plan_convert)
            is_plan_valid, val_message = pddl_env.validate_plan(target_domain_pddl, t_gt_p, gen_plan_convert)
            print(val_message)
            # pdb.set_trace()
            if is_plan_valid:
                print('plan valid')
                n_valids += 1
                valid.append(count)
        count+=1
    return (n_valids / len(target_domain_problem_pddls)), valid


def evaluate_all_tasks_random_walk(
        pddl_env: PDDLEnv,
        target_domain_pddl: str,
        target_domain_problem_pddls: List[str],
        target_gen_domain_pddl: str,
        target_gen_problem_pddls: List[str],
        other_task_image_paths,
        vlm_path: str,
        multimodal: bool,
        sequence: bool,
        exp_flags: ConfigDict,
):
    assert len(target_domain_problem_pddls) == len(target_gen_problem_pddls)
    t_to_gen_scores, gen_to_t_score = [], []
    dummy_pred_desc = """def describe_predicate(*args, **kwargs): return ("", "")"""
    # import pdb; pdb.set_trace()
    for t_gen_p, t_gt_p, img_path in zip(target_gen_problem_pddls, target_domain_problem_pddls, other_task_image_paths):
        task_evaluator = PlanningEvaluator(
            env=pddl_env, target_domain_pddl=target_domain_pddl, target_problem_pddl=t_gt_p,
            target_gen_problem_pddl=t_gen_p, rw_feedback=True, vlm_path=vlm_path, predicate_descriptor_py=dummy_pred_desc,
            exp_flags=exp_flags, bi_rw_feedback=True, multimodal=multimodal, sequence=sequence
        )
        _, t_to_gen_frac, gen_to_t_frac = task_evaluator.evaluate_generated_domain_with_random_walks(
            target_gen_domain_pddl, img_path, 0
        )
        # import pdb; pdb.set_trace()
        t_to_gen_scores.append(t_to_gen_frac)
        gen_to_t_score.append(gen_to_t_frac)
    t_to_gen_frac = mean(t_to_gen_scores)
    gen_to_t_frac = mean(gen_to_t_score)
    final_score = harmonic_mean(t_to_gen_frac, gen_to_t_frac)
    logging.info(f"Random walk scores on all tasks: {final_score}")

    return final_score, t_to_gen_frac, gen_to_t_frac


def _get_best_of_n_responses(gpt_client, planning_evaluator, pddl_obj, conv_id, user_input, n_completions, img_path, turn):
    if n_completions == 1:
        best_conv_id, gpt_output, _ = gpt_client.complete_one_chat(conv_id, user_input)
        # pdb.set_trace()
        planning_evaluation = planning_evaluator.rate_domain_modification(
            pddl_obj, gpt_output, img_path, turn
        )
        return best_conv_id, planning_evaluation, {"all_conv_ids": [best_conv_id],
                                                   "all_ratings": [planning_evaluation.rating]}
    else:
        conv_ids, gpt_outputs, _ = gpt_client.complete_n_chats(
            conv_id, user_input, n_completions, temp=STOCHASTIC_TEMPERATURE
        )
        all_evaluations = []
        best_evaluation = None
        best_conv_id = None
        for i in range(n_completions):
            gpt_output = gpt_outputs[i]
            planning_evaluation = planning_evaluator.rate_domain_modification(
                pddl_obj, gpt_output, img_path, turn
            )
            all_evaluations.append(planning_evaluation)
            logging.info(f"Rating for completion {i}: {planning_evaluation.rating}")
            if best_evaluation is None or planning_evaluation.rating > best_evaluation.rating:
                best_evaluation = planning_evaluation
                best_conv_id = conv_ids[i]
        return best_conv_id, best_evaluation, {"all_conv_ids": conv_ids,
                                               "all_ratings": [e.rating for e in all_evaluations]}
