import gymnasium as gym

class Lander(gym.Env):
    
    def __init__(self, render_mode = None):
        
        self.env = gym.make("LunarLander-v3", render_mode = render_mode)
        
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.last_state = None
        
    def step(self, action):
        state, reward, done, trunc, info = self.env.step(action)
        
        if self.goal is not None:
            if self.goal.goal.predicate(state):
                info['is_success'] = True
                done = True
                reward+= 100
            else:
                info['is_success'] = False
                reward+= self.goal.goal.reward(state, self.last_state)
        self.last_state = state
        return state, reward, done, trunc, info
    
    def reset(self, seed = None):
        curr_goal = self.goal
        self.goal = None
        while True:
            state, info = self.env.reset()
            done = False
            trunc = False
            found = False
            if len(self.policies) > 0:
                for policy, goal in self.policies:
                    while not goal.goal.predicate(state) and done is False and trunc is False:                    
                        action, _ = policy.predict(state)
                        state, _, done, trunc, _ = self.step(action)
                        if done or trunc or goal.goal.predicate(state):
                            break
                    
                    if goal.goal.predicate(state):
                        found = True
                        break
                    if done or trunc:
                        break
                    
            else:
                break
            if found:
                break
        self.goal = curr_goal
        self.goal.goal.reset()

        self.last_state = state
        return state, info
    
    def set_abstract_states(self, start, goal, avoid, policies):
        self.start = start
        self.goal = goal
        self.avoid = avoid
        self.policies = policies
        
    def render(self):
        self.env.render()