import torch

import numpy as np
import gymnasium as gym
from einops import rearrange
from torch.nn.functional import one_hot

class Environment(gym.Env):
    metadata = {"render_modes": []}

    def __init__(self, ndim = 4, height = 64, R0 = 1e-10 , R1 = 0, R2 = 2.0, render_mode=None):
        """HyperGrid environment from the GFlowNets paper.
        The states are represented as 1-d tensors of length `ndim` with values in
        {0, 1, ..., height - 1}.

        Args:
            ndim (int, optional): dimension of the grid. Defaults to 4.
            height (int, optional): height of the grid. Defaults to 64.
            R0 (float, optional): reward parameter R0. Defaults to 1e-3.
            R1 (float, optional): reward parameter R1. Defaults to 1e-3.
            R2 (float, optional): reward parameter R2. Defaults to 2.0.
        """
        self.ndim = ndim
        self.height = height
        self.R0 = R0
        self.R1 = R1
        self.R2 = R2

        self.observation_space = gym.spaces.Box(shape=(ndim*height,), low=0, high=1, dtype=np.float32)
        self.action_space = gym.spaces.Discrete(ndim + 1)

        self.state = np.zeros(self.ndim, dtype=np.float32)
        self.render_mode = render_mode

        self.truth_dist = np.zeros([self.height] * self.ndim)

        # go through all the states
        ## slow ver
        # for i in range(self.height):
        #     for j in range(self.height):
        #         for k in range(self.height):
        #             for h in range(self.height):
        #                 state = np.array([i, j, k, h])
        #                 ax = abs(state / (self.height - 1) - 0.5)
        #                 self.truth_dist[i, j, k, h] = self.R0 + (0.25 < ax).prod(-1) * self.R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * self.R2
        ## fater ver
        # Create an array of indices
        indices = np.arange(self.height)

        # Compute the 'state' array in a vectorized form
        # (i, j, k, h) will be broadcasted into a 3D array
        states = np.stack(np.meshgrid(indices, indices, indices, indices, indexing='ij'), axis=-1)

        # Normalize and calculate ax at once
        ax = np.abs(states / (self.height - 1) - 0.5)

        # Calculate the conditions
        cond1 = (0.25 < ax).all(axis=-1)  # ax > 0.25
        cond2 = ((0.3 < ax) & (ax < 0.4)).all(axis=-1)  # 0.3 < ax < 0.4

        # Vectorized calculation of the truth distribution
        self.truth_dist = self.R0 + cond1 * self.R1 + cond2 * self.R2

        self.truth_dist = self.truth_dist / self.truth_dist.sum()

    def reset(self, seed=None, options = None):
        super().reset(seed=seed)
        
        self.state = np.zeros(self.ndim, dtype=np.float32)

        return self._get_obs(), {}
    
    def _get_obs(self):
        # K hot encoding
        # obs = np.zeros(self.ndim * self.height, dtype=np.float32)
        # for i in range(self.ndim):
        #     ind = i*self.height + int(self.state[i])
        #     obs[ind] = 1

        state_tensor = torch.tensor(self.state).long()
        obs_tensor = one_hot(state_tensor, num_classes=self.height).float()
        hot = rearrange(obs_tensor, '... h w -> ... (h w)')
        return hot.numpy()
            
    def get_state(self, obs):
        # reverse K hot encoding
        state = np.zeros(obs.shape[:-1] + (self.ndim,), dtype=np.float32)
        for i in range(self.ndim):
            state[..., i] = np.argmax(obs[..., i*self.height:(i+1)*self.height], axis=-1)
        return state

    def step(self, action):
        """`action` is an integer in {0, ..., height}.
        """
        # if action is exit or if going out of bounds, both terminate the episode without changing the state
        if action == self.ndim: #or self.state[action] == self.height - 1:
            ax = abs(self.state / (self.height - 1) - 0.5)
            reward = (
                self.R0 + (0.25 < ax).prod(-1) * self.R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * self.R2
            )
            return self._get_obs(), reward, True, False, {'augmented_rew': 1}
        
        self.state[action] += 1
        
        return self._get_obs(), 0, False, False, {}
    
    def get_forward_action_masks(self, state):
        """Returns a binary tensor of shape (ndim + 1,) where the ith element is 1 if the ith action is feasible.
        """
        # size with last dim to be 1, so that we can broadcast with state
        selected_ind = [(i + 1) * self.height -1 for i in range(self.ndim)]
        masks = state[..., selected_ind] != 1
        # add an all True array for the exit action in the last dim
        if isinstance(state, torch.Tensor):
            exit_mask = torch.ones(list(masks.size())[:-1]+[1], device=state.device, dtype=torch.bool)
            masks =  torch.cat([masks, exit_mask], dim=-1)
        else:
            exit_mask = np.ones(list(masks.shape)[:-1]+[1])
            masks = np.concatenate([masks, exit_mask], axis=-1)
        return masks
    
    def get_backward_action_masks(self, state):
        """Returns a binary tensor of shape (ndim + 1,) where the ith element is 1 if the ith action is feasible.
        """
        # size with last dim to be 1, so that we can broadcast with state
        selected_ind = [(i) * self.height for i in range(self.ndim)]
        masks = state[..., selected_ind] != 1
        return masks

    def get_error(self, samples):
        """Get the L1 error between the distribution given by `samples` and the true distribution
        """
        samples = np.array(samples)
        # compute the sample distribution
        sample_dist = np.zeros((self.height, self.height, self.height, self.height))
        for s in samples:
            sample_dist[int(s[0]), int(s[1]), int(s[2]), int(s[3])] += 1

        sample_dist = sample_dist / sample_dist.sum()

        # compute the L1 error
        return np.abs(self.truth_dist - sample_dist).sum()

    def render(self):
        pass

if __name__ == "__main__":
    from tqdm import trange

    # =========================================================================
    # Show sparsity
    np.random.seed(42)
    num_samples =  100_000
    env = Environment()
    rs = []
    for _ in trange(num_samples):
        done = False
        while not done:
            valid_actions = env.get_forward_action_masks(env._get_obs())
            action_index = np.random.choice(int(np.sum(valid_actions)))
            action = np.where(valid_actions==1)[0][action_index]
            _, r, done, _, _ = env.step(action)
        rs.append(r)
        env.reset()
    
    rs = np.array(rs)
    print(np.sum(rs > 1e-3))