import copy
from large_rl.commons.utils import logging
from random import randint
import random
import time
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
from sklearn.kernel_approximation import Nystroem
from mpl_toolkits.mplot3d import Axes3D

from large_rl.commons.launcher import launch_embedding, launch_env, launch_agent
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.args import get_all_args

def to_true_action(actions_2d, feature_map, args, rec_sim_emb=None):
    if args["env_name"].startswith("recsim"):
        # true_actions = feature_map.fit_transform(actions_2d)
        # # reflact the true_actions to [0, 1]
        # true_actions = MinMaxScaler(feature_range=(-1, 1)).fit_transform(true_actions)
        # # find the nearest action in the recsim_emb
        # true_actions_ids = np.array([np.argmin(np.linalg.norm(rec_sim_emb.embedding_np - a, axis=1)) for a in true_actions])
        true_action_ids = np.random.choice(range(len(rec_sim_emb.embedding_np)), size=len(actions_2d))
        # reflact the true_actions to the original action space
        true_actions = rec_sim_emb.embedding_np[true_action_ids]
    else:
        true_actions = feature_map.fit_transform(actions_2d)
        # reflact the true_actions to [0, 1]
        true_actions = MinMaxScaler().fit_transform(true_actions)
        # reflact the true_actions to the original action space
        true_actions = true_actions * (_high - _low) + _low
    return true_actions


def get_correspond_q_val(critic_network, state, actions, device):
    q = critic_network(torch.cat([state, torch.tensor(actions).to(device)], dim=-1).to(device)).detach().cpu().numpy()
    return q


def gen_main_critic_plot(a2d_x, a2d_y, num_samples_sqrt, q, actions_2d, sampled_actions_ids, timestep, plot_save_dir, sampled_actions=None):
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection="3d", facecolor="w")
    X, Y = np.meshgrid(a2d_x, a2d_y)
    Z = q.reshape(num_samples_sqrt, num_samples_sqrt)
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap='viridis', edgecolor='none', alpha=.85)

    if sampled_actions is not None:
        # random sample k different ids from len(actions_2d)
        # sampled_actions_ids = random.sample(range(len(actions_2d)), k=10)
        sampled_actions_2d = actions_2d[sampled_actions_ids]
        x, y = sampled_actions_2d[:, 0], sampled_actions_2d[:, 1]
        z = q[sampled_actions_ids]
        ax.scatter(x, y, z, s=40, c='r', marker='.', zorder=10, alpha=.85)
        for i in range(len(sampled_actions_ids)):
            ax.text(x[i], y[i], z[i], f'Q{i}', size=10, zorder=10, color='r') 
    ax.set_xlabel('Action Space')
    ax.set_ylabel('')
    ax.set_zlabel('Q-value')

    # ax.set_yticklabels([])
    # ax.set_xticklabels([])
    # ax.set_zticklabels([])
    ax.set_title('Q-value Distribution')

    plt.tight_layout()
    # plt.show()
    plt.savefig(f"{plot_save_dir}/main_critic_plot_{timestep}")
    plt.clf()


def gen_critic_plot_2d(true_actions, q, sampled_actions_ids, timestep, plot_save_dir, ar_critic=False, sampled_action_index=None):
    # reshape q to (num_samples, 1)
    q = q.reshape(-1, 1)
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, facecolor="w")
    X = true_actions[:, 0]
    Y = q[:, 0]
    ax.plot(X, Y, c='b', marker='.', zorder=10, alpha=.85)


    if ar_critic:
        id_in_true_actions = sampled_actions_ids[sampled_action_index]
        x = true_actions[id_in_true_actions, 0]
        y = q[id_in_true_actions, 0]
        ax.scatter(x, y, s=40, c='r', marker='.', zorder=10, alpha=.85)
        ax.text(x, y, f'Q{sampled_action_index}', size=10, zorder=10, color='r') 
    else:
        sampled_actions = true_actions[sampled_actions_ids]
        x = sampled_actions[:, 0]
        y = q[sampled_actions_ids, 0]
        ax.scatter(x, y, s=40, c='r', marker='.', zorder=10, alpha=.85)
        for i in range(len(sampled_actions_ids)):
            ax.text(x[i], y[i], f'Q{i}', size=10, zorder=10, color='r') 
    
    ax.set_xlabel('Action Space')
    ax.set_ylabel('Q-value')

    ax.set_title('Q-value Distribution')

    plt.tight_layout()
    if ar_critic:
        plt.savefig(f"{plot_save_dir}/ar_critic_plot_{sampled_action_index}")
    else:
        plt.savefig(f"{plot_save_dir}/main_critic_plot")
    plt.clf()


def gen_ar_critic_plot(a2d_x, a2d_y, num_samples_sqrt, q, actions_2d, sampled_actions_id, 
                       timestep, plot_save_dir, sampled_action=None, q_index=None):
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection="3d", facecolor="w")
    X, Y = np.meshgrid(a2d_x, a2d_y)
    Z = q.reshape(num_samples_sqrt, num_samples_sqrt)
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap='viridis', edgecolor='none', alpha=.85)

    if sampled_action is not None:
        # random sample k different ids from len(actions_2d)
        # sampled_actions_ids = random.sample(range(len(actions_2d)), k=10)
        sampled_actions_2d = actions_2d[sampled_actions_id]
        x, y = sampled_actions_2d[0], sampled_actions_2d[1]
        z = q[sampled_actions_id]
        ax.scatter(x, y, z, s=40, c='r', marker='.', zorder=10, alpha=.85)
        ax.text(x, y, z, f'Q{q_index}', size=10, zorder=10, color='r') 
    ax.set_xlabel('Action Space')
    ax.set_ylabel('')
    ax.set_zlabel('Q-value')

    # ax.set_yticklabels([])
    # ax.set_xticklabels([])
    # ax.set_zticklabels([])
    ax.set_title('Q-value Distribution')

    plt.tight_layout()
    # plt.show()
    plt.savefig(f"{plot_save_dir}/ar_critic_plot{q_index}")
    plt.clf()


if __name__ == "__main__":
    generate_main_q_val = True
    generate_ar_q_val = True
    # get start time
    start_time = time.time()
    set_randomSeed(seed=randint(0, 1000000))
    args = get_all_args()
    args = vars(args)

    if args["env_name"] == "recsim-data":
        DATASET_PATH = ""
        args["recsim_data_dir"] = os.path.join(DATASET_PATH, args["recsim_data_dir"])
        args["user_embedding_path"] = os.path.join(DATASET_PATH, args["recsim_data_dir"], "user_attr.npy")
        args["item_embedding_path"] = os.path.join(DATASET_PATH, args["recsim_data_dir"], "trained_weight/item.npy")
        args["save_dir"] = os.path.join(DATASET_PATH, args["save_dir"])

    args["agent_load_path"] = './model_log/11711539-25-inverted_pendulum-BOX-savo'
    args['agent_load_epoch'] = 39799
    device = args.get("device", "cpu")
    args["num_envs"] = 1
    env = launch_env(args=args)
    args["env_max_action"] = 1.
    env_name = args["env_name"]
    if env_name == "mine":
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["mw_action_dim"]) * 2)
    if env_name.startswith("mujoco"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
        args["env_max_action"] = float(env.action_space[0].high[0])
    if env_name.lower().startswith("recsim"):
        args["_max_l2_dist"] = np.linalg.norm(np.ones(args["recsim_dim_embed"]) * 2)

    # Expand noisy dimensions
    if args["env_dim_extra"] > 0 and args["env_name"] != "recsim-data":
        _emb = np.random.random(size=(env.act_embedding.shape[0], args["env_dim_extra"]))
        __emb = MinMaxScaler(feature_range=(-0.01, 0.01)).fit_transform(env.act_embedding)
        # _emb += np.random.random(size=env.act_embedding.shape)
        _emb = np.hstack([__emb, _emb])
        if args["env_act_emb_tSNE"]:
            logging("======== START: t-SNE on Act Emb ========")
            _emb = TSNE(n_components=__emb.shape[-1],
                        perplexity=3,
                        # init="pca",
                        random_state=0,
                        method="exact",
                        n_iter=1000,
                        n_jobs=-1).fit_transform(_emb)
    else:
        _emb = env.act_embedding

    if args["recsim_if_tsne_embed"]:  # for Dual-tSNE
        _emb = TSNE(n_components=_emb.shape[-1],
                    perplexity=3,
                    # init="pca",
                    random_state=0,
                    method="exact",
                    n_iter=1000,
                    n_jobs=-1).fit_transform(_emb)

    if env_name in ["recsim", "mine", "recsim-data"]:
        dict_embedding = launch_embedding(args=args)
        dict_embedding["item"].load(embedding=_emb)
        dict_embedding["task"].load(embedding=env.task_embedding)
    elif env_name.startswith("mujoco"):
        dict_embedding = {'item': None, 'task': None}
    else:
        raise NotImplementedError
    logging("======== FINISHED: get embedding ========")
    
    args["TD3_policy_noise"] *= args["env_max_action"]
    args["TD3_noise_clip"] *= args["env_max_action"]
    agent = launch_agent(args=args, env=env)
    agent.load_model(epoch=args['agent_load_epoch'])
    logging("======== FINISHED: get agent ========")


    plot_save_dir_original = f"./src/3Dplots/"

    if not os.path.exists(plot_save_dir_original):
        os.makedirs(plot_save_dir_original)

    if args["env_name"].startswith("recsim"):
        action_dim = dict_embedding["item"].shape[1]
        _high, _low = args["num_all_actions"], 0
    else:
        action_dim = env.action_space[0].shape[0]
        _high, _low = env.action_space[0].high, env.action_space[0].low

    # get 2d action first
    # sample num_samples of different values evenly spaced between(-1, 1)
    num_samples_sqrt = 100
    # num_samples_sqrt = 500
    num_samples = num_samples_sqrt ** 2
    if args["env_name"] == "mujoco-inverted_pendulum":
        # the action space is none-changed
        true_actions_basic = np.linspace(_low, _high, num=num_samples).astype(np.float32)
    else:
        a2d_x = np.linspace(-1, 1, num=num_samples_sqrt).astype(np.float32)
        a2d_y = np.linspace(-1, 1, num=num_samples_sqrt).astype(np.float32)
        # a2d_x is the first dimension of the action, a2d_y is the second dimension of the action, combine them
        actions_2d = np.vstack(np.meshgrid(a2d_x, a2d_y)).reshape(2, -1).T
        # then use 2d actions to get true actions
        feature_map_nystroem = Nystroem(gamma=.2,random_state=1,n_components=action_dim)
        true_actions_basic = to_true_action(actions_2d, feature_map_nystroem, args=args, rec_sim_emb=dict_embedding["item"])

    timesteps=1000
    s = env.reset()
    for timestep in range(timesteps):
        plot_save_dir = os.path.join(plot_save_dir_original, f"{timestep}")
        if not os.path.exists(plot_save_dir):
            os.makedirs(plot_save_dir)

        true_actions = copy.deepcopy(true_actions_basic)
        # get state
        s = torch.tensor(s).to(device)
        action = agent.select_action(obs=s, act_embed_base=dict_embedding["item"], epsilon={"actor": 0.0, "critic": 0.0})
        # get next state
        s_next, _, d, _ = env.step(action["action"])
        # print(action)
        logging("======== FINISHED: get action ========")

        # # _, _, _, _ = env.step(action["action"])
        if args["env_name"].startswith("recsim"):
            sampled_actions_ids_original = action["topk_act"][0, :, 0]
            # turn sampled_actions_ids_original into int
            sampled_actions_ids_original = sampled_actions_ids_original.astype(np.int)
            sampled_actions = dict_embedding["item"].embedding_np[sampled_actions_ids_original]
            sampled_action_ids = []
            for i in range(len(sampled_actions_ids_original)):
                sampled_action = sampled_actions[i]
                distance = np.linalg.norm(sampled_action - true_actions, axis=1)
                min_distance = np.min(distance)
                if min_distance > 0:
                    # random sample an interger from 0 to len(true_actions)
                    sampled_action_id = np.random.choice(range(len(true_actions)))
                    true_actions[sampled_action_id] = sampled_action
                else:
                    sampled_action_id = np.argmin(distance)
                sampled_action_ids.append(sampled_action_id)
            sampled_actions_ids = np.array(sampled_action_ids)
                
            # make sure the sampled_actions are in the true_actions
            # sampled_actions_ids = np.array([np.argmin(np.linalg.norm(true_actions - a, axis=1)) for a in sampled_actions])
            # sampled_actions = true_actions[sampled_actions_ids]
        else:
            sampled_actions = action["query"][0,:, action_dim:]
            # find the nearest action in the true_actions
            sampled_action_ids = []
            for sampled_action in sampled_actions:
                distance = np.linalg.norm(sampled_action - true_actions, axis=1)
                min_distance = np.min(distance)
                if min_distance > 0:
                    # random sample an interger from 0 to len(true_actions)
                    sampled_action_id = np.argmin(distance)
                    if sampled_action_id in sampled_action_ids:
                        sampled_action_id += 1 * np.random.randint(-1, 1)
                    true_actions[sampled_action_id] = sampled_action
                else:
                    sampled_action_id = np.argmin(distance)
                sampled_action_ids.append(sampled_action_id)
            sampled_actions_ids = np.array(sampled_action_ids)


            # sampled_actions_ids = np.array([np.argmin(np.linalg.norm(true_actions - a, axis=1)) for a in sampled_actions])
            print(sampled_actions_ids)
            sampled_actions = true_actions[sampled_actions_ids]
        logging("======== FINISHED: get sampled_actions ========")

        # get state s
        s = s.cpu().numpy()
        s_dim_len = len(s.shape)
        # reps equals to (num_samples, 1) if s_dim_len == 1, else (num_samples, 1, ..., 1) where the number of 1s is s.shape[1]
        reps = (num_samples, ) + (1, ) * (s_dim_len - 1)
        s = torch.tensor(np.tile(A=s[0, :], reps=reps).astype(np.float32)).to(device)


        # get critic network
        if generate_main_q_val:
            if args["env_name"].lower() in ["mine", "recsim-data"] and not args['mw_obs_flatten']:
                _s = agent.main_ref_critic_obs_enc(s).to(device)
            else:
                _s = s
            critic_network = agent.main_ref_critic
            # get q values
            q = get_correspond_q_val(critic_network, _s,
                                    true_actions, device)
            logging("======== FINISHED: get q values ========")
            if args["env_name"] != "mujoco-inverted_pendulum":
                # plot 3d plot
                gen_main_critic_plot(a2d_x, a2d_y, num_samples_sqrt, q, actions_2d, sampled_actions_ids, 
                        timestep, plot_save_dir, sampled_actions=sampled_actions)
            else:
                # plot 2d plot
                gen_critic_plot_2d(true_actions, q, sampled_actions_ids, 
                        timestep, plot_save_dir)

            print("Finished in {:.2f} seconds".format(time.time() - start_time))
        
        if generate_ar_q_val:
            for i in range(len(sampled_actions_ids)):
                if args["env_name"].lower() in ["mine", "recsim-data"] and not args['mw_obs_flatten']:
                    _s = agent.main_ar_critic_obs_enc(s).to(device)
                else:
                    _s = s
                _s = _s[:, None, :]
                # get actions
                list_actions = np.zeros((true_actions.shape[0], len(sampled_actions_ids), action_dim)) # 10k * 3 * dim
                # for index before i, use the sampled_actions
                for j in range(i):
                    # a0, a1, a2
                    # sampled_actions_ids[0] --> the id of the a0 in true_actions
                    # true_actions[sampled_actions_ids[0]] = a0
                    sampled_action = true_actions[sampled_actions_ids[j]] # act_dim
                    sampled_action = sampled_action[None, :] # 1, act_dim
                    list_actions[:, j, :] = sampled_action.repeat(true_actions.shape[0], axis=0) # 10k, 1, act_dim
                # sampled_action = true_actions[sampled_actions_ids] # or use torch.index_select; shape: 3, act_dim
                # list_actions = sampled_action[None, :, :].repeat(true_actions.shape[0], axis=0) # 10k, 3, act_dim

                # for index i, use the true_actions
                list_actions[:, i, :] = true_actions
                action_seq = torch.tensor(list_actions, device=device)
                # change action_seq dtype to float32
                action_seq = action_seq.float()
                # get Q values
                with torch.no_grad():
                    Q = agent.main_ar_critic(state=_s,
                                    action_seq=action_seq,
                                    alternate_conditioning=None,
                                    true_action=None)
                q = Q[:, i].detach().cpu().numpy()


                logging("======== FINISHED: get q values ========")
                if args["env_name"] != "mujoco-inverted_pendulum":
                    # plot 3d plot
                    gen_ar_critic_plot(a2d_x, a2d_y, num_samples_sqrt, q, actions_2d, sampled_actions_ids[i], 
                            timestep, plot_save_dir, sampled_action=sampled_actions[i], q_index=i)
                else:
                    # plot 2d plot
                    gen_critic_plot_2d(true_actions, q, sampled_actions_ids, timestep, plot_save_dir, 
                                       ar_critic=True, sampled_action_index=i)
                    

        s = s_next
