import gymnasium as gym
from gymnasium import spaces
from itertools import product
import numpy as np


class FactorisedWrapper(gym.Wrapper):
    def __init__(self, env, bin_size=3):
        super(FactorisedWrapper, self).__init__(env)
        self.num_subaction_spaces = self.env.action_space.shape[0]
        if isinstance(bin_size, int):
            self.bin_size = [bin_size] * self.num_subaction_spaces
        elif isinstance(bin_size, list) or isinstance(bin_size, np.ndarray):
            assert len(bin_size) == self.num_subaction_spaces
            self.bin_size = bin_size
        lows = self.env.action_space.low
        highs = self.env.action_space.high
        self.action_lookups = {}
        for a, l, h in zip(range(self.num_subaction_spaces), lows, highs):
            self.action_lookups[a] = {}
            bins = np.linspace(l, h, self.bin_size[a])
            for count, b in enumerate(bins):
                self.action_lookups[a][count] = b
        self.action_space = spaces.MultiDiscrete(self.bin_size)

    def step(self, action):
        action = self.get_continuous_action(action)
        action = np.asarray(action)
        return super().step(action)

    def get_continuous_action(self, action):
        continuous_action = []
        for action_id, a in enumerate(action):
            continuous_action.append(self.action_lookups[action_id][a])
        return continuous_action

def make_env(env, bin_size=3):

    return FactorisedWrapper(env, bin_size=bin_size)

