import os
import pickle
import random
import warnings
import math
import gym
import cv2

from abc import abstractmethod
from collections import deque, defaultdict
from copy import copy
from atariari.benchmark.wrapper import AtariARIWrapper

# import gym_super_mario_bros
# from nes_py.wrappers import BinarySpaceToDiscreteSpaceEnv
# from gym_super_mario_bros.actions import SIMPLE_MOVEMENT, COMPLEX_MOVEMENT

from torch.multiprocessing import Pipe, Process

from model import *
from config import *
from PIL import Image

warnings.filterwarnings("ignore", category=RuntimeWarning, message="overflow encountered in scalar subtract")


ROOT_DIR = os.path.dirname(os.path.abspath(__file__))

train_method = default_config['TrainMethod']
max_step_per_episode = int(default_config['MaxStepPerEpisode'])

class Environment(Process):
    @abstractmethod
    def run(self):
        pass

    @abstractmethod
    def reset(self):
        pass

    @abstractmethod
    def pre_proc(self, x):
        pass

    @abstractmethod
    def get_init_state(self, x):
        pass


def unwrap(env):
    if hasattr(env, "unwrapped"):
        return env.unwrapped
    elif hasattr(env, "env"):
        return unwrap(env.env)
    elif hasattr(env, "leg_env"):
        return unwrap(env.leg_env)
    else:
        return env


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, is_render, skip=4):
        """Return only every `skip`-th frame"""
        gym.Wrapper.__init__(self, env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
        self._skip = skip
        self.is_render = is_render

    def step(self, action):
        """Repeat action, sum reward, and max over last observations."""
        total_reward = 0.0
        done = None
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            # if self.is_render:
            #     self.env.render()
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        # Note that the observation on the done=True frame
        # doesn't matter
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class MontezumaInfoWrapper(gym.Wrapper):
    def __init__(self, env, room_address, use_state_loading=False, load_room=10, room_saving=False, should_calc_additional_metrics=False):
        super(MontezumaInfoWrapper, self).__init__(env)
        self.room_address = room_address
        self.visited_rooms = defaultdict(int)
        self.use_state_loading = use_state_loading
        self.load_room = load_room
        self.room_saving = room_saving

        self.cumulative_reward = 0
        self.room_visitation = None
        self.rooms_survived = None
        self.rooms_lives_lost = None
        self.prev_room = None
        self.current_room_survive = None
        self.prev_lives = unwrap(self.env).ale.lives()
        self.rooms_done = None

        self.prev_inventory = None
        self.prev_position = None
        self.prev_cumulative_reward = 0
        
        self.step_before_inventory = None
        self.step_before_position = None
        self.step_before_cumulative_reward = 0

        self.possible_loading_states = self.possible_states()
        self.reset_metrics()
        self.should_calc_additional_metrics = should_calc_additional_metrics

    @staticmethod
    def is_within_radius(current_pos, prev_pos, radius=5):
        # Extract player coordinates from the dictionary
        current_x, current_y = current_pos['player_x'], current_pos['player_y']
        prev_x, prev_y = prev_pos['player_x'], prev_pos['player_y']

        # Calculate the distance between current and previous positions
        distance = math.sqrt((current_x - prev_x) ** 2 + (current_y - prev_y) ** 2)

        # Check if the distance is within the specified radius
        return distance <= radius

    def _check_if_room_done(self):
        if self.step_before_cumulative_reward!= self.prev_cumulative_reward:
            return True
        if not self.is_within_radius(self.prev_position, self.step_before_position):
            return True
        if self.prev_inventory != self.step_before_inventory:
            return True
        return False

    def _update_step_before_info(self, info):
        self.step_before_inventory = info['labels']['items_in_inventory_count']
        self.step_before_position = {'player_x':info['labels']['player_x'], 'player_y': info['labels']['player_y']}
        self.step_before_cumulative_reward = self.cumulative_reward

    def _update_new_room_info(self, info):
        self.prev_inventory = info['labels']['items_in_inventory_count']
        self.prev_position = {'player_x':info['labels']['player_x'], 'player_y': info['labels']['player_y']}
        self.prev_cumulative_reward = self.cumulative_reward

    def possible_states(self):
        states_path = os.path.join(ROOT_DIR, "game_states")
        return [
            os.path.join(states_path, elem)
            for elem in os.listdir(states_path)
            if f"room_{self.load_room}_" in elem
        ]

    def reset_metrics(self):
        labels = self.labels()
        info = {"labels":labels}
        self._update_new_room_info(info)
        self._update_step_before_info(info)

        self.prev_room = labels["room_number"]

        self.visited_rooms.clear()
        self.current_room_survive = 0
        self.rooms_survived = {i: 0 for i in range(24)}
        self.room_visitation = {i: 0 for i in range(24)}
        self.rooms_lives_lost = {i: 0 for i in range(24)}
        self.rooms_done = {i: 0 for i in range(24)}
        self.cumulative_reward = 0
        self.prev_lives = unwrap(self.env).ale.lives()

    def get_current_room(self):
        ram = unwrap(self.env).ale.getRAM()
        assert len(ram) == 128
        room = int(ram[self.room_address])

        if self.room_saving:
            lives = unwrap(self.env).ale.lives()
            if not os.path.exists(f"game_states/room_{room}_lives_{lives}_return_{self.cumulative_reward}.pkl"):
                save_room = self.env.clone_state()
                pickle.dump(save_room, open(f"game_states/room_{room}_lives_{lives}_return_{self.cumulative_reward}.pkl", "wb"))
        return room

    def step(self, action):
        obs, rew, done, info = self.env.step(action)

        if self.should_calc_additional_metrics:
            self._calc_additional_metrics(done, info, rew)
        else:
            info['current_room'] = self.get_current_room()

        if done:
            self.reset_metrics()
        return obs, rew, done, info

    def _calc_additional_metrics(self, done, info, rew):
        self.cumulative_reward += rew
        current_room = self.get_current_room()
        self.visited_rooms[current_room] += 1  # Increment the time spent in the current room
        current_lives = info['lives']
        if current_room == self.prev_room:
            if done:
                self.rooms_survived[current_room] = self.current_room_survive
            else:
                self.current_room_survive += 1
            self._update_step_before_info(info)

        else:  # There is probably no chance of dying just after entering the new room
            if self._check_if_room_done():
                self.rooms_done[self.prev_room] += 1
            self._update_new_room_info(info)
            self.room_visitation[current_room] += 1
            self.prev_room = current_room
            self.current_room_survive = 0
            self.rooms_survived[self.prev_room] = 0
        if current_lives != self.prev_lives:
            self.rooms_lives_lost[current_room] += 1
            self.prev_lives = current_lives
        if 'episode' not in info:
            info['episode'] = {}
        info['episode'].update(visited_rooms=copy(set(self.visited_rooms.keys())))
        info['episode'].update(visited_rooms_full=copy(self.visited_rooms))
        info['current_room'] = current_room
        info['rooms_survived'] = copy(self.rooms_survived)
        info['room_visitation'] = copy(self.room_visitation)
        info['lives_lost'] = copy(self.rooms_lives_lost)
        info['rooms_done'] = copy(self.rooms_done)
        info['inventory'] = info['labels']['items_in_inventory_count']

    def reset(self, load_state_path:str=None):
        obs = self.env.reset()
        if not self.use_state_loading:
            return obs
        if load_state_path is None:
            load_state_path = random.choice(self.possible_loading_states)
        print(f"Loading game state: {load_state_path}")
        load_state = pickle.load(open(load_state_path, "rb"))
        unwrap(self.env).ale.restoreState(load_state)
        self.reset_metrics()
        return unwrap(self.env)._get_obs()

class AtariEnvironment(Environment):
    def __init__(
            self,
            env_id,
            is_render,
            env_idx,
            child_conn,
            history_size=4,
            h=84,
            w=84,
            life_done=True,
            sticky_action=True,
            p=0.25,
            writer=None,
            use_state_loading=False,
            load_room=10,
            room_saving=True,
            should_calc_additional_metrics=False):
        super(AtariEnvironment, self).__init__()
        self.writer = writer
        self.daemon = True
        self.env = MaxAndSkipEnv(gym.make(env_id, render_mode='human' if is_render else None), is_render)
        if "Montezuma" in env_id:
            self.env = MontezumaInfoWrapper(
                AtariARIWrapper(self.env),
                room_address=3 if "Montezuma" in env_id else 1,
                use_state_loading=use_state_loading,
                load_room=load_room,
                room_saving=room_saving,
                should_calc_additional_metrics=should_calc_additional_metrics if env_idx == 0 else False
            )
        self.env_id = env_id
        self.is_render = is_render
        self.env_idx = env_idx
        self.steps = 0
        self.episode = 0
        self.rall = 0
        self.recent_rlist = deque(maxlen=100)
        self.child_conn = child_conn

        self.sticky_action = sticky_action
        self.last_action = 0
        self.p = p

        self.history_size = history_size
        self.history = np.zeros([history_size, h, w])
        self.h = h
        self.w = w

        self.reset()

    def run(self):
        super(AtariEnvironment, self).run()
        while True:
            action = self.child_conn.recv()

            if 'Breakout' in self.env_id:
                action += 1

            # sticky action
            if self.sticky_action:
                if np.random.rand() <= self.p:
                    action = self.last_action
                self.last_action = action

            s, reward, done, info = self.env.step(action)

            if max_step_per_episode < self.steps:
                done = True

            log_reward = reward
            force_done = done

            self.history[:3, :, :] = self.history[1:, :, :]
            self.history[3, :, :] = self.pre_proc(s)

            self.rall += reward
            self.steps += 1

            if done:
                self.recent_rlist.append(self.rall)
                print("[Episode {}({})] Step: {}  Reward: {}  Recent Reward: {}  Visited Room: [{}]".format(
                    self.episode, self.env_idx, self.steps, self.rall, np.mean(self.recent_rlist),
                    info.get('episode', {}).get('visited_rooms', {})))
                if self.writer and self.env_idx==0:
                    rooms_dict = info.get('episode', {}).get('visited_rooms_full', {})
                    rooms_survived = info.get('rooms_survived', {})
                    rooms_visitation = info.get('room_visitation', {})
                    lives_lost = info.get('lives_lost', {})
                    rooms_done = info.get('rooms_done', {})
                    inventory =  info.get('rooms_done', {})

                    for room, survived_time in rooms_survived.items():
                        self.writer.add_scalar(f"time_spent/{room}", rooms_dict.get(room, 0), self.episode)
                        self.writer.add_scalar(f"survived/{room}", survived_time, self.episode)
                        self.writer.add_scalar(f"visitation/{room}", rooms_visitation.get(room, 0), self.episode)
                        self.writer.add_scalar(f"lives_lost/{room}", lives_lost.get(room, 0), self.episode)
                        self.writer.add_scalar(f"rooms_done/{room}", rooms_done.get(room, 0), self.episode)
                        self.writer.add_scalar(f"inventory/{room}", inventory.get(room, 0), self.episode)

                self.history = self.reset()

            self.child_conn.send(
                [self.history[:, :, :], reward, force_done, done, log_reward, info['current_room']])

    def step(self, action):
        if 'Breakout' in self.env_id:
            action += 1

        # Sticky action
        if self.sticky_action:
            if np.random.rand() <= self.p:
                action = self.last_action
            self.last_action = action

        s, reward, done, info = self.env.step(action)

        if max_step_per_episode < self.steps:
            done = True

        log_reward = reward
        force_done = done

        self.history[:3, :, :] = self.history[1:, :, :]
        self.history[3, :, :] = self.pre_proc(s)

        self.rall += reward
        self.steps += 1

        if done:
            self.recent_rlist.append(self.rall)
            print(
                f"[Episode {self.episode}] Step: {self.steps}  Reward: {self.rall}  Recent Reward: {np.mean(self.recent_rlist)}"
            )

            self.history = self.reset()

        return self.history[:, :, :], reward, force_done, done, log_reward, info

    def reset(self):
        self.last_action = 0
        self.steps = 0
        self.episode += 1
        self.rall = 0
        s = self.env.reset()
        self.get_init_state(
            self.pre_proc(s))
        return self.history[:, :, :]

    def pre_proc(self, X):
        X = np.array(Image.fromarray(X).convert('L')).astype('float32')
        x = cv2.resize(X, (self.h, self.w))
        return x

    def get_init_state(self, s):
        for i in range(self.history_size):
            self.history[i, :, :] = self.pre_proc(s)


# class MarioEnvironment(Process):
#     def __init__(
#             self,
#             env_id,
#             is_render,
#             env_idx,
#             child_conn,
#             history_size=4,
#             life_done=False,
#             h=84,
#             w=84, movement=COMPLEX_MOVEMENT, sticky_action=True,
#             p=0.25):
#         super(MarioEnvironment, self).__init__()
#         self.daemon = True
#         self.env = BinarySpaceToDiscreteSpaceEnv(
#             gym_super_mario_bros.make(env_id), COMPLEX_MOVEMENT)
#
#         self.is_render = is_render
#         self.env_idx = env_idx
#         self.steps = 0
#         self.episode = 0
#         self.rall = 0
#         self.recent_rlist = deque(maxlen=100)
#         self.child_conn = child_conn
#
#         self.life_done = life_done
#         self.sticky_action = sticky_action
#         self.last_action = 0
#         self.p = p
#
#         self.history_size = history_size
#         self.history = np.zeros([history_size, h, w])
#         self.h = h
#         self.w = w
#
#         self.reset()
#
#     def run(self):
#         super(MarioEnvironment, self).run()
#         while True:
#             action = self.child_conn.recv()
#             if self.is_render:
#                 self.env.render()
#
#             # sticky action
#             if self.sticky_action:
#                 if np.random.rand() <= self.p:
#                     action = self.last_action
#                 self.last_action = action
#
#             # 4 frame skip
#             reward = 0.0
#             done = None
#             for i in range(4):
#                 obs, r, done, info = self.env.step(action)
#                 if self.is_render:
#                     self.env.render()
#                 reward += r
#                 if done:
#                     break
#
#             # when Mario loses life, changes the state to the terminal
#             # state.
#             if self.life_done:
#                 if self.lives > info['life'] and info['life'] > 0:
#                     force_done = True
#                     self.lives = info['life']
#                 else:
#                     force_done = done
#                     self.lives = info['life']
#             else:
#                 force_done = done
#
#             # reward range -15 ~ 15
#             log_reward = reward / 15
#             self.rall += log_reward
#
#             r = int(info.get('flag_get', False))
#
#             self.history[:3, :, :] = self.history[1:, :, :]
#             self.history[3, :, :] = self.pre_proc(obs)
#
#             self.steps += 1
#
#             if done:
#                 self.recent_rlist.append(self.rall)
#                 print(
#                     "[Episode {}({})] Step: {}  Reward: {}  Recent Reward: {}  Stage: {} current x:{}   max x:{}".format(
#                         self.episode,
#                         self.env_idx,
#                         self.steps,
#                         self.rall,
#                         np.mean(
#                             self.recent_rlist),
#                         info['stage'],
#                         info['x_pos'],
#                         self.max_pos))
#
#                 self.history = self.reset()
#
#             self.child_conn.send([self.history[:, :, :], r, force_done, done, log_reward])
#
#     def reset(self):
#         self.last_action = 0
#         self.steps = 0
#         self.episode += 1
#         self.rall = 0
#         self.lives = 3
#         self.stage = 1
#         self.max_pos = 0
#         self.get_init_state(self.env.reset())
#         return self.history[:, :, :]
#
#     def pre_proc(self, X):
#         # grayscaling
#         x = cv2.cvtColor(X, cv2.COLOR_RGB2GRAY)
#         # resize
#         x = cv2.resize(x, (self.h, self.w))
#
#         return x
#
#     def get_init_state(self, s):
#         for i in range(self.history_size):
#             self.history[i, :, :] = self.pre_proc(s)
