from Common.config.config_general import Config
from Common.env_utils.env_make_utils import make_env

from Common.utils import *


class Worker:
    def __init__(self, id, config:Config):
        self.id = id
        self.config = config
        self.env_name = self.config.env_name
        self.max_episode_steps = self.config.max_frames_per_episode
        self.state_shape = self.config.state_shape
        self.env = make_env(config)
        if hasattr(config,"seed") and config.seed is not None:
            self.env.seed(config.seed+id)

        self.state4policy = np.zeros(self.state_shape, dtype=np.uint8)
        self.n_actions=self.config.n_actions
        self.episode_id=-1
        self.curr_state = None
        self.reset()
        self.erir_dict=dict()



    def __str__(self):
        return str(self.id)

    def render(self):
        self.env.render()

    def update_state4policy(self,state,is_initial=False):
        if self.config.env_type=="Atari":
            self.state4policy = stack_states(self.state4policy, state, is_initial)
        elif self.config.env_type =="MiniGrid":
            self.state4policy=state
        else:
            raise NotImplementedError
        self.curr_state=state

    def get_state4rm(self):
        if self.config.env_type=="Atari":
            return self.state4policy[-1:,...]
        elif self.config.env_type =="MiniGrid":
            return self.state4policy
        else:
            raise NotImplementedError

    def get_state_key(self):
        state_key=self.get_state4rm() if self.config.env_type!="MiniGrid" else self.env.get_full_obs()
        return tuple(state_key.reshape(-1).tolist())

    def update_erir(self):
        if self.config.env_type=="Atari":
            return 1,1
        episode_state_key = self.get_state_key()
        if episode_state_key in self.erir_dict:
            self.erir_dict[episode_state_key] += 1
            return 0,self.erir_dict[episode_state_key]
        else:
            self.erir_dict.update({episode_state_key: 1})
            return 1,self.erir_dict[episode_state_key]

    def reset(self):
        state = self.env.reset()
        self.update_state4policy(state,is_initial=True)

        episode_state_key = self.get_state_key()
        self.erir_dict=dict()
        self.erir_dict.update({episode_state_key: 1})
        self.episode_id+=1

    def step_extra(self,conn):
        pass

    def recv_action(self,conn):
        action=conn.recv()
        return action

    def step_env(self,conn,t):
        action = self.recv_action(conn)
        next_state, r, d, info = self.env.step(action)
        info['gt_reward'] = r
        info['episode_id'] = self.episode_id
        self.update_state4policy(next_state)
        info['erir_mask'], info["eps_nvisit"] = self.update_erir()
        info['episode_step'] = t
        return self.step_env_extra(next_state, r, d, info)

    def step_env_extra(self,next_state,r,d,info):
        return next_state, r, d, info

    def step_s(self,conn):
        conn.send(self.state4policy)

    def step(self, conn):
        t = 1
        while True:
            self.step_s(conn)
            next_state,r,d,info=self.step_env(conn,t)
            t += 1
            if t % self.max_episode_steps == 0:
                d = True
            conn.send((self.state4policy, max(0,np.sign(r)) if not self.config.use_origin_score else r, d, info))
            self.step_extra(conn)
            if d:
                self.reset()
                t = 1
