import torch
import numpy as np

class ImageWorld():
    class action_space():
        def __init__(self, n_actions):
            self.n = n_actions
        def sample(self, k=1):
            return torch.randint(0, self.n, (k,))
    class observation_space():
        def __init__(self):
            self.shape = []

    def __init__(self, images: np.array):
        self.images = images
        self.state_space = self.images.shape[:-3]
        self.action_space = self.action_space(2*(images.ndim-3))
        self.observation_space = self.observation_space()

    def reset(self, state_init=None):
        if state_init is None:
            # Select at random a position within the state space
            self.state = np.array([np.random.randint(0, dim_size) for dim_size in self.state_space])
        else:
            self.state = state_init
        return self.get_observation()

    def get_observation(self):
        return torch.from_numpy(self.images[tuple(self.state)])

    def step(self, action):
        if action == 0:
            self.state[0] += 1
            self.state[0] = self.state[0] % self.state_space[0]
        elif action == 1:
            self.state[0] -= 1
            self.state[0] = self.state[0] % self.state_space[0]
        elif action == 2:
            self.state[1] += 1
            self.state[1] = self.state[1] % self.state_space[1]
        elif action == 3:
            self.state[1] -= 1
            self.state[1] = self.state[1] % self.state_space[1]
        return self.get_observation()


class ObjectImageWorld():
    """
    Environment class in which only one transition can be done (rotations) and in each state reset we can end up in new
    objects
    """
    class action_space():
        def __init__(self, n_actions):
            self.n = n_actions
        def sample(self, k=1):
            return torch.randint(0, self.n, (k,))
    class observation_space():
        def __init__(self):
            self.shape = []

    def __init__(self, images: np.array):
        self.images = images
        self.state_space = self.images.shape[:-3]
        self.action_space = self.action_space(2)
        self.observation_space = self.observation_space()

    def reset(self, state_init=None):
        if state_init is None:
            # Select at random a position within the state space
            self.state = np.array([np.random.randint(0, dim_size) for dim_size in self.state_space])
        else:
            self.state = state_init
        return self.get_observation()

    def get_observation(self):
        return torch.from_numpy(self.images[tuple(self.state)])

    def step(self, action):
        if action == 0:
            self.state[1] += 1
            self.state[1] = self.state[1] % self.state_space[1]
        elif action == 1:
            self.state[1] -= 1
            self.state[1] = self.state[1] % self.state_space[1]

        return self.get_observation()
