from collections import deque

import atari_py
import random

import cv2 as cv
import numpy as np

from utils import logs_handler, process_image

ACTIONS = ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT', 
           'UPFIRE', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 
           'DOWNRIGHTFIRE', 'DOWNLEFTFIRE', 'UNKNOWN']

class Env:
    def __init__(self, game, img_size, grayscale=False, use_dynamic_range=False, channel_last=False, 
                 frame_skip=4, buffer_size=None, max_episode_length=None, clip_reward=False, seed=42):
        
        self.logger = logs_handler.get_logger(f'ale.env.{game}')
        assert frame_skip > 0, '...'
        self.game = game
        self.ale = atari_py.ALEInterface()
        self.ale.setInt('random_seed', seed)
        self.ale.setInt('max_num_frames_per_episode', max_episode_length or 108e3)
        self.ale.setFloat('repeat_action_probability', 0)  # Disable sticky actions
        self.ale.setInt('frame_skip', 0)
        self.ale.setBool('color_averaging', False)
        self.ale.loadROM(atari_py.get_game_path(game))  # ROM loading must be done after setting options
        self.screen_dims = self.ale.getScreenDims()[::-1]
        self.actions_set = self.ale.getMinimalActionSet()
        self.actions = dict([int(act), ACTIONS[act]] for act in self.actions_set)
        self.lives = 0  # Life counter (used in DeepMind training)
        self.life_termination = False  # Used to check if resetting only from loss of life
        self.training = True  # Consistent with model training mode
        self.img_size = img_size
        self.num_channels = 1 if grayscale else 3
        self.grayscale = grayscale
        self.use_dynamic_range = use_dynamic_range
        self.channel_last = channel_last
        self.frame_skip = frame_skip
        self.clip_reward = clip_reward
        
        # [stacking frames + sliding window]
        self.buffer_size = buffer_size or 1
        self.state_buffer = deque([], maxlen=self.buffer_size)
    
        self.logger.info(f'Observation Shape: {self.observation_shape(from_buffer=True)}')
        self.logger.info(f'Actions: {self.actions_set}')
         
        assert ACTIONS[0] in self.actions.values(), '...'

    def isvalid_action(self, action):
        return action in self.actions
    
    def random_action(self):
        return int(np.random.choice(self.actions_set, size=()))
    
    def observation_shape(self, from_buffer=True):
        depth = self.num_channels * self.buffer_size if from_buffer else self.num_channels
        shape = (*self.img_size, depth) if self.channel_last else (depth, *self.img_size)
        return shape
    
    def reset_buffer(self):
        for _ in range(self.buffer_size):
            self.state_buffer.append(np.zeros(self.observation_shape(False), dtype=np.float32))

    def update_buffer(self, state):
        self.state_buffer.append(state)
    
    def release_buffer(self):
        return np.concatenate(self.state_buffer, axis=0)
    
    def reset(self, from_buffer=True):
        if self.life_termination:
            self.life_termination = False  # Reset flag
            self.ale.act(0)  # Use a no-op after loss of life
        else:
            # Reset internals
            self.ale.reset_game()
            # Perform up to 30 random no-ops before starting
            for _ in range(random.randrange(30)):
                self.ale.act(0)  # Assumes raw action 0 is always no-op
                if self.ale.game_over():
                    self.ale.reset_game()
        # Process and return "initial" state
        self.reset_buffer()
        state = self.get_state()
        state = self.process_state(state)
        self.update_buffer(state)
        if self.state_buffer is not None:
            self.state_buffer.append(state)
        self.lives = self.ale.lives()
        if from_buffer:
            return self.release_buffer()
        return state
    
    def get_state(self):
        state = self.ale.getScreenGrayscale() if self.grayscale else self.ale.getScreenRGB()
        return state

    def process_state(self, state):
        state = process_image.cv_process(state, self.img_size, channel_last=self.channel_last)
        state = state.astype('float32')
        if self.use_dynamic_range:
            state = process_image.normalize(state)
        else:
            state /= 255.0
        return state
    
    def process_reward(self, reward):
        if self.clip_reward:
            reward = float(np.sign(reward))
        return reward
    
    def step(self, action, from_buffer=True):
        # Repeat action 4/k times, max pool over last 2 frames --> [frame-skip]
        frame_buffer = np.zeros((2, *self.screen_dims, self.num_channels))
        reward, done = 0, False
        for t in range(self.frame_skip):
            reward += self.process_reward(self.ale.act(action))
            if t == self.frame_skip - 2:
                frame_buffer[0] = self.get_state()
            elif t == self.frame_skip - 1:
                frame_buffer[1] = self.get_state()
            done = self.ale.game_over()
            if done:
                break
        if self.frame_skip == 1:
            state = frame_buffer[1]
        else:
            state = frame_buffer.max(0)
        state = self.process_state(state)
        self.update_buffer(state)
        # Detect loss of life as terminal in training mode
        if self.training:
            lives = self.ale.lives()
            if lives < self.lives and lives > 0:  # Lives > 0 for Q*bert
                self.life_termination = not done  # Only set flag when not truly done
                done = True
            self.lives = lives
        # Return state, reward, done
        if from_buffer:
            state = self.release_buffer()
        return state, reward, done
    
    # Uses loss of life as terminal signal
    def train(self):
        self.training = True

    # Uses standard terminal signal
    def eval(self):
        self.training = False

    def action_space(self):
        return len(self.actions)

    def render(self):
        cv.imshow('screen', self.ale.getScreenRGB()[:, :, ::-1])
        cv.waitKey(1)

    def close(self):
        cv.destroyAllWindows()
