from __future__ import annotations
import numpy as np
from llm.constant import *

from argparse import Namespace
from prog_policies.base import dsl_nodes, BaseTask, BaseDSL

from prog_policies.karel.dsl import KarelDSL


class PromptGenerator:
    def __init__(
        self,
        seed: int,
        task: str,
    ) -> None:
        self.seed: int = seed
        self.task: str = task
        self.np_rng = np.random.RandomState(self.seed)

    def _karel_env_desc_python(self) -> str:
        return KAREL_ENV_DESC_PYTHON
    
    def _karel_env_desc_dsl(self) -> str:
        return KAREL_ENV_DESC_DSL

    def _action_desc_python(self) -> str:
        action_desc = "Here are the available actions for the agent:\n"
        for single_action_desc in ACTION_DESC_LIST_PYTHON:
            action_desc += single_action_desc + "\n"
        return action_desc

    def _preception_desc_python(self) -> str:
        perception_desc = "Here are the available perceptions of the agent:\n"
        for single_perception_desc in PERCEPTION_DESC_LIST_PYTHON:
            perception_desc += single_perception_desc + "\n"
        return perception_desc
    
    def _action_desc_dsl(self) -> str:
        action_desc = "Here are the available actions for the agent:\n"
        for single_action_desc in ACTION_DESC_LIST_DSL:
            action_desc += single_action_desc + "\n"
        return action_desc

    def _preception_desc_dsl(self) -> str:
        perception_desc = "Here are the available perceptions of the agent:\n"
        for single_perception_desc in PERCEPTION_DESC_LIST_DSL:
            perception_desc += single_perception_desc + "\n"
        return perception_desc

    def _dsl_desc(self) -> str:
        return DSL_DESC


    def _python_limitation(self) -> str:
        return PYTHON_LIMITATION

    def _python_to_karel(self) -> str:
        return PYTHON_TO_KAREL

    def get_system_prompt_python_to_dsl(self) -> str:
        system_prompt = (
            self._karel_env_desc_python()
            + "\n"
            + self._action_desc_python()
            + "\n"
            + self._preception_desc_python()
            + "\n"
            + self._python_limitation()
            + "\n"
            + self._python_to_karel()
        )
        return system_prompt

    def get_user_prompt_python_to_dsl(self) -> str:
        user_prompt = USER_PROMPT_TEMPLATE_PYTHON_TO_DSL.replace("<<task_name>>", self.task.upper()) \
                                                        .replace("<<task_map_desc>>", TASK_MAP_DESC[self.task]) \
                                                        .replace("<<task_agent_position_desc>>", TASK_AGENT_POSITION_DESC[self.task]) \
                                                        .replace("<<task_goal_desc>>", TASK_GOAL_DESC[self.task]) \
                                                        .replace("<<task_return_desc>>", TASK_RETURN_DESC[self.task])
        return user_prompt
    
    def get_system_prompt_python(self) -> str:
        system_prompt = (
            self._karel_env_desc_python()
            + "\n"
            + self._action_desc_python()
            + "\n"
            + self._preception_desc_python()
            + "\n"
            + self._python_limitation()
        )
        return system_prompt
    
    def get_user_prompt_python(self) -> str:
        user_prompt = USER_PROMPT_TEMPLATE_PYTHON.replace("<<task_name>>", self.task.upper()) \
                                                 .replace("<<task_map_desc>>", TASK_MAP_DESC[self.task]) \
                                                 .replace("<<task_agent_position_desc>>", TASK_AGENT_POSITION_DESC[self.task]) \
                                                 .replace("<<task_goal_desc>>", TASK_GOAL_DESC[self.task]) \
                                                 .replace("<<task_return_desc>>", TASK_RETURN_DESC[self.task])
        return user_prompt
    
    def get_system_prompt_dsl(self) -> str:
        system_prompt = (
            self._karel_env_desc_dsl()
            + "\n"
            + self._action_desc_dsl()
            + "\n"
            + self._preception_desc_dsl()
            + "\n"
            + self._dsl_desc()
        )
        return system_prompt

    def get_user_prompt_dsl(self) -> str:
        user_prompt = USER_PROMPT_TEMPLATE_DSL.replace("<<task_name>>", self.task.upper()) \
                                              .replace("<<task_map_desc>>", TASK_MAP_DESC[self.task]) \
                                              .replace("<<task_agent_position_desc>>", TASK_AGENT_POSITION_DESC[self.task]) \
                                              .replace("<<task_goal_desc>>", TASK_GOAL_DESC[self.task]) \
                                              .replace("<<task_return_desc>>", TASK_RETURN_DESC[self.task])
        return user_prompt