from .gymma import GymmaWrapper
from prompts.util import *
from utils.logging import get_logger
# from .multiagentenv import MultiAgentEnv
logger = get_logger()

class LLMWrapper(GymmaWrapper):
    def __init__(self,
        key,
        time_limit,
        pretrained_wrapper,
        seed,
        common_reward,
        reward_scalarisation,
        compute_reward_fn,
        planning_fn,
        reward_mode,
        env_name,
        **kwargs,
    ):
        super().__init__(key, time_limit,
                        pretrained_wrapper,
                        seed,
                        common_reward, 
                        reward_scalarisation,
                        **kwargs)
        self.reward_mode = reward_mode
        self.set_func(planning_fn, compute_reward_fn)
        self.env_name = env_name

    def compute_reward(self, observations, tasks):
        raise NotImplementedError
    
    def planning_function(self, observations):
        raise NotImplementedError
    
    def set_func(self, planning_function, compute_reward):
        self.planning_function = planning_function
        self.compute_reward = compute_reward
        logger.critical(f"Planning function and compute reward function({self.reward_mode} mode) set")
    
    def step_train(self, actions):
        prev_obs = self._obs
        obs, r, done, truncated, info = super().step(actions)
        # print("reward", r)

        dirname = self.env_name.split("_")[0]
        process_state =  import_function(
                f"prompts.env_code.{dirname}.processed_obs_{self.env_name}", "process_state")

        processed_obs = process_state(prev_obs)
        actions_ = convert_actions(actions)
        llm_task = self.planning_function(processed_obs)

        if dirname  == "lbf":
            llm_actions = lbf_task_to_actions(llm_task, processed_obs)
        elif dirname  == "mpe":
            llm_actions = mpe_task_to_actions(llm_task, processed_obs)
        else:
            raise NotImplementedError

        if  self.reward_mode == "pure":
            # pure llm reward
            reward_dict = self.compute_reward(processed_obs, actions_)
            reward = float(sum(reward_dict.values()))
        elif self.reward_mode == "mixed_constant":
            # original reward + llm constant aligned reward
            reward_dict = constant_reward_signal(
                actions_, llm_actions, llm_reward=0.4, penalty=0.1)
            reward = float(sum(reward_dict.values()))+r
        elif self.reward_mode == "mixed_normalized":
            # original reward + llm normalized code gen reward
            reward_dict = self.compute_reward(processed_obs, llm_actions, actions_)
            reward = normalized_reward(reward_dict, theta=0.01)
            reward = float(sum(reward_dict.values())) + r
        else:
            raise NotImplementedError
    
        return obs, reward, done, truncated, info

    def step_eval(self, actions):
        obs, reward, done, truncated, info = super().step(actions)
        return obs, reward, done, truncated, info