import copy
import time

import matplotlib
import numpy as np
import torch
from matplotlib.lines import Line2D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from sklearn.decomposition import PCA
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler

from large_rl.commons.utils import logging

matplotlib.use('Agg')
import matplotlib.pyplot as plt

plt.style.use('seaborn-pastel')
import os

from large_rl.commons.launcher import launch_embedding, launch_env, launch_agent
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.args import get_all_args


def actions22d(actions_true):
    # pca = PCA(n_components=2)
    pca = PCA(n_components=2, svd_solver='randomized')
    actions_2d = pca.fit_transform(actions_true)
    return actions_2d


def get_correspond_q_val(critic_network, state, actions, device):
    state = state.float()
    actions = torch.tensor(actions).to(device)
    actions = actions.float()
    q = critic_network(torch.cat([state, actions], dim=-1).to(device)).detach().cpu().numpy()
    return q


def plot_surface_3d(x, y, z, ax, color=None, alpha=0.7, label=None):
    if color == None:
        ax.plot_surface(x, y, z, label=label, cmap='coolwarm', alpha=alpha, edgecolor='none')
    else:
        ax.plot_surface(x, y, z, label=label, color=color, alpha=alpha, edgecolor='none')


# Define the Arrow3D class
# class Arrow3D(FancyArrowPatch):
#     def __init__(self, xs, ys, zs, *args, **kwargs):
#         super().__init__((0,0), (0,0), *args, **kwargs)
#         self._verts3d = xs, ys, zs

#     def do_3d_projection(self, renderer=None):
#         xs3d, ys3d, zs3d = self._verts3d
#         xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
#         self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))

#         return np.min(zs)
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)


def add_point(x, y, z, text, ax, add_dashed_line=True, font_size=20, x_scale=None, y_scale=None, z_max=None,
              add_all_dots_to_main=False):
    point_position = [x[0], y[0], z[0]]
    # text_position = [x[0], y[0], z[0]]
    text_position = [x[0], y[0], z_max]

    ax.scatter(point_position[0], point_position[1], point_position[2], color='k', s=100)

    if not add_all_dots_to_main:
        ax.text(text_position[0], text_position[1], text_position[2], text, size=font_size, color='k')

        # add an arrow from text to point
        arrow_prop_dict = dict(mutation_scale=20, arrowstyle='Fancy', color='k', shrinkA=0, shrinkB=0)
        a = Arrow3D([text_position[0], point_position[0]],
                    [text_position[1], point_position[1]],
                    [text_position[2], point_position[2]], **arrow_prop_dict)
        ax.add_artist(a)

    if add_dashed_line and not add_all_dots_to_main:
        ax.plot([x[0], x[0]], [y_scale[1], y[0]], [z[0], z[0]], color="black", linestyle="--")  # Line to y-z plane
        ax.plot([x_scale[0], x[0]], [y[0], y[0]], [z[0], z[0]], color="black", linestyle="--")  # Line to x-z plane


def add_surface_label(x_scale, y_scale, z, ax, text="q", font_size=20, z_shift=-30, x_shift=-1):
    x = x_scale[0]
    y = y_scale[0]
    ax.text(x + x_shift, y, z + z_shift, text, size=font_size, color='k')


def add_legend(ax, ar_colors):
    # We need to create proxy artists for the legend
    legend_main = Line2D([0], [0], linestyle="none", marker='o', markersize=10, markerfacecolor='blue')
    legend_ars = [Line2D([0], [0], linestyle="none", marker='o', markersize=10, markerfacecolor=ar_colors[i])
                  for i in range(len(ar_colors))]
    legends = [legend_main] + legend_ars
    labels = ["main critic"] + [f"ar critic {i}" for i in range(len(ar_colors))]
    ax.legend(legends, labels, numpoints=1)


def get_color(index):
    color_ids = ['Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
                 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
                 'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']
    flat_cmap = plt.cm.get_cmap(color_ids[index])  # Oranges is a gradient of orange colors
    flat_color = flat_cmap(0.5)  # Get the middle color from the Oranges colormap
    return flat_color


def gen_critic_plot_2d(x, y, sampled_actions_2d, z, sampled_point_z, plot_save_dir, assigned_ax,
                       ar_critic=False, sampled_action_index=None, previous_q_value=None, q_scale=None,
                       flag_threshold=False, args=None, smooth=True, add_all_dots_to_main=False):
    # fig = plt.figure(figsize=(14, 9))
    # ax = fig.add_subplot(111, projection='3d')
    # ax.plot_surface(x, y, z, cmap='viridis')
    font_size = 30
    ax = assigned_ax
    X = x
    Y = y
    Z = z
    # plot_surface_3d(X, Y, Z, ax, label = f"ar critic {sampled_action_index}" if ar_critic else "main critic")

    # # Apply a Gaussian filter to Z-values
    # from scipy.ndimage import gaussian_filter
    # Z_smoothed = gaussian_filter(Z, sigma=0.5)  # Adjust the sigma value to control the smoothing level
    # plot_surface_3d(X, Y, Z_smoothed, ax, label = f"ar critic {sampled_action_index}" if ar_critic else "main critic")

    if smooth:
        from scipy.signal import savgol_filter
        # Assuming Z is your data matrix
        if args["env_name"].lower() in ["mine", "recsim-data"]:
            window_length = 5
        elif args["env_name"].lower().startswith("mujoco"):
            window_length = 7
        elif args["env_name"].lower().startswith("recsim"):
            window_length = 7
        Z_sg = savgol_filter(Z, window_length=window_length, polyorder=2, axis=0)  # Apply along one axis
        Z_sg = savgol_filter(Z_sg, window_length=window_length, polyorder=2, axis=1)  # Apply along the other axis
        plot_surface_3d(X, Y, Z_sg, ax, label=f"ar critic {sampled_action_index}" if ar_critic else "main critic")
    else:
        plot_surface_3d(X, Y, Z, ax, label=f"ar critic {sampled_action_index}" if ar_critic else "main critic")

    # # Plot the flat surface with z = 1 using the color from the Oranges colormap
    # surf2 = ax.plot_surface(x, y, z_flat, color=flat_color, alpha=0.5)

    x_scale = [np.min(X), np.max(X)]
    y_scale = [np.min(Y), np.max(Y)]

    if args["env_name"].lower() in ["mine", "recsim-data"]:
        x_shift = 0
        z_shift = 0
    elif args["env_name"].lower().startswith("mujoco"):
        x_shift = -1
        z_shift = -30
    elif args["env_name"].lower().startswith("recsim"):
        x_shift = 0
        z_shift = 0
    if sampled_action_index is not None:
        sampled_point_x = [sampled_actions_2d[sampled_action_index, 0]]
        sampled_point_y = [sampled_actions_2d[sampled_action_index, 1]]
        sampled_point_z = [sampled_point_z[sampled_action_index, 0]]

        z_max_adjust = np.max(Z) * 0.1
        add_point(sampled_point_x, sampled_point_y, sampled_point_z, f'$a_{sampled_action_index}$',
                  ax, font_size=font_size, x_scale=x_scale, y_scale=y_scale, z_max=np.max(Z) + z_max_adjust,
                  add_all_dots_to_main=add_all_dots_to_main)

        if previous_q_value is not None:
            flat_z = np.ones(Z.shape) * previous_q_value
            color = get_color(sampled_action_index - 1)
            plot_surface_3d(X, Y, flat_z, ax, color=color,
                            alpha=0.4, label=f"q {sampled_action_index - 1}")

            x_scale = [np.min(X), np.max(X)]
            y_scale = [np.min(Y), np.max(Y)]
            # surface_label_text = f"$Q(a_{sampled_action_index - 1})$"
            if sampled_action_index == 1:
                surface_label_text = f"$Q$($a_0$)"
            else:
                surface_label_text = f"Max($Q$($a_0$), $Q$($a_1$))"
            add_surface_label(x_scale, y_scale, previous_q_value, ax,
                              text=surface_label_text,
                              font_size=font_size, z_shift=z_shift, x_shift=x_shift)

    else:
        sampled_actions = sampled_actions_2d
        sampled_point_xs = sampled_actions[:, 0]
        sampled_point_ys = sampled_actions[:, 1]
        sampled_point_zs = sampled_point_z[:, 0]
        # add flat surface
        for i in range(len(sampled_actions_ids)):
            flat_z = np.ones(Z.shape) * sampled_point_zs[i]
            if add_all_dots_to_main:
                color = get_color(0)
            else:
                color = get_color(i)
            if not add_all_dots_to_main:
                plot_surface_3d(X, Y, flat_z, ax, color=color, alpha=0.4, label=f"$q_{i}$")

                x_scale = [np.min(X), np.max(X)]
                y_scale = [np.min(Y), np.max(Y)]
                # z_shifts = [-10, 10, 30]
                # z_shifts = [-30, 0, 0]
                # x_shifts = [-0, 0.4, 0.8]
                # add_surface_label(x_scale, y_scale, sampled_point_zs[i], ax, 
                #                   text=f"$Q(a_{i})$", 
                #                   font_size=font_size, z_shift=z_shifts[i], x_shift=x_shifts[i])
                add_surface_label(x_scale, y_scale, sampled_point_zs[i], ax,
                                  text=f"$Q(a_{i})$",
                                  font_size=font_size, z_shift=z_shift, x_shift=x_shift)
            # else:
            #     sampled_point_x = [sampled_point_xs[i]]
            #     sampled_point_y = [sampled_point_ys[i]]
            #     sampled_point_z = [sampled_point_zs[i]]
            #     z_max_adjust = np.max(Z) * 0.1
            #     add_point(sampled_point_x, sampled_point_y, sampled_point_z, f'$a_{sampled_action_index}$', 
            #             ax, font_size=font_size, x_scale=x_scale, y_scale=y_scale, z_max=np.max(Z)+z_max_adjust, add_all_dots_to_main=add_all_dots_to_main)

    # remove numbers
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    # set labels
    ax.set_xlabel('Action Space', fontsize=font_size)
    ax.set_zlabel('Q-value', fontsize=font_size)
    # ax.legend()

    if q_scale is not None:
        ax.set_zlim([q_scale[0], q_scale[1]])

    if args["env_name"].lower() in ["mine", "recsim-data"]:
        title_env_name = args["env_name"]
    elif args["env_name"].lower().startswith("mujoco"):
        title_env_name = args["env_name"].split("-")[1]
        if args['reacher_validity_type'] == 'box':
            title_env_name += "-Hard"
        else:
            title_env_name += "-Easy"
    elif args["env_name"].lower().startswith("recsim"):
        title_env_name = args["env_name"]
    # replace the '_' with ' ', and capitalize the first letter of each word
    title_env_name = title_env_name.replace("_", " ").title()
    if ar_critic and sampled_action_index:
        ax.set_title(f'{title_env_name} $\Psi_{sampled_action_index}$', fontsize=font_size)
    else:
        ax.set_title(f'{title_env_name} Primary Critic', fontsize=font_size)
    if add_all_dots_to_main:
        ax.set_title(f'Walker2d', fontsize=font_size)

    # if ar_critic and sampled_action_index == len(sampled_actions_ids) - 1:
    if ar_critic:
        plt.tight_layout()
        title = f"{plot_save_dir}/ar_critic_plot_{sampled_action_index}"
        if flag_threshold:
            title += "_threshold"
        title += ".pdf"
        plt.savefig(title, format="pdf")
        plt.clf()
    elif not ar_critic:
        title = f"{plot_save_dir}/main_critic_plot.pdf"
        if add_all_dots_to_main:
            title = f"{plot_save_dir}/Walker2d.pdf"
        plt.tight_layout()
        plt.savefig(title, format="pdf")
        plt.clf()


def create_plot_save_dir(timestep, env_name):
    root_dir = f"./src/3Dplots/"
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)
    current_time = time.strftime("%Y-%m-%d-%H-%M-%S")
    sub_dir = os.path.join(root_dir, f"{env_name}", f"{current_time}_{timestep}")
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)
    return sub_dir


def preprocess_state(agent, s, main_critic=False):
    if args["env_name"].lower() in ["mine", "recsim-data"] and not args['mw_obs_flatten']:
        if main_critic:
            _s = agent.main_ref_critic_obs_enc(s).to(device)
        else:
            _s = agent.main_ar_critic_obs_enc(s).to(device)
    else:
        _s = s
    return _s


def gen_z(actions_2d, q, x, y, sampled_actions_2d):
    # Step 3: Fit a model to interpolate z-values
    kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
    gpr_main = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
    gpr_main.fit(actions_2d, q)  # y is your 1-dimensional labels

    # Use the model to predict values for the grid
    X_grid = np.vstack((x.ravel(), y.ravel())).T
    z = gpr_main.predict(X_grid).reshape(x.shape)
    z_sampled = gpr_main.predict(sampled_actions_2d).reshape(-1, 1)
    return z, z_sampled


def transform_action_and_q(actions_true, q_main, q_ars, sampled_actions_ids, num_sqrt=100):
    # Step 1: Dimensionality Reduction
    actions_2d = actions22d(actions_true)
    sampled_actions_2d = actions_2d[sampled_actions_ids]

    # Step 2: Create a grid for plotting
    x = np.linspace(min(actions_2d[:, 0]), max(actions_2d[:, 0]), num_sqrt)
    y = np.linspace(min(actions_2d[:, 1]), max(actions_2d[:, 1]), num_sqrt)
    x, y = np.meshgrid(x, y)

    # Step 3: Fit a model to interpolate z-values
    kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
    gpr_main = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
    gpr_main.fit(actions_2d, q_main)  # y is your 1-dimensional labels

    # Use the model to predict values for the grid
    z_main, z_main_sampled = gen_z(actions_2d, q_main, x, y, sampled_actions_2d)

    zs_ar = []
    zs_ar_sampled = []
    for i in range(3):
        q_ar = q_ars[i]
        z_ar, z_ar_sampled = gen_z(actions_2d, q_ar, x, y, sampled_actions_2d)
        zs_ar.append(z_ar)
        zs_ar_sampled.append(z_ar_sampled)
    zs_ar = np.array(zs_ar)
    zs_ar_sampled = np.array(zs_ar_sampled)
    return x, y, sampled_actions_2d, z_main, z_main_sampled, zs_ar, zs_ar_sampled


if __name__ == "__main__":
    args = get_all_args()
    args = vars(args)
    smooth = True
    enable_logging = True
    concat_ar_plots = True
    add_all_dots_to_main = True
    # num_samples_sqrt = 10
    q_threshold = 1e-2

    if args["env_name"].lower().startswith("recsim"):
        q_threshold = 5e-2
    elif args["env_name"].lower().startswith("mine"):
        q_threshold = 1e-2

    if args["env_name"].lower().startswith("mujoco"):
        num_samples = 100
        mesh_grid_sqrt = 41
    elif args["env_name"].lower().startswith("mine"):
        num_samples = 100
        mesh_grid_sqrt = 11
    elif args["env_name"].lower().startswith("recsim"):
        num_samples = 100
        mesh_grid_sqrt = 41
    else:
        raise ValueError(f"env_name {args['env_name']} not supported")
    device = args.get("device", "cpu")
    set_randomSeed(seed=args["seed"])

    # set environment
    args["num_envs"] = 1
    # env = launch_env(args=args)
    env_name = args["env_name"]
    # if env_name.lower().startswith("recsim"):
    #     args["_max_l2_dist"] = np.linalg.norm(np.ones(args["recsim_dim_embed"]) * 2)
    # args["env_max_action"] = 1.
    # if env_name == "mine" and args['mw_show_action_embeddings']:
    #     env.show_action_embeddings()
    # if env_name.startswith("mujoco"):
    #     args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
    #     args["env_max_action"] = float(env.action_space[0].high[0])
    # if env_name == "mine":
    #     args["_max_l2_dist"] = np.linalg.norm(np.ones(args["mw_action_dim"]) * 2)
    # args["TD3_policy_noise"] *= args["env_max_action"]
    # args["TD3_noise_clip"] *= args["env_max_action"]
    if env_name.lower().startswith("recsim"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["recsim_dim_embed"]) * 2)
    env = launch_env(args=args)
    eval_env = launch_env(args=args)
    args["env_max_action"] = 1.
    if env_name == "mine" and args['mw_show_action_embeddings']:
        env.show_action_embeddings()
    if env_name.startswith("mujoco"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
        args["env_max_action"] = float(env.action_space[0].high[0])
    if env_name == "mine":
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["mw_action_dim"]) * 2)
    args["TD3_policy_noise"] *= args["env_max_action"]
    args["TD3_noise_clip"] *= args["env_max_action"]
    logging("======== Env IS READY ========")
    # Expand noisy dimensions
    if args["env_dim_extra"] > 0 and args["env_name"] != "recsim-data":
        _emb = np.random.random(size=(env.act_embedding.shape[0], args["env_dim_extra"]))
        __emb = MinMaxScaler(feature_range=(-0.01, 0.01)).fit_transform(env.act_embedding)
        # _emb += np.random.random(size=env.act_embedding.shape)
        _emb = np.hstack([__emb, _emb])
        if args["env_act_emb_tSNE"]:
            logging("======== START: t-SNE on Act Emb ========")
            _emb = TSNE(n_components=__emb.shape[-1],
                        perplexity=3,
                        # init="pca",
                        random_state=0,
                        method="exact",
                        n_iter=1000,
                        n_jobs=-1).fit_transform(_emb)
            logging("======== DONE: t-SNE on Act Emb ========")
    else:
        _emb = env.act_embedding

    if env_name in ["recsim", "mine", "recsim-data"]:
        dict_embedding = launch_embedding(args=args)
        dict_embedding["item"].load(embedding=_emb)
        dict_embedding["task"].load(embedding=env.task_embedding)
    elif env_name.startswith("mujoco"):
        dict_embedding = {'item': None, 'task': None}
    else:
        raise NotImplementedError

    if env_name == "mine":
        action_dim = args["mw_action_dim"]
    elif env_name.startswith("mujoco"):
        action_dim = env.action_space[0].shape[0]
        _high, _low = env.action_space[0].high, env.action_space[0].low
    elif env_name.startswith("recsim"):
        action_dim = _emb.shape[-1]
    else:
        raise ValueError(f"env_name {env_name} not supported")

    if enable_logging:
        logging("======== FINISHED: get embedding ========")

    # agent load
    prefix = args["method_name"]
    if env_name.startswith("mujoco"):
        suffix = env_name.split('-')[1]
    elif env_name == "mine":
        suffix = "mine"
    elif env_name == "recsim":
        suffix = "recsim"
    args["agent_load_path"] = f"./weights/{prefix}-{suffix}"
    dict_load_epoch = {}
    if prefix == 'savo_refined':
        dict_load_epoch["ant"] = 18899
        dict_load_epoch["half_cheetah"] = 19699
        dict_load_epoch["hopper"] = 21199
        dict_load_epoch["inverted_double_pendulum"] = 21299
        dict_load_epoch["inverted_pendulum"] = 20799
        dict_load_epoch["walker2d"] = 20499
        dict_load_epoch["mine"] = 6099
        dict_load_epoch["recsim"] = 9499
    elif prefix == 'savo_threshold_refined':
        dict_load_epoch["ant"] = 34099
        dict_load_epoch["half_cheetah"] = 34199
        dict_load_epoch["hopper"] = 37299
        dict_load_epoch["inverted_double_pendulum"] = 39699
        dict_load_epoch["inverted_pendulum"] = 39699
        dict_load_epoch["walker2d"] = 35599
    else:
        raise ValueError(f"prefix {prefix} not supported")
    args["agent_load_epoch"] = dict_load_epoch.get(suffix, 0)
    agent = launch_agent(args=args, env=env)
    agent.load_model(epoch=args['agent_load_epoch'])
    if enable_logging:
        logging("======== FINISHED: get agent ========")

    # get 2d action first
    if env_name.startswith("mujoco"):
        # random sample num_samples actions in env.action_space
        actions_true_basic = np.random.uniform(low=_low, high=_high, size=(num_samples, action_dim))
    elif env_name == "mine":
        # since this is a discrete action space, we just need to sample num_samples actions
        actions_true_basic_ids = np.random.choice(range(0, env.action_space[0].n), num_samples, replace=False)
        actions_true_basic = env.act_embedding[actions_true_basic_ids]
    elif env_name.startswith("recsim"):
        actions_true_basic_ids = np.random.choice(range(0, _emb.shape[0]), num_samples, replace=False)
        actions_true_basic = _emb[actions_true_basic_ids]
    timesteps = 10000
    s = env.reset()
    for timestep in range(timesteps):
        true_actions = copy.deepcopy(actions_true_basic)
        # get state
        s = torch.tensor(s).to(device).float()
        action = agent.select_action(obs=s, act_embed_base=dict_embedding["item"],
                                     epsilon={"actor": 0.0, "critic": 0.0})
        # get next state
        s_next, _, d, _ = env.step(action["action"])
        if d:
            s_next = env.reset()
        if enable_logging:
            logging("======== FINISHED: get action ========")

        if env_name.startswith("mujoco"):
            sampled_actions = action["query"][0, :, action_dim:]
        elif env_name.startswith("mine"):
            action_query = action["query"]
            action_query = torch.tensor(action_query).to(device).float()
            if args["env_name"].startswith("mujoco"):
                act_embed = None
            else:
                act_embed = dict_embedding["item"].embedding_torch[None, ...].repeat(1, 1, 1).detach()
            sampled_actions = agent._perform_kNN(act_embed, action_query, 1)[2][0]
            sampled_actions = sampled_actions.detach().cpu().numpy()
        elif env_name.startswith("recsim"):
            action_query = action["query"]
            action_query = torch.tensor(action_query).to(device).float()
            if args["env_name"].startswith("mujoco"):
                act_embed = None
            else:
                act_embed = dict_embedding["item"].embedding_torch[None, ...].repeat(1, 1, 1).detach()
            sampled_actions = agent._perform_kNN(act_embed, action_query, 1)[2][0]
            sampled_actions = sampled_actions.detach().cpu().numpy()
        # find the nearest action in the true_actions
        sampled_action_ids = []
        for sampled_action in sampled_actions:
            distance = np.linalg.norm(sampled_action - true_actions, axis=1)
            min_distance = np.min(distance)
            if min_distance > 0:
                # random sample an interger from 0 to len(true_actions)
                sampled_action_id = np.argmin(distance)
                if sampled_action_id in sampled_action_ids:
                    sampled_action_id += 1 * np.random.randint(-1, 1)
                true_actions[sampled_action_id] = sampled_action
            else:
                sampled_action_id = np.argmin(distance)
            sampled_action_ids.append(sampled_action_id)
        sampled_actions_ids = np.array(sampled_action_ids)

        if enable_logging:
            logging("======== FINISHED: get sampled_actions ========")

        # get state s
        s = s.cpu().numpy()
        s_dim_len = len(s.shape)
        # reps equals to (num_samples, 1) if s_dim_len == 1, else (num_samples, 1, ..., 1) where the number of 1s is s.shape[1]
        reps = (num_samples,) + (1,) * (s_dim_len - 1)
        s = torch.tensor(np.tile(A=s[0, :], reps=reps).astype(np.float32)).to(device)
        if enable_logging:
            logging("======== FINISHED: get state ========")

        shared_scale = None
        # gen q_main
        _s = preprocess_state(agent, s, main_critic=True)
        critic_network = agent.main_ref_critic
        # get q values
        q_main = get_correspond_q_val(critic_network, _s,
                                      true_actions, device)

        # if the difference is too marginal, dont plot
        if np.argmax(q_main[sampled_action_ids]) == 0:
            s = s_next
            continue  # dont plot if the best action is the first one
        q_max = np.max(q_main[sampled_action_ids])
        q0 = q_main[sampled_action_ids[0]]
        # if (q_max - q_main[sampled_action_ids[0]]) / q_max < q_threshold and (q_max - q_main[sampled_action_ids[0]]) / q_max > 0:
        if (q_max - q0) / abs(q0) < q_threshold:
            print(q_max)
            print(q0)
            print((q_max - q0) / abs(q0))
            s = s_next
            continue  # dont plot if the best action is the first one
        print(f"timestep: {timestep}", q_main[sampled_action_ids])
        if enable_logging:
            logging("======== FINISHED: get main q values ========")
        # gen q_ar
        q_ars = []
        _s = preprocess_state(agent, s, main_critic=False)
        _s = _s[:, None, :]
        for i in range(len(sampled_actions_ids)):
            # get actions
            list_actions = np.zeros((true_actions.shape[0], len(sampled_actions_ids), action_dim))  # 10k * 3 * dim
            # for index before i, use the sampled_actions
            for j in range(i):
                sampled_action = true_actions[sampled_actions_ids[j]]  # act_dim
                sampled_action = sampled_action[None, :]  # 1, act_dim
                list_actions[:, j, :] = sampled_action.repeat(true_actions.shape[0], axis=0)  # 10k, 1, act_dim

            # for index i, use the true_actions
            list_actions[:, i, :] = true_actions
            action_seq = torch.tensor(list_actions, device=device).float()
            # get Q values
            with torch.no_grad():
                Q = agent.main_ar_critic(state=_s,
                                         action_seq=action_seq,
                                         alternate_conditioning=None,
                                         true_action=None)
            q_ar = Q[:, i].detach().cpu().numpy()
            q_ars.append(q_ar)

        q_ars = np.array(q_ars)
        if enable_logging:
            logging("======== FINISHED: get ar q values ========")

        # plot 2d plot
        plot_save_dir = create_plot_save_dir(timestep, env_name)
        # transform action space
        ## prepare data for plotting main critic
        x, y, sampled_actions_2d, z_main, z_main_sampled, zs_ar, zs_ar_sampled = transform_action_and_q(
            true_actions, q_main, q_ars, sampled_actions_ids, num_sqrt=mesh_grid_sqrt)
        if enable_logging:
            logging("======== FINISHED: get data for plotting ========")

        # align q value scale for all plots
        q_scale = [np.min([np.min(z_main), np.min(zs_ar)]), np.max([np.max(z_main), np.max(zs_ar)])]

        # main
        # fig_size = (5, 5)
        fig_size = (10, 10)
        fig = plt.figure(figsize=fig_size)
        ax = fig.add_subplot(1, 1, 1, projection="3d", facecolor="w")
        gen_critic_plot_2d(x, y, sampled_actions_2d, z_main, z_main_sampled, plot_save_dir, ax, q_scale=q_scale,
                           args=args, smooth=smooth)
        if enable_logging:
            logging("======== FINISHED: get main critic plot ========")

        # shared_plot = None
        # ar
        if not concat_ar_plots:
            fig = plt.figure(figsize=fig_size)
            ax = fig.add_subplot(1, 1, 1, projection="3d", facecolor="w")
            for i in range(len(sampled_actions_ids)):
                gen_critic_plot_2d(x, y, sampled_actions_2d, zs_ar[i], zs_ar_sampled[i],
                                   plot_save_dir, ax, ar_critic=True, q_scale=q_scale, args=args, smooth=smooth)
        else:
            # fig, axs = plt.subplots(1, 3, projection="3d", figsize=(15, 5))
            for i in range(len(sampled_actions_ids)):
                fig = plt.figure(figsize=fig_size)
                ax = fig.add_subplot(1, 1, 1, projection="3d", facecolor="w")
                previous_q = None if i == 0 else z_main_sampled[i - 1]
                gen_critic_plot_2d(x, y, sampled_actions_2d, zs_ar[i], zs_ar_sampled[i],
                                   plot_save_dir, ax, ar_critic=True, sampled_action_index=i,
                                   previous_q_value=previous_q, q_scale=q_scale, args=args, smooth=smooth)
        if enable_logging:
            logging("======== FINISHED: get ar critic plot ========")

        # threshold
        # fig, axs = plt.subplots(1, 3, projection="3d", figsize=(15, 5))
        for i in range(len(sampled_actions_ids)):
            fig = plt.figure(figsize=fig_size)
            ax = fig.add_subplot(1, 1, 1, projection="3d", facecolor="w")
            previous_q = None if i == 0 else z_main_sampled[i - 1]
            if i:
                max_before = np.max(z_main_sampled[:i])
                # create zs_ar_new, zs_ar_sampled_new. zs_ar_new = max(zs_ar[i], max_before)
                zs_ar_new = np.maximum(zs_ar[i], max_before)
            else:
                zs_ar_new = zs_ar[i]
            gen_critic_plot_2d(x, y, sampled_actions_2d, zs_ar_new, zs_ar_sampled[i],
                               plot_save_dir, ax, ar_critic=True, sampled_action_index=i,
                               previous_q_value=previous_q, q_scale=q_scale, flag_threshold=True, args=args,
                               smooth=smooth)
        if enable_logging:
            logging("======== FINISHED: get ar critic threshold plot ========")

        # add_dotted_plots for main
        if add_all_dots_to_main:
            sampled_actions_ids = np.array(range(len(true_actions)))
            x, y, sampled_actions_2d, z_main, z_main_sampled, zs_ar, zs_ar_sampled = transform_action_and_q(
                true_actions, q_main, q_ars, sampled_actions_ids, num_sqrt=mesh_grid_sqrt)

            # fig_size = (5, 5)
            fig_size = (10, 10)
            fig = plt.figure(figsize=fig_size)
            ax = fig.add_subplot(1, 1, 1, projection="3d", facecolor="w")
            gen_critic_plot_2d(x, y, sampled_actions_2d, z_main, z_main_sampled, plot_save_dir, ax, q_scale=q_scale,
                               args=args, smooth=smooth, add_all_dots_to_main=True)
            if enable_logging:
                logging("======== FINISHED: get main critic plot ========")

        s = s_next
