import tiktoken
gpt2_enc = tiktoken.encoding_for_model("text-davinci-003")


def truncate_scratchpad(scratchpad: str, n_tokens: int = 1600, tokenizer=gpt2_enc) -> str:
    lines = scratchpad.split('\n')
    observations = filter(lambda x: x.startswith('Observation'), lines)
    observations_by_tokens = sorted(
        observations, key=lambda x: len(tokenizer.encode(x)))
    while len(gpt2_enc.encode('\n'.join(lines))) > n_tokens:
        largest_observation = observations_by_tokens.pop(-1)
        ind = lines.index(largest_observation)
        lines[ind] = largest_observation.split(
            ':')[0] + ': [truncated wikipedia excerpt]'
    return '\n'.join(lines)


class Direct:
    def __init__(self, llm_model, env) -> None:
        self.llm = llm_model
        self.answer: str = ""
        self.scratchpad: str = ""
        self.json_log = []
        self.is_succ = False
        print("DirectAgent Loaded")

    def run(self, prompt, reset=True):
        if reset:
            self.reset()
        print("Direct Question:", prompt)
        answer = self.llm({"role": "user", "content": prompt})
        if "```json" in answer:
            json_start_pos = answer.find("```json")
            json_end_pos = answer.rfind("```")
            json_str = answer[json_start_pos + 7:json_end_pos]
        self.scratchpad += f'\nAnswer: {answer}'
        self.json_log.append({"role": "user", "content": prompt})
        self.json_log.append({"role": "assistant", "content": answer})
        self.answer = answer
        self.is_succ = True
        return answer, self.scratchpad, self.json_log, 1

    def reset(self) -> None:
        self.answer = ""
        self.scratchpad = ""
        self.json_log = []
        self.is_succ = False

    def is_success(self) -> bool:
        return self.is_succ
class ReactPlanner:
    def __init__(self,
                 # args,
                 llm_model,
                 env,
                 max_steps=30,
                 need_print=True
                 ) -> None:
        self.llm = llm_model
        self.env = env
        self.max_steps = max_steps
        self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
        self.prompt: str = ""
        self.scratchpad: str = ""
        self.curr_step: int = 1
        self.answer: str = ""
        self.json_log = []
        self.json_scratchpad = []
        self.need_print = need_print
        # print("ReactPlannerAgent Loaded")

    def run(self, prompt, reset=True):
        if reset:
            self.reset()
        self.prompt = prompt
        self.json_scratchpad.append({"role": "user", "content": prompt})
        while not (self.is_halted() or self.is_finished()):
            self.step()
        self.answer = self.env.get_ans()
        self.curr_step -= 1
        return self.answer, self.scratchpad, self.json_log, self.curr_step

    def step(self) -> None:
        # Think
        self.scratchpad += f'\nThought[{self.curr_step}]:'

        self.json_scratchpad.append(
            {"role": "user", "content": f'Thought[{self.curr_step}]: '})

        # thought = self.llm(self.prompt+self.scratchpad)
        thought = self.llm(self.json_scratchpad)
        self.scratchpad += ' ' + thought

        self.json_scratchpad.append(
            {"role": "assistant", "content": thought})
        if self.need_print:
            print(f"Thought[{self.curr_step}]:", thought)
        self.json_log.append({f'Thought[{self.curr_step}]': thought})

        # Act
        self.scratchpad += f'\nAction[{self.curr_step}]:'

        self.json_scratchpad.append(
            {"role": "user", "content": f'Action[{self.curr_step}]: '})
        # action = self.llm(self.prompt+self.scratchpad)
        action = self.llm(self.json_scratchpad)
        self.scratchpad += ' ' + str(action)
        if self.need_print:
            print(f"Action[{self.curr_step}]:", str(action))
        self.json_scratchpad.append(
            {"role": "assistant", "content": str(action)})
        self.json_log.append({f'Action[{self.curr_step}]': str(action)})

        # Observation
        self.scratchpad += f'\nObservation[{self.curr_step}]:'

        # todo: env parse action
        observation = self.env.run(action)
        # ob_str = "\n"
        # for i in range(len(observation)):
        #     self.json_scratchpad.append(
        #         {"role": "tool", "tool_call_id": observation[i].tool_call_id, "content": str(observation[i])})
        self.json_scratchpad.append(
            {"role": "user", "content": f'Observation[{self.curr_step}]: '+str(observation)})
        observation = str(observation)
        self.scratchpad += '\n' + observation
        if self.need_print:
            print(f"Observation[{self.curr_step}]:", observation)
        self.json_log.append(
            {f'Observation[{self.curr_step}]': observation})
        # todo: env observation

        # Update
        self.curr_step += 1

    def reset(self) -> None:
        self.prompt = ""
        self.scratchpad = ""
        self.curr_step = 1
        self.finished = False
        self.answer = ""
        self.json_log = []
        self.json_scratchpad = []
        self.env.reset()

    def is_finished(self) -> bool:
        return self.env.is_finished()

    def is_halted(self) -> bool:
        return self.curr_step > self.max_steps
        # or (len(self.enc.encode(self.prompt+self.scratchpad)) > 20000 and not self.finished)

    def is_success(self) -> bool:
        return self.env.is_success()


class ActPlanner:
    def __init__(self,
                 llm_model,
                 env,
                 max_steps=30,
                 need_print=True
                 ) -> None:
        self.llm = llm_model
        self.env = env
        self.max_steps = max_steps
        self.enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
        self.prompt: str = ""
        self.scratchpad: str = ""
        self.curr_step: int = 1
        self.answer: str = ""
        self.json_log = []
        self.json_scratchpad = []
        self.need_print = need_print
        # print("ReactPlannerAgent Loaded")

    def run(self, prompt, reset=True):
        if reset:
            self.reset()
        self.prompt = prompt
        self.json_scratchpad.append({"role": "user", "content": prompt})
        while not (self.is_halted() or self.is_finished()):
            self.step()
        self.answer = self.env.get_ans()
        self.curr_step -= 1
        return self.answer, self.scratchpad, self.json_log, self.curr_step

    def step(self) -> None:

        # Act
        self.scratchpad += f'\nAction[{self.curr_step}]:'

        self.json_scratchpad.append(
            {"role": "user", "content": f'Action[{self.curr_step}]: '})
        # action = self.llm(self.prompt+self.scratchpad)
        action = self.llm(self.json_scratchpad)
        self.scratchpad += ' ' + str(action)
        if self.need_print:
            print(f"Action[{self.curr_step}]:", str(action))
        self.json_scratchpad.append(
            {"role": "assistant", "content": str(action)})
        self.json_log.append({f'Action[{self.curr_step}]': str(action)})
        # Observation
        self.scratchpad += f'\nObservation[{self.curr_step}]:'
        observation = self.env.run(action)
        self.json_scratchpad.append(
            {"role": "user", "content": f'Observation[{self.curr_step}]: '+str(observation)})
        observation = str(observation)
        self.scratchpad += '\n' + observation
        if self.need_print:
            print(f"Observation[{self.curr_step}]:", observation)
        self.json_log.append(
            {f'Observation[{self.curr_step}]': observation})
        # Update
        self.curr_step += 1

    def reset(self) -> None:
        self.prompt = ""
        self.scratchpad = ""
        self.curr_step = 1
        self.finished = False
        self.answer = ""
        self.json_log = []
        self.json_scratchpad = []
        self.env.reset()

    def is_finished(self) -> bool:
        return self.env.is_finished()

    def is_halted(self) -> bool:
        return self.curr_step > self.max_steps
        # or (len(self.enc.encode(self.prompt+self.scratchpad)) > 20000 and not self.finished)

    def is_success(self) -> bool:
        return self.env.is_success()
