import re
import os
import sympy
import pandas as pd
import copy
from tot.tasks.base import Task, DATA_PATH
import sys
sys.path.append('/export/home/code/lam_mcts/tot')
from utils import *
from prompts import *
import json


def get_gsm8k_dataset(dataset_name='gsm8k'):


    # read few shot exmaples
    # with open('./few_shot_3.jsonl') as f:
    #     few_shot_samples = [json.loads(line) for line in f]

    # read GSM8K test dataset
    with open('/export/home/code/REX_ref/gsm8k/gsm_test_all.jsonl') as f:
        data = [json.loads(line) for line in f]

    # few_shot_samples_list = []
    # for i in few_shot_samples:
    #     q = i['question']
    #     a = re.sub('<<.*?>>', '', i['answer'])
    #     a_lines = a.split('\n')
    #     steps = a_lines[:-1]
    #     final_ans = a_lines[-1].split('####')[1].strip()
    #     few_shot_samples_list.append((q, steps, final_ans))

    data_listt = []
    for i in data:
        q = i['question']
        a = re.sub('<<.*?>>', '', i['answer'])
        a_lines = a.split('\n')
        steps = a_lines[:-1]
        final_ans = a_lines[-1].split('####')[1].strip()
        data_listt.append([q, steps, final_ans])

    return data_listt

class GSM8K(Task):
    def __init__(self):
        super().__init__()
        self.data = get_gsm8k_dataset()
        self.value_cache = {}
        self.steps = None
        self.stops = None
        self.x = None
        self.f = None
        self.item = None
        self.final_block_config = None

    def __len__(self) -> int:
        return len(self.data)

    def get_input(self, idx: int) -> str:
        item = self.data[idx]
        self.question = item[0]
        self.int_steps = item[1]
        self.answer = item[2]
        return self.question, self.int_steps, self.answer

    @staticmethod
    def standard_prompt_wrap(x: str, y:str='') -> str:
        raise NotImplementedError("standard_prompt_wrap not implemented")

    @staticmethod
    def cot_prompt_wrap(x: str, f: str, y:str='') -> str:
        return prompt_without_history_v2(x, f) + y

    @staticmethod
    def propose_prompt_wrap(q: str, y: str='') -> str:
        psa = prompt_with_history_gsm8k_propose(q, y)
        return psa


    @staticmethod
    def value_prompt_wrap(q: str, y: str) -> str:
        return prompt_with_history_gsm8k_value(q, y)

    @staticmethod
    def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:

        value_names = [_.split('\n')[-1] for _ in value_outputs]

        value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  # TODO: ad hoc
        value = sum(value * value_names.count(name) for name, value in value_map.items())
        # print(value)
        return value
        
        
        
        