from copy import deepcopy

import cv2
import gym
import numpy as np


class StickyActionEnv(gym.Wrapper):
    def __init__(self, env, p=0.25):
        super(StickyActionEnv, self).__init__(env)
        self.p = p
        self.last_action = 0

    def step(self, action):
        if np.random.uniform() < self.p:
            action = self.last_action

        self.last_action = action
        return self.env.step(action)

    def reset(self):
        self.last_action = 0
        return self.env.reset()
class RepeatActionEnv(gym.Wrapper):
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

    def step(self, action):
        reward, done = 0, False
        for t in range(4):
            state, r, done, info = self.env.step(action)
            reward += r
            if done:
                break
        return state, reward, done, info


class CellEnv(gym.Wrapper):
    prev_cell:np.ndarray
    curr_cell:np.ndarray
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)

    def imdownscale(self,state):
        raise NotImplementedError


    def reset(self):
        s=self.env.reset()
        self.prev_cell=self.curr_cell=self.imdownscale(s)
        
        return s

    def formal_cell_state(self,cell_state): # one-hot 8*11*8
        return np.transpose(np.eye(8,dtype=np.uint8)[np.concatenate(cell_state)].reshape((*cell_state.shape,-1)),(2,0,1))

    def update_info(self,s,r,d,info):
        pass

    def step(self, action):
        s,r,d,info = self.env.step(action)
        self.update_info(s,r,d,info)
        self.curr_cell=self.imdownscale(s)
        info["next_cell_state"]=self.formal_cell_state(self.curr_cell)
        info["cell_state"]=self.formal_cell_state(self.prev_cell)
        self.prev_cell=self.curr_cell
        return s,r,d,info


class PFCellEnv(CellEnv):
    def __init__(self, env):
        super().__init__(env)
        self.pf_room=11

    def imdownscale(self, state):
        actor_state = np.where(state[:, :, 0] == 228, 255, 0).astype(np.uint8)
        resized_actor = cv2.resize(actor_state, (8, 11), interpolation=cv2.INTER_AREA)
        resized = np.where(resized_actor > 0, 255, 0)
        img = ((resized / 256.0) * 8).astype(np.uint8)
        pf_room = self.pf_room
        room_indices = [(pf_room // 8) // 8 % 8, (pf_room // 8) % 8, pf_room % 8]
        img[1, 0:3] = np.array(room_indices, dtype=np.uint8)
        img[0, 0:3] = np.array(room_indices, dtype=np.uint8)
        return img

    def reset(self):
        self.pf_room = 11
        return super().reset()

    def update_info(self,s,r,d,info):
        self.pf_room = info['pf_roomxy'][0]

class UNoisePFCellEnv(PFCellEnv):
    def imdownscale(self, state):
        img=super().imdownscale(state)
        img[-1,:]=np.random.randint(0,8,size=(8,)).astype(np.uint8)
        return img

class TimeNoisePFCellEnv(PFCellEnv):
    def __init__(self, env):
        super().__init__(env)
        self.episode_step=0
    def reset(self):
        self.episode_step=0
        return super().reset()

    def update_info(self,s,r,d,info):
        self.episode_step+=1
        super().update_info(s,r,d,info)

    def get_time_idx(self):
        steps=self.episode_step
        indices=[]
        for i in range(8):
            indices.append(steps % 8)
            steps //= 8
        return np.array(list(reversed(indices)),dtype=np.uint8)


    def imdownscale(self, state):
        img = super().imdownscale(state)
        img[-1, :] = self.get_time_idx()
        return img

class PFS0RDEnv(gym.Wrapper): #sub-optimal path pruning
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)

    def reset(self):
        s=self.env.reset()
        return s

    def step(self,action):
        s, r, d, info = self.env.step(action)
        room=info["pf_roomxy"][0]
        y_pos=info["pf_roomxy"][2]
        if y_pos>MyPitfall.ground_y or room<11:
            d = True
        return s, r, d, info

class MyPitfall: ## Copied from the source code of Go-Explore
    TARGET_SHAPE = None
    MAX_PIX_VALUE = None
    #: The original width of the screen
    screen_width = 160
    #: The original height of the screen
    screen_height = 210
    #: A factor by which to multiply the width of the screen to account for the fact that pixels where assumed to be
    #: wider than they were tall when displayed on a television.
    x_repeat = 2
    #: The space, in pixels, on the top of the screen that displays information and that can not be reached by the
    #: player
    gui_size = 50
    #: A rough estimate of the y position of the ground on screen. Used to determine whether the player is above or
    #: below ground.
    ground_y = 70
    #: If the player moves more than this distance along the x-axis in a single frame, this is considered a room
    #: transition. Otherwise, large jumps in player position indicate player death and respawn.
    x_jump_threshold = 270
    #: If a score increase exceeds this value, we know that a treasure has been collected.
    treasure_collected_threshold = 100
    game_screen_height = screen_height - gui_size
    nb_rooms = 255
    attr_max = {'treasures': 32,
                'room': nb_rooms}

    @staticmethod
    def get_attr_max(name):
        if name == 'x':
            return MyPitfall.screen_width * MyPitfall.x_repeat
        elif name == 'y':
            return MyPitfall.game_screen_height
        else:
            return MyPitfall.attr_max[name]

class PitfallPos:
    def __init__(self,room,x,y,no_agent=False):
        self.room=room
        self.x=x
        self.y=y
        self.no_agent=no_agent

class PitfallEnv(gym.Wrapper):
    def __init__(self,env):
        gym.Wrapper.__init__(self,env)
        self.pf_pos=None
        self.visited_rooms = set()


    def reset(self):
        s=self.env.reset()
        self.pf_pos=None
        self.pos_from_unprocessed_state(s)
        self.visited_rooms.clear()
        return s

    def pos_from_unprocessed_state(self, unprocessed_state):
        # face_pixels = [(y, x * self.x_repeat) for y, x in face_pixels]
        result = set(zip(*np.where(unprocessed_state[MyPitfall.gui_size:, :, 0] == 228)))
        face_pixels = [(y, x * MyPitfall.x_repeat) for y, x in result]
        if len(face_pixels) == 0:
            assert self.pf_pos is not None, 'No face pixel and no previous pos'
            #self.pf_pos = PitfallPos(self.pf_pos.room, 0, 0)
            self.pf_pos.no_agent=True
            return 0
        y, x = np.mean(face_pixels, axis=0)
        room = 11
        # level = 0
        if self.pf_pos is not None:
            direction_x = np.clip(int((self.pf_pos.x - x) / MyPitfall.x_jump_threshold), -1, 1)
            if y < MyPitfall.ground_y:
                room = (self.pf_pos.room + direction_x) % MyPitfall.nb_rooms
            else:
                room = (self.pf_pos.room + direction_x * 3) % MyPitfall.nb_rooms
        else:
            self.pf_pos = PitfallPos(room, x, y,False)

        self.pf_pos.room = room
        self.pf_pos.x = x
        self.pf_pos.y = y
        self.pf_pos.no_agent=False
        return 0


    def step(self,action):
        s,r,d,info=self.env.step(action)
        self.pos_from_unprocessed_state(s)
        self.visited_rooms.add(self.pf_pos.room)
        #if d:
        if "episode" not in info:
            info["episode"] = {}
        info["episode"].update(visited_room=deepcopy(self.visited_rooms))
        info['pf_roomxy']=(self.pf_pos.room,self.pf_pos.x,self.pf_pos.y)

        return s,r,d,info

class MRVisitMapEnv(gym.Wrapper):
    def __init__(self,env):
        gym.Wrapper.__init__(self,env)
        #self.visit_map=np.zeros((24,15,10),dtype=int)
        #self.pos_x,self.pos_y,self.i_room=None,None,None
        self.visited_rooms = set()


    def reset(self):
        s=self.env.reset()
        #self.visit_map = np.zeros((24,15, 10),dtype=int)
        #self.pos_x, self.pos_y, self.i_room=self.get_roomxy()
        self.visited_rooms.clear()
        return s

    def get_room(self):
        ram = self.env.unwrapped.ale.getRAM()
        #pos_x = (int(ram[42]))//4
        #pos_y = (360 - int(ram[43]))//8
        i_room = ram[3]
        return i_room

    def step(self,action):
        s,r,d,info=self.env.step(action)
        self.i_room=self.get_room()
        self.visited_rooms.add(self.i_room)
        info.update({
            "mr_room":self.i_room
        })
        #info.update({"roomxy":(self.pos_x,self.pos_y,self.i_room)})
        #self.visit_map[self.i_room,self.pos_y//2,self.pos_x//4]+=1
        if "episode" not in info:
            info["episode"] = {}
        info["episode"].update(visited_room=deepcopy(self.visited_rooms))
        return s,r,d,info


class MRS1RDEnv(CellEnv):
    def __init__(self, env,max_room_step=600):
        self.curr_room_max_step=max_room_step
        gym.Wrapper.__init__(self, env)
        self.level_path_dict={
            1:{
                "path":[1,2,6,7,13,14, 22,23,22,21,13,12, 11,19,20,19,18,17, 16,15,1],
                "req_score":[2,0,1,1,1,1, 0,3,0,0,0,0, 0,1,2,0,0,2, 0,6]
            },
            2:{
                "path":[1,2,6,7,13,14, 22,23,22,21,13,12, 11,19,20,19,18,17, 16,15,1],
                "req_score":[2,1,0,1,0,1, 0,2,0,1,0,0, 1,0,3,0,0,2, 0,6]
            },
            3:{
                "path":[1,2,6,7,13,14, 22,23,22,21,13,12, 11,19,20,19,18,17, 16,15,1],
                "req_score":[2,0,0,1,0,1, 0,3,0,0,0,0, 0,0,1,0,0,2, 0,6]
            }
        }

        self.level_path_dict_fastest={
            1:{
                "path":[1,2,6,7,13,14, 22,21, 13,12,11,19,18,17, 16,15,1],
                "req_score":[2,0,1,1,1,1, 0,0, 0,0,0,1,0,2, 0,6],
            },
            2:{
                "path":[1,2,6,7,13,14, 22,21,13,12,11,19, 18,17,16,15,1],
                "req_score":[2,1,0,1,0,1, 0,1,0,0,1,0, 0,2,0,6],
            },
            3:{
                "path":[1,2,6,7,13,12, 11,19,20,19,18,17, 16,15,1],
                "req_score":[2,0,0,1,0,0, 0,0,1,0,0,2, 0,6],
            }
        }
        self.reset_status()

    def imdownscale(self, state):
        get_idx=lambda x:[((x//8) // 8) // 8 % 8, (x // 8) // 8 % 8, (x // 8) % 8, x % 8]
        actor_state = np.where(state[:, :, 0] == 228, 255, 0).astype(np.uint8)
        resized_actor = cv2.resize(actor_state, (8, 11), interpolation=cv2.INTER_AREA)
        resized = np.where(resized_actor > 0, 255, 0)
        img = ((resized / 256.0) * 8).astype(np.uint8)
        room_indices =get_idx(self.curr_room)
        score_indices= get_idx(self.curr_reward)
        level_indices=get_idx(self.curr_level)
        img[1, 0:4] = np.array(room_indices, dtype=np.uint8)
        img[0, 0:4] = np.array(score_indices, dtype=np.uint8)
        img[0, 4:8] = np.array(level_indices, dtype=np.uint8)
        return img


    def reset(self):
        s=super().reset()
        self.reset_status()
        return s

    def reset_status(self):
        self.curr_level=1
        self.curr_reward=0
        self.curr_room_reward=0
        self.curr_level_room_idx=0
        self.curr_room=1
        self.achieve_room1_right=0
        self.curr_room_remain = self.curr_room_max_step

    def update_status(self,room,score):
        done=False
        pruned=False
        self.curr_room_reward+=np.sign(score)
        map_level=self.curr_level if self.curr_level<3 else 3
        path=self.level_path_dict_fastest[map_level]["path"]
        req_score = self.level_path_dict_fastest[map_level]["req_score"]
        if room==self.curr_room:
            self.curr_room_remain-=1
            if self.curr_room_remain<=0:
                self.curr_room_remain=self.curr_room_max_step
                done=True
                pruned=True
        elif room==path[self.curr_level_room_idx+1]:
            self.curr_room_remain=self.curr_room_max_step
            self.achieve_room1_right=0
            if self.curr_room_reward>=req_score[self.curr_level_room_idx]:
                self.curr_room_reward=0
                self.curr_room=room
                self.curr_level_room_idx+=1
                if self.curr_level_room_idx==len(path)-1:
                    self.curr_level+=1
                    self.curr_level_room_idx=0
            else:
                done=True
        else:
            # print(room,path[self.curr_level_room_idx+1],self.curr_room_remain,pruned)
            self.achieve_room1_right=0
            self.curr_room_remain = self.curr_room_max_step
            done=True
        if room!=15:
            self.curr_reward+=np.sign(score)
        if room==1:
            if np.sum(self.curr_cell[2:5,6:8])>0:
                self.achieve_room1_right=1
            if self.achieve_room1_right and not(np.sum(self.curr_cell[2:5,6:8])>0):
                done=True
        if room==1 and np.sum(self.curr_cell[2:4,1:4])>0:
            done=True
        return done,pruned


    def process_done(self,d,d0,d1):
        return d or d0 or d1

    def step(self,action):
        s, r, d, info = self.env.step(action)
        room=info["mr_room"]
        d0,pruned=self.update_status(room,r)
        d1=(info['ale.lives'] < 6)
        d=self.process_done(d,d0,d1)
        self.curr_cell = self.imdownscale(s)
        info["next_cell_state"] = self.formal_cell_state(self.curr_cell)
        info["cell_state"] = self.formal_cell_state(self.prev_cell)
        self.prev_cell = self.curr_cell
        info.update({
            "mr_reward_scale":1.0 if room!=15 else 0.1,
            "curr_level":self.curr_level,
            "curr_score":self.curr_reward,
        })
        info.update({
            "pruned":pruned
        })
        return s, r, d, info

class MRS1PostEnv(MRS1RDEnv):
    def process_obs(self,s):
        s_=s*1
        s_[:16,:,:]=s_[:16,:,:]*0
        # import matplotlib.pyplot as plt
        # plt.imshow(s)
        # plt.show()
        # input(" dew")
        return s_

    def process_done(self,d,d0,d1):
        return d

    def step(self,action):
        s, r, d, info = super().step(action)
        info["sla_origin_obs"] = s
        s_=self.process_obs(s) if "curr_level" in info and info["curr_level"]>=3 else s

        return s_, r, d, info


class MRCellEnv(CellEnv):
    mr_room:int
    mr_reward:int

    def __init__(self, env):
        super().__init__(env)
        self.mr_room=1
        self.mr_reward=0

    def imdownscale(self, state):
        get_idx=lambda x:[(x // 8) // 8 % 8, (x // 8) % 8, x % 8]
        actor_state = np.where(state[:, :, 0] == 228, 255, 0).astype(np.uint8)
        resized_actor = cv2.resize(actor_state, (8, 11), interpolation=cv2.INTER_AREA)
        resized = np.where(resized_actor > 0, 255, 0)
        img = ((resized / 256.0) * 8).astype(np.uint8)
        room_indices =get_idx(self.mr_room)
        score_indices= get_idx(self.mr_reward)
        img[1, 0:3] = np.array(room_indices, dtype=np.uint8)
        img[0, 0:3] = np.array(score_indices, dtype=np.uint8)
        return img

    def reset(self):
        self.mr_room = 1
        self.mr_reward=0
        return super().reset()

    def update_info(self,s,r,d,info):
        self.mr_room = info['mr_room']
        if r>0:
            self.mr_reward+=1

def make_atari(config):
    env_id,max_episode_steps=config.env_name,config.max_frames_per_episode
    env = gym.make(env_id)
    env._max_episode_steps = max_episode_steps * 4
    assert 'NoFrameskip' in env.spec.id
    env = RepeatActionEnv(env)
    if "Montezuma" in env_id:
        env = MRVisitMapEnv(env)
        if "mr_stype" in config.env_config:
            if config.env_config["mr_stype"] == 1:
                if "post_process" in config.env_config and config.env_config["post_process"]:
                    env=MRS1PostEnv(env,config.env_config.get("max_room_step",600))
                else:
                    env = MRS1RDEnv(env, config.env_config.get("max_room_step", 600))
                #print("MRS1")
    elif "Pitfall" in env_id:
        env = PitfallEnv(env)
        if "sticky_action" in config.env_config and config.env_config["sticky_action"]:
            env=StickyActionEnv(env)

        if "noise_style" in config.env_config:
            if config.env_config["noise_style"] == 0:
                env=UNoisePFCellEnv(env)
            elif config.env_config["noise_style"] == 1:
                env = TimeNoisePFCellEnv(env)
            else:
                raise NotImplementedError
        else:
            env=PFCellEnv(env)
        if "pf_stype" in config.env_config:
            if config.env_config["pf_stype"] == 0:
                env=PFS0RDEnv(env)
    else:
        raise NotImplementedError


    return env

