import torch
import seaborn as sns
import pandas as pd
import numpy as np
import gymnasium as gym
from sklearn.neighbors import KernelDensity

# This environment is a GMM distribution with four peaks
class Environment(gym.Env):
    metadata = {"render_modes": ["seaborn"]}
    def __init__(self, render_mode=None):
        # Reward distribution
        mix = torch.distributions.Categorical(torch.ones(4,))
        comp = torch.distributions.Independent(
                    torch.distributions.Normal(
                        torch.tensor([[.2, .2], [.8, .2], [.2, .8], [.8, .8]]),
                        torch.full((4, 2), 0.02)
                    ),
                    reinterpreted_batch_ndims=1
                )
        self.gmm = torch.distributions.MixtureSameFamily(mix, comp)
        self.observation_space = gym.spaces.Box(shape=(40,), low = 0, high = .5, dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(3,),low = -2, high = 2, dtype=np.float32)

        self.max_t = 10

        self.s = np.array([.5, .5])
        self.t = 0

         # create a 2d grid of points
        x = np.linspace(0, 1, 100)
        y = np.linspace(0, 1, 100)
        X, Y = np.meshgrid(x, y)
        self.data = np.vstack([X.ravel(), Y.ravel()]).T
        self.data2 = torch.tensor(self.data)

        self.render_mode = render_mode

    def reset(self, seed = None, options = None):
        super().reset(seed = seed)
        
        self.s = np.array([.5, .5])
        self.t = 0

        return self._get_obs(), self._get_info()

    # ending state: the one before the timeout or the last one within the bounding box
    def step(self, a):
        exit_flag = a[0]
        a2 = a[1:]/5
        tmp_s = self.s + a2
        if self.t + 1>= self.max_t or np.any(tmp_s>1) or np.any(tmp_s<0) or exit_flag > 0.5:
            return self._get_obs(), max(torch.exp(self.gmm.log_prob(torch.tensor(self.s))), 1e-10), True, False, {'augmented_rew': 1}
        
        self.t += 1
        self.s = tmp_s
        # current state, reward, done?, invalid?, augmented_reward
        return self._get_obs(), 0, False, False, {}
    
    def _get_obs(self): # return the state of the environment
        obs = np.zeros(40, dtype=np.float32)
        obs[self.t * 4] = max(self.s[0] - .5, 0) + 1e-8
        obs[self.t * 4 + 1] = max(self.s[1] - .5, 0)  + 1e-8
        obs[self.t * 4 + 2] = max(.5 - self.s[0], 0)  + 1e-8
        obs[self.t * 4 + 3] = max(.5 - self.s[1], 0)  + 1e-8
        return obs
    
    def _get_info(self): # return the information of the environment
        return {}
    
    def get_state(self, obs):
        obs = (obs - 1e-8)
        state = obs[np.where(obs >= 0)].reshape(len(obs), 4) # find the element that >=0 for each row
        state = .5 + state[:, :2] - state[:, 2:]
        return  state # reverse the state encoding

    def get_error(self, samples):
        """Get the KL error between the distribution given by `samples` and the true distribution
        """
        # Fit KDE to your samples
        kde = KernelDensity(kernel='gaussian', bandwidth='scott').fit(samples)
        def integrand(x, x2):
                px = kde.score_samples(x)
                qx = self.gmm.log_prob(x2)
                return torch.exp(qx) * (qx - px)
        
        kl_divergence = integrand(self.data, self.data2)
        return kl_divergence.sum().item()/10000
    
    def render(self):
        if self.render_mode == "matplotlib":
            # sample from the policy
            samples = self.gmm.sample((10000,))
            data = pd.DataFrame(samples.numpy())
            g = sns.pairplot(data, kind="hist", corner=True)
            g.axes[1, 0].set_xlim(0, 1)
            g.axes[1, 0].set_ylim(0, 1)
            

if __name__ == "__main__":    
    from tqdm import trange

    # =========================================================================
    # Show sparsity
    np.random.seed(42)
    torch.manual_seed(42)

    num_samples = 100_000
    env = Environment()
    rs = []
    for _ in trange(num_samples):
        done = False
        while not done:
            # sample from normal distribution
            action = np.random.normal(size=(2,))/2
            # if any action is > 2 or < -2, sample that dimension again
            while np.any(action > 2) or np.any(action < -2):
                ind = np.where(((action > 2) | (action < -2)) == 1)
                action[ind] = np.random.normal(size=(len(ind[0]),))/2
            exit_flag = np.random.randint(2, size=(1,))
            action = np.concatenate((exit_flag, action))
            _, r, done, _, _ = env.step(action)
        rs.append(r)
        env.reset()
    
    rs = np.array(rs)
    print(np.sum(rs > 1e-3))