import numpy as np
import matplotlib.pyplot as plt
import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
# from action_space_transform_wrappers import ActionRedundancyWrapper
# from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from torch.utils.tensorboard import SummaryWriter

if th.cuda.is_available():
    device = th.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = th.device("cpu")
    print("CUDA is not available. Using CPU.")

action_bins = 20
bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)
path = ""
root_path = ""
replay_buffer = load_from_pkl(root_path + path, verbose=2)

obs_shape = replay_buffer.obs_shape
action_dim = replay_buffer.action_dim

data = replay_buffer.actions.squeeze()

densities = np.zeros_like(data)

plt.figure(figsize=(15, 10))

from sklearn.neighbors import KernelDensity
import time

test_index = 10
test_prob = 1

for i in range(data.shape[1]):

    kde = KernelDensity(kernel='tophat', bandwidth=0.05).fit(data[:, i].reshape(-1,1))


    x = np.linspace(-1, 1, 100)
    density = np.exp(kde.score_samples(x.reshape(-1,1)))

    ith_test_prob = density[test_index]
    print(str(i)+"th prob: "+str(ith_test_prob))
    test_prob = test_prob * ith_test_prob
    print("current_prob:" + str(test_prob))

    plt.subplot(4, 2, i + 1)
    plt.plot(x, density)
    plt.title(f'Dimension {i + 1} KDE')
    plt.grid(True)

# kde = KernelDensity(kernel='gaussian', bandwidth=0.05).fit(data)
# x = np.linspace(-1, 1, 100)
# test_x = [x[test_index] for _ in range(data.shape[1])]
# test_x = np.array(test_x).reshape(1,-1)
#
# density = np.exp(kde.score_samples(test_x))
# print("High dim evaluate: "+str(density))


plt.tight_layout()
plt.show()