from sympy import im
from RoboMemory.BaseModules.BaseEnv import BaseEnv
from RoboMemory.APIs import AlfredAPI
from RoboMemory.APIs.APIprocesser import APIProcesser        
from RoboMemory.Envs.EmbodiedBenchClient import EmbodiedEnvClient
from typing import Literal
from RoboMemory.agent_utils import save_image_from_base64
from RoboMemory.env_utils import get_objects, ALFRED_ACTION_TYPES, ALFREDValidEvalSets


class ALFREDEnv(BaseEnv):
    
    def __init__(
            self, 
            selected_index : int,
            eval_set_index : int,
            base64_image = True, 
            image_save_path = None,
            detection_box = False,
            resolution = 500,
            down_sample_ratio = 1.0,
            base_url = "",
            port = 12435,
            
            # API Processor   
            function_filter_list = []
        ):
        super().__init__(base64_image, image_save_path)
        
        self.action_types = ALFRED_ACTION_TYPES
        
        self.objects = None
        
        self.eval_set = ALFREDValidEvalSets[eval_set_index]
        
        env_config = {
            "env_type" : "EB-ALFRED",
            "eval_set" : self.eval_set, 
            "exp_name" : f"{self.eval_set}_{selected_index}", #    log        base_1   
            "down_sample_ratio": down_sample_ratio, 
            "selected_indexes" : [selected_index], #          
            "detection_box": detection_box, 
            "resolution": resolution
        }
        
        self.client = EmbodiedEnvClient(
            env_config = env_config,
            base_url = base_url,
            port = port
        )
        self.actions : dict = None #       , key     value      
        self.task : str = "" #            
        
        self.api_processor = APIProcesser(api_file_path= AlfredAPI.__file__, function_filter_list = function_filter_list)
        
    def __action_to_dict(self, actions : list[str]) -> dict[int]:
  
        action_dict = {}
        for i, action in enumerate(actions):
            action_dict[action.casefold()] = i #    casefold                
        
        return action_dict
        
    def __preprocess_image(self, image_base64):
   
        if self.base64_image:
            self.obs_img = image_base64 # save for do not call description redundently
            return image_base64
        else:
            if self.image_save_path == None:
               raise ValueError("If you need to save image, you have to provide image saving path!!!")
            image_name = save_image_from_base64(image_base64, self.image_save_path)
            self.obs_img = image_name

            return image_name

    def render(self) -> bytes|str:
        
        image_base64 = self.client.render()
        
        return self.__preprocess_image(image_base64)

    def step(self, fn_call : str, subgoal : str = None) -> tuple:
    
        try:
            instruction_string : str = eval(f"AlfredAPI.{fn_call}")
            instruction_string = instruction_string.casefold() # to lower case
            # instruction_string = 'find a Cart' # mock a action
            try:
                instruction_id = self.actions[instruction_string]
            except:
                raise ValueError(f"{instruction_string} is not in the list")
            
            obs, score, done, feedback = self.client.step(instruction_id)
            
            image = self.__preprocess_image(obs)
            
        except:
            # handle error
            error_message = "action or object dose not exist."
            print(error_message)
            feedback = {
                "env_feedback" : error_message, 
                'task_success': 0.0, 
                'task_progress': 0.0
            }
            image = self.render() # in most of cases!
            
            score = 0
            done =False
            
        truncated = done
            
        return image, score, done, truncated, feedback
    
    def reset(self) -> tuple:
 
        obs, score, done, info = self.client.reset()
        
        self.actions = self.__action_to_dict(info["skill_set"]) #    skill
        
        self.objects = get_objects(self.action_types, info["skill_set"])
        
        self.task = info['instruction'] #     
        
        return obs, score, done, info

    def close(self) -> dict[str, bool]:
  
        return self.client.close()


