import cv2
import gym
import matplotlib.pyplot as plt
import numpy as np
import pickle
import PIL
import torch
import torchvision.transforms
import pickle

from .. import ramops

convertToTensor = torchvision.transforms.ToTensor()
convertToPIL = torchvision.transforms.ToPILImage()

class AtariEnv:
    """Wrapped Gym Atari environment
    """
    def __init__(self, name='', seed=None, render_mode="rgb_array"):
        envs = {
            'breakout': 'Breakout',
            'freeway': 'Freeway',
            'montezuma': 'MontezumaRevenge',
            'pacman': 'MsPacman',
            'pong': 'Pong',
        }
        assert name in envs.keys()
        self.env = gym.make('{}NoFrameskip-v4'.format(envs[name]), render_mode=render_mode)
        self.action_space = self.env.action_space
        self.observation_space = self.env.observation_space
        self.n_actions = self.env.action_space.n
        self.reset(seed=seed)

    def __getattr__(self, name):
        # NOTE: Forwarding to underlying env's attributes as this is not a proper wrapper
        # not using self.env because self.* internally calls __getattr__, results in recursion loop.
        internal_env = self.__class__.__getattribute__(self, 'env')
        return internal_env.__getattribute__(name)

    @property
    def unwrapped(self):
        return self.env.unwrapped

    def reset(self, seed=None):
        self.env.reset()
        self._previous_frame = None
        if seed is not None:
            self.env.seed(seed)
        self.timestep = 0

    def close(self, *args, **kwargs):
        return self.env.close(*args, **kwargs)

    def step(self, *args, **kwargs):
        self.timestep += 1
        ob, reward, done, info = self.env.step(*args, **kwargs)
        return ob, reward, done, info

    def deprocess(self, img):
        """ Undo the tensor conversion process to get back to an 8-bit image
        """
        img = torch.tensor((img + 1) / 2 * 255, dtype=torch.uint8)  # convert to int [0,255]
        img = convertToPIL(img)  # CxHxW [0,255] => HxWxC [0,255]
        # Uncomment if you want to upscale back to original size
        # img = cv2.resize(np.array(img), (160,210), interpolation=cv2.INTER_LINEAR)
        # img = PIL.Image.fromarray(img) # Not required, but helps pyplot get the colors right
        return img

    def render(self, mode='human', **kwargs):
        """Draw the current frame and display it on-screen
        """
        return self.env.render(mode=mode, **kwargs)

    def getGrayscaleFrame(self):
        img_gray = np.empty([210, 160, 1], dtype=np.uint8)
        self.env.env.ale.getScreenGrayscale(img_gray)
        return img_gray

    def getRGBFrame(self):
        img_rgb = np.empty([210, 160, 3], dtype=np.uint8)
        self.env.env.ale.getScreenRGB(img_rgb)
        return img_rgb

    def getFrame(self):
        """Get the current frame as an RGB array
        """
        return self.getRGBFrame()

    def getRAM(self):
        return self.env.env.ale.getRAM()

    def parseRAM(self, ram):
        raise NotImplemented

    def getState(self):
        ram = self.getRAM()
        state = self.parseRAM(ram)
        return state

    def printRam(self):
        # prints out the current state of the ram
        ram = self.env.env.ale.getRAM()
        print(ram)

    def setRAM(self, ramIndex, value):
        # set the given index into ram to the given value.
        state = self.env.env.ale.cloneState()
        coded = self.env.env.ale.encodeState(state)
        # the ram section of the underlying state appears to be an array of
        # integers being used to represent an array of bytes. This means that
        # only every 4th element actually does anything: the rest are
        # meaningless 0's
        arrIndex = ramIndex * 4
        # the ram portion of the underlying state is offest 155 elements from
        # the beginning.
        arrIndex += 155
        coded[arrIndex] = value
        state2 = self.env.env.ale.decodeState(coded)
        self.env.env.ale.restoreState(state2)

    def save(self, filePath):
        state = self.env.env.ale.cloneState()
        coded = self.env.env.ale.encodeState(state)
        pickle.dump(coded, open(filePath, "wb"))

    def load(self, filePath):
        coded = pickle.load(open(filePath, "rb"))
        state = self.env.env.ale.decodeState(coded)
        self.env.env.ale.restoreState(state)

def testGetFrame():
    env = AtariEnv(name='freeway')
    frame0 = env.getFrame()
    ob, _, _, _ = env.step(0)
    frame1 = env.getFrame()
    # The screen (for freeway) should have changed after the first env.step()
    assert np.any(ob != frame0)
    # But with no preprocessing, ob should be equivalent to the next getFrame()
    assert np.all(ob == frame1)

def testPreprocess():
    env = AtariEnv(name='pacman')
    ob, _, _, _ = env.step(0)
    # ob should be a tensor
    assert type(ob) is torch.Tensor
    # ob should have the correct size, with channels first
    assert ob.size() == (3, env.ob_shape[1], env.ob_shape[0])

    # try displaying the image as a sanity check
    # plt.imshow(env.deprocess(ob))
    # plt.show()

    env = AtariEnv(name='pacman', grayscale=True)
    ob, _, _, _ = env.step(action=0)
    # For grayscale envs, ob should still have a dimension for channels
    assert ob.size() == (1, env.ob_shape[1], env.ob_shape[0])

    # try displaying the image again; verify that it looks gray
    # plt.imshow(env.deprocess(ob))
    # plt.show()

def testSetRam():
    env = AtariEnv(name='pacman')
    env.step(0)

    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ramops.setByte(env, "d5", 234)
    # the ram doesn't change until we step
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ob, reward, done, _ = env.step(action)
    # after one step, it should be set how we want.
    assert (ramops.getByte(env.getRAM(), "d5") == 234)
    env.close()

    # Repeat the same thing *without* setting the ram byte, and make sure it
    # doesn't just happen to equal that value anyways
    env = AtariEnv(name='pacman')
    env.step(0)
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ob, reward, done, _ = env.step(action)
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    env.close()

def testLoad():
    # this is what the ram array should be if we loaded properly
    desiredResult = np.array([
        0, 34, 113, 113, 51, 3, 110, 94, 82, 45, 42, 0, 57, 66, 69, 98, 98, 0, 0, 3, 0, 0, 1, 0, 0,
        2, 4, 5, 162, 8, 207, 0, 45, 1, 0, 51, 38, 232, 0, 75, 0, 49, 37, 4, 0, 120, 0, 82, 111,
        130, 0, 134, 1, 222, 0, 1, 3, 0, 6, 80, 255, 255, 0, 255, 255, 80, 255, 255, 80, 255, 255,
        80, 255, 255, 80, 191, 191, 80, 191, 191, 80, 191, 191, 80, 245, 170, 80, 255, 255, 80,
        255, 255, 80, 255, 255, 0, 255, 255, 80, 255, 255, 20, 223, 43, 217, 91, 217, 123, 217,
        123, 217, 123, 217, 123, 217, 221, 0, 63, 0, 6, 96, 0, 0, 2, 21, 251, 146, 215
    ])
    env = AtariEnv(name='pacman')
    env.load("savefile.txt")
    env.step(0)
    curRam = env.getRAM()
    for x in range(len(curRam)):
        assert (desiredResult[x] == curRam[x])

def testGetState():
    env = AtariEnv(name='pacman')
    state = env.getState()

    # Check screen
    assert state['screen'] == 0

    # Check positions
    assert state['orange_x'] == state['blue_x'] == state['pink_x'] == state['red_x'] == state[
        'pac_x']
    assert state['orange_y'] == state['blue_y'] == state['pink_y']
    assert state['pac_y'] > state['pink_y'] and state['pink_y'] > state['red_y']

    # Check dots / power pellets
    center_dots = state['dots'][5:8, 8:10]
    assert np.all(center_dots == False)
    assert np.sum(state['dots'] == False) == (6 + 4)  # center + power pellets

    # Check ghost status
    for key in state.keys():
        if '_status' in key:
            assert state[key]['edible'] == False
            if 'red' in key:
                assert state[key]['mode'] == 'free'
            else:
                assert state[key]['mode'] == 'jail'

def testSetRAM():
    env = AtariEnv(name='pacman')
    env.step(0)

    action = 0

    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ramops.setByte(env, "d5", 234)
    # the ram doesn't change until we step
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ob, reward, done, _ = env.step(action)
    # after one step, it should be set how we want.
    assert (ramops.getByte(env.getRAM(), "d5") == 234)
    env.close()

    # Repeat the same thing *without* setting the ram byte, and make sure it
    # doesn't just happen to equal that value anyways
    env = AtariEnv(name='pacman')
    env.step(0)
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    ob, reward, done, _ = env.step(action)
    assert (ramops.getByte(env.getRAM(), "d5") != 234)
    env.close()

def testLoad():
    # this is what the ram array should be if we loaded properly
    desiredResult = np.array([
        0, 34, 113, 113, 51, 3, 110, 94, 82, 45, 42, 0, 57, 66, 69, 98, 98, 0, 0, 3, 0, 0, 1, 0, 0,
        2, 4, 5, 162, 8, 207, 0, 45, 1, 0, 51, 38, 232, 0, 75, 0, 49, 37, 4, 0, 120, 0, 82, 111,
        130, 0, 134, 1, 222, 0, 1, 3, 0, 6, 80, 255, 255, 0, 255, 255, 80, 255, 255, 80, 255, 255,
        80, 255, 255, 80, 191, 191, 80, 191, 191, 80, 191, 191, 80, 245, 170, 80, 255, 255, 80,
        255, 255, 80, 255, 255, 0, 255, 255, 80, 255, 255, 20, 223, 43, 217, 91, 217, 123, 217,
        123, 217, 123, 217, 123, 217, 221, 0, 63, 0, 6, 96, 0, 0, 2, 21, 251, 146, 215
    ])
    env = AtariEnv(name='pacman')
    env.load("savefile.txt")
    env.step(0)
    curRam = env.getRAM()
    for x in range(len(curRam)):
        assert (desiredResult[x] == curRam[x])

def main():
    # testGetFrame()
    # testPreprocess()
    # testGetState()
    # testSetRAM()
    testLoad()
    print('Testing complete.')

if __name__ == '__main__':
    main()
