from random import randint
import random
import time
from sklearn.preprocessing import MinMaxScaler
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.kernel_approximation import Nystroem
from mpl_toolkits.mplot3d import Axes3D

from large_rl.commons.launcher import launch_env, launch_agent
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.args import get_all_args

def to_true_action(actions_2d, feature_map):
    true_actions = feature_map.fit_transform(actions_2d)
    # reflact the true_actions to [0, 1]
    true_actions = MinMaxScaler().fit_transform(true_actions)
    # reflact the true_actions to the original action space
    true_actions = true_actions * (_high - _low) + _low
    return true_actions

if __name__ == "__main__":
    # get start time
    start_time = time.time()
    set_randomSeed(seed=randint(0, 1000000))
    args = get_all_args()
    args = vars(args)
    # args["agent_load_path"] = './weights/walker-flair/5179243-101-walker-none-flair'
    # args['agent_load_epoch'] = 10099
    args["num_envs"] = 1
    env = launch_env(args=args)
    args["env_max_action"] = 1.
    args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
    args["env_max_action"] = float(env.action_space[0].high[0])
    args["TD3_policy_noise"] *= args["env_max_action"]
    args["TD3_noise_clip"] *= args["env_max_action"]
    agent = launch_agent(args=args, env=env)
    # agent.load_model(epoch=args['agent_load_epoch'])

    plot_save_dir = f"./src/3Dplots/"

    if not os.path.exists(plot_save_dir):
        os.makedirs(plot_save_dir)

    _dim = env.action_space[0].shape[0]
    _high, _low = env.action_space[0].high, env.action_space[0].low

    # get 2d action first
    # sample num_samples of different values evenly spaced between(-1, 1)
    num_samples_sqrt = 500
    num_samples = num_samples_sqrt ** 2
    a2d_x = np.linspace(-1, 1, num=num_samples_sqrt).astype(np.float32)
    a2d_y = np.linspace(-1, 1, num=num_samples_sqrt).astype(np.float32)
    # a2d_x is the first dimension of the action, a2d_y is the second dimension of the action, combine them
    actions_2d = np.vstack(np.meshgrid(a2d_x, a2d_y)).reshape(2, -1).T
    # then use 2d actions to get true actions
    feature_map_nystroem = Nystroem(gamma=.2,random_state=1,n_components=_dim)
    true_actions = to_true_action(actions_2d, feature_map_nystroem)

    # get state
    s = env.reset()
    s = torch.tensor(s)
    action = agent.select_action(obs=s, act_embed_base=None, epsilon={"actor": 0.0, "critic": 0.0})
    
    # # _, _, _, _ = env.step(action["action"])
    sampled_actions = action["query"][0]
    # find the nearest action in the true_actions
    sampled_actions_ids = np.array([np.argmin(np.linalg.norm(true_actions - a, axis=1)) for a in sampled_actions])
    sampled_actions = true_actions[sampled_actions_ids]
    # # transform the sampled_actions to the actions 2d
    # feature_map_nystroem_inverse = Nystroem(gamma=.2,random_state=1,n_components=_dim)
    # feature_map_nystroem_inverse = feature_map_nystroem_inverse.fit(true_actions, actions_2d)
    # sampled_actions_2d = feature_map_nystroem_inverse.transform(sampled_actions)
    
    # if args["method_name"] == 'wolp':
    #     sampled_actions = action["action"]
    # elif args["method_name"] == 'flair_joint':
    #     sampled_actions = action["action"]
    # elif args["method_name"] == 'flair_no_ar_critics':
    #     sampled_actions = action["action"]
    # elif args["method_name"] == 'flair_no_linkage':
    #     sampled_actions = action["action"]
    # elif args["method_name"] == 'flair_final':
    #     sampled_actions = action["action"]

    # frame = env.render(mode="rgb_array")

    # generate q values
    s = np.tile(A=s[0, :], reps=(num_samples, 1)).astype(np.float32)
    q = agent.main_ref_critic(torch.cat([torch.tensor(s), torch.tensor(true_actions)], dim=-1)).detach().cpu().numpy()
    # fig = plt.figure(figsize=(15, 5))
    # ax1 = fig.add_subplot(131)
    # ax1.imshow(frame)
    # ax1.set_title("Frame")


    # ax2 = fig.add_subplot(132, projection="3d", facecolor="w")
    # ax2.scatter(actions_2d[:, 0], actions_2d[:, 1], q.flatten())
    # ax2.set_xlabel("a1")
    # ax2.set_ylabel("a2")
    # ax2.set_zlabel("Q(s, a)")
    # ax2.set_xlim(actions_2d[:, 0].min() - 0.01, actions_2d[:, 0].max() + 0.01)
    # ax2.set_ylim(actions_2d[:, 1].min() - 0.01, actions_2d[:, 1].max() + 0.01)
    # ax2.set_title("Q-val distribution")

    # ax3 = fig.add_subplot(133, projection="3d", facecolor="w")
    # X, Y = np.meshgrid(a2d_x, a2d_y)
    # Z = q.reshape(num_samples_sqrt, num_samples_sqrt)
    # ax3.contour3D(X, Y, Z, 50, cmap='binary')
    # ax3.set_xlabel('x')
    # ax3.set_ylabel('y')
    # ax3.set_zlabel('z')
    # ax3.set_title('3D contour')

    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection="3d", facecolor="w")
    X, Y = np.meshgrid(a2d_x, a2d_y)
    Z = q.reshape(num_samples_sqrt, num_samples_sqrt)
    # ax.contour3D(X, Y, Z, 50, cmap='plasma')
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap='viridis', edgecolor='none', alpha=.85)

    if sampled_actions is not None:
        # random sample k different ids from len(actions_2d)
        # sampled_actions_ids = random.sample(range(len(actions_2d)), k=10)
        sampled_actions_2d = actions_2d[sampled_actions_ids]
        x, y = sampled_actions_2d[:, 0], sampled_actions_2d[:, 1]
        z = q[sampled_actions_ids]
        ax.scatter(x, y, z, s=40, c='r', marker='.', zorder=10, alpha=.85)
        for i in range(len(sampled_actions_ids)):
            ax.text(x[i], y[i], z[i], f'Q{i}', size=10, zorder=10, color='r') 
    ax.set_xlabel('Action Space')
    ax.set_ylabel('')
    ax.set_zlabel('Q-value')

    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_zticklabels([])
    ax.set_title('Q-value Distribution')

    plt.tight_layout()
    # plt.show()
    plt.savefig(f"{plot_save_dir}/3dplot")
    plt.clf()
    print("Finished in {:.2f} seconds".format(time.time() - start_time))