import re
import os
import sympy
import pandas as pd
import copy
from tot.tasks.base import Task, DATA_PATH
from prompts import prompt_without_history_v2, prompt_without_history_v2_propose_prompt
import sys
sys.path.append('/export/home/code/lam_mcts/tot')
from utils import *
from prompts import *

class Blocksworld(Task):
    def __init__(self, step_count):
        super().__init__()
        self.data = get_blocksworld_data(step_count)
        self.value_cache = {}
        self.steps = step_count
        self.stops = None
        self.x = None
        self.f = None
        self.item = None
        self.final_block_config = None

    def __len__(self) -> int:
        return len(self.data)

    def get_input(self, idx: int) -> str:
        item = self.data[idx]
        self.item = item
        init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])
        final_block_config = copy.deepcopy(init_block_config)
        gt_action_sequence = real_solution2text(item['real_solution'])
        for action in gt_action_sequence:
            final_block_config, valid_action = add_action_to_json_state(final_block_config, action)
            assert valid_action is True, 'Cannot reach final block config'

        self.x = state_json2text(init_block_config)
        self.f = state_json2text(final_block_config)
        self.final_block_config = final_block_config
        
        return self.x, self.f

    @staticmethod
    def standard_prompt_wrap(x: str, y:str='') -> str:
        raise NotImplementedError("standard_prompt_wrap not implemented")

    @staticmethod
    def cot_prompt_wrap(x: str, f: str, y:str='') -> str:
        return prompt_without_history_v2(x, f) + y

    @staticmethod
    def propose_prompt_wrap(x: str, f: str, y: str='') -> str:
        # print('prompt_without_history_v2_propose_prompt')
        psa = prompt_without_history_v2_propose_prompt(x, f, y)
        # print(psa)
        return psa


    @staticmethod
    def value_prompt_wrap(x: str, f: str, y: str) -> str:
        return prompt_without_history_v2_value_last_step_prompt(x, f, y)

    @staticmethod
    def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
        # print('value outputs:\n')
        # print(value_outputs)
        value_names = [_.split('\n')[-1] for _ in value_outputs]
        # print('value_names:\n')
        # print(value_names)
        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
        value = sum(value * value_names.count(name) for name, value in value_map.items())
        # print(value)
        return value
        
        
        
        