import copy
from large_rl.commons.utils import logging
from random import randint
import random
import time
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
print(plt.style.available)
plt.style.use('seaborn-pastel')
import os
from sklearn.kernel_approximation import Nystroem
from mpl_toolkits.mplot3d import Axes3D

from large_rl.commons.launcher import launch_embedding, launch_env, launch_agent
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.args import get_all_args

def to_true_action(actions_2d, feature_map, args, rec_sim_emb=None):
    if args["env_name"].startswith("recsim"):
        # true_actions = feature_map.fit_transform(actions_2d)
        # # reflact the true_actions to [0, 1]
        # true_actions = MinMaxScaler(feature_range=(-1, 1)).fit_transform(true_actions)
        # # find the nearest action in the recsim_emb
        # true_actions_ids = np.array([np.argmin(np.linalg.norm(rec_sim_emb.embedding_np - a, axis=1)) for a in true_actions])
        true_action_ids = np.random.choice(range(len(rec_sim_emb.embedding_np)), size=len(actions_2d))
        # reflact the true_actions to the original action space
        true_actions = rec_sim_emb.embedding_np[true_action_ids]
    else:
        true_actions = feature_map.fit_transform(actions_2d)
        # reflact the true_actions to [0, 1]
        true_actions = MinMaxScaler().fit_transform(true_actions)
        # reflact the true_actions to the original action space
        true_actions = true_actions * (_high - _low) + _low
    return true_actions


def get_correspond_q_val(critic_network, state, actions, device):
    q = critic_network(torch.cat([state, torch.tensor(actions).to(device)], dim=-1).to(device)).detach().cpu().numpy()
    return q

def get_color(index):
    color_ids = ['Purples', 'Blues', 'Greens', 'Oranges', 'Reds',
                      'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
                      'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']
    flat_cmap = plt.cm.get_cmap(color_ids[index])  # Oranges is a gradient of orange colors
    flat_color = flat_cmap(0.7)  # Get the middle color from the Oranges colormap
    return flat_color

def plot_line_2d(x, y, ax, label=None):
    ax.plot(x, y, label=label, c=get_color(1), marker='.', zorder=10, alpha=.85)

def plot_point(x, y, ax, point_color, txt_label=None, x_shift=0, y_shift=-1, font_size=20):
    ax.scatter(x, y, s=40, c=point_color, marker='.', zorder=10, alpha=.85)
    ax.text(x, y, txt_label, size=font_size, zorder=10, color=point_color) 

def add_hline(x, y, text="q", font_size=20, y_txt_shift=0, x_txt_shift=0):
    plt.axhline(y=y, color='b', linestyle='--', alpha=.5)
    plt.text(x=x+x_txt_shift, y=y+y_txt_shift, s=text, color='b', va='center', ha='left', fontsize=font_size)

def gen_critic_plot_2d(true_actions, q, sampled_actions_ids, plot_save_dir, q_scale, assigned_ax,
                       ar_critic=False, sampled_action_index=None, independent_ar_plots=False, previous_q_value=None, 
                       only_main_critic=False):
    q = q.reshape(-1, 1)
    ax = assigned_ax
    X = true_actions[:, 0]
    Y = q[:, 0]
    plot_line_2d(X, Y, ax, label = f"ar critic {sampled_action_index}" if ar_critic else "main critic")

    point_color = get_color(4)
    font_size = 15
    hline_x = np.min(X)
    if sampled_action_index is not None:
        id_in_true_actions = sampled_actions_ids[sampled_action_index]
        x = true_actions[id_in_true_actions, 0]
        y = q[id_in_true_actions, 0]
        plot_point(x, y, ax, point_color, txt_label=f'$a_{sampled_action_index}$', font_size=font_size)
        if previous_q_value is not None:
            if sampled_action_index == 1:
                surface_label_text = f"$Q$($a_0$)"
            else:
                surface_label_text = f"Max($Q$($a_0$), $Q$($a_1$))"
            add_hline(hline_x, previous_q_value, text=surface_label_text, font_size=font_size)
    else:
        sampled_actions = true_actions[sampled_actions_ids]
        x = sampled_actions[:, 0]
        y = q[sampled_actions_ids, 0]
        for i in range(len(sampled_actions_ids)):
            if i == len(sampled_actions_ids) - 1:
                plot_point(x[i], y[i], ax, point_color, txt_label=f'$a_{i}$', x_shift=-0.5, font_size=font_size)
            else:
                plot_point(x[i], y[i], ax, point_color, txt_label=f'$a_{i}$', font_size=font_size)
            # y_txt_shifts = [-0.5, 0.5, 0]
            # x_txt_shifts = [-1, -0.2, 0.6]
            y_txt_shifts = [0, 0, 0]
            x_txt_shifts = [0, 0, 0]
            surface_label_text = f'$Q(a_{i})$'
            add_hline(hline_x, y[i], text=surface_label_text, font_size=font_size, 
                      y_txt_shift=y_txt_shifts[i], x_txt_shift=x_txt_shifts[i])
    
    # remove numbers
    ax.set_xticks([])
    ax.set_yticks([])
    # set labels
    ax.set_xlabel('Action Space', fontsize=font_size)
    ax.set_ylabel('Q-value', fontsize=font_size)
    # ax.legend()

    if q_scale is not None:
        ax.set_ylim([q_scale[0], q_scale[1]])

    title_env_name = args["env_name"].split("-")[1]
    # replace the '_' with ' ', and capitalize the first letter of each word
    title_env_name = title_env_name.replace("_", " ").title()
    if args['reacher_validity_type'] == 'box':
        title_env_name += "-Hard"
    else:
        title_env_name += "-Easy"
    if ar_critic and sampled_action_index:  
        ax.set_title(f'{title_env_name} $\Psi_{sampled_action_index}$', fontsize=font_size)
    else:
        ax.set_title(f'{title_env_name} Primary Critic', fontsize=font_size)

    if not ar_critic:
        plt.tight_layout()
        plt.savefig(f"{plot_save_dir}/main_critic_plot_2d.pdf", format="pdf")
        plt.clf()
    elif ar_critic:
        suffix = "_after_max" if only_main_critic else ""
        if independent_ar_plots:
            plt.tight_layout()
            plt.savefig(f"{plot_save_dir}/ar_critic_plot_{sampled_action_index}_2d{suffix}.pdf", format="pdf")
            plt.clf()
        elif sampled_action_index == len(sampled_actions_ids) - 1:
            plt.tight_layout()
            plt.savefig(f"{plot_save_dir}/ar_critic_plot_2d.pdf", format="pdf")
            plt.clf()



def create_plot_save_dir(timestep, env_name):
    root_dir = f"./src/3Dplots/"
    if not os.path.exists(root_dir):
        os.makedirs(root_dir)
    sub_dir = os.path.join(root_dir, f"{env_name}", f"{timestep}")
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)
    return sub_dir


def preprocess_state(agent, s, main_critic=False):
    if args["env_name"].lower() in ["mine", "recsim-data"] and not args['mw_obs_flatten']:
        if main_critic:
            _s = agent.main_ref_critic_obs_enc(s).to(device)
        else:
            _s = agent.main_ar_critic_obs_enc(s).to(device)
    else:
        _s = s
    return _s

if __name__ == "__main__":
    independent_ar_plots = True
    concat_ar_plots = False
    args = get_all_args()
    args = vars(args)
    device = args.get("device", "cpu")
    set_randomSeed(seed=args["seed"])
    q_threshold = 1e-1

    # set environment
    args["num_envs"] = 1
    env = launch_env(args=args)
    env_name = args["env_name"]
    args["_max_l2_dist"] = np.linalg.norm(np.ones(args["reacher_action_shape"]) * 2)
    args["env_max_action"] = float(env.action_space[0].high[0])
    action_dim = env.action_space[0].shape[0]
    _high, _low = env.action_space[0].high, env.action_space[0].low
    
    dict_embedding = {'item': None, 'task': None}
    logging("======== FINISHED: get embedding ========")
    
    args["TD3_policy_noise"] *= args["env_max_action"]
    args["TD3_noise_clip"] *= args["env_max_action"]

    # agent load
    prefix = args["method_name"]
    suffix = args["env_name"].split('-')[1]    
    args["agent_load_path"] = f"./weights/{prefix}-{suffix}"
    dict_load_epoch = {}
    if prefix == 'savo_refined':
        dict_load_epoch["ant"] = 18899
        dict_load_epoch["half_cheetah"] = 19699
        dict_load_epoch["hopper"] = 21199
        dict_load_epoch["inverted_double_pendulum"] = 21299
        dict_load_epoch["inverted_pendulum"] = 20799
        dict_load_epoch["walker2d"] = 20499
    elif prefix == 'savo_threshold_refined':
        dict_load_epoch["ant"] = 34099
        dict_load_epoch["half_cheetah"] = 34199
        dict_load_epoch["hopper"] = 37299
        dict_load_epoch["inverted_double_pendulum"] = 39699
        dict_load_epoch["inverted_pendulum"] = 39699
        dict_load_epoch["walker2d"] = 35599
    else:
        raise ValueError(f"prefix {prefix} not supported")
    args["agent_load_epoch"] = dict_load_epoch.get(suffix, 0)
    agent = launch_agent(args=args, env=env)
    agent.load_model(epoch=args['agent_load_epoch'])
    logging("======== FINISHED: get agent ========")

    # get 2d action first
    # sample num_samples of different values evenly spaced between(-1, 1)
    num_samples = 10000
    true_actions_basic = np.linspace(_low, _high, num=num_samples).astype(np.float32)

    timesteps=10
    s = env.reset()
    for timestep in range(timesteps):
        true_actions = copy.deepcopy(true_actions_basic)
        # get state
        s = torch.tensor(s).to(device)
        action = agent.select_action(obs=s, act_embed_base=dict_embedding["item"], epsilon={"actor": 0.0, "critic": 0.0})
        # get next state
        s_next, _, d, _ = env.step(action["action"])
        logging("======== FINISHED: get action ========")

        sampled_actions = action["query"][0,:, action_dim:]
        # find the nearest action in the true_actions
        sampled_action_ids = []
        for sampled_action in sampled_actions:
            distance = np.linalg.norm(sampled_action - true_actions, axis=1)
            min_distance = np.min(distance)
            if min_distance > 0:
                # random sample an interger from 0 to len(true_actions)
                sampled_action_id = np.argmin(distance)
                if sampled_action_id in sampled_action_ids:
                    sampled_action_id += 1 * np.random.randint(-1, 1)
                true_actions[sampled_action_id] = sampled_action
            else:
                sampled_action_id = np.argmin(distance)
            sampled_action_ids.append(sampled_action_id)
        sampled_actions_ids = np.array(sampled_action_ids)

        print(sampled_actions_ids)
        sampled_actions = true_actions[sampled_actions_ids]
        logging("======== FINISHED: get sampled_actions ========")

        # get state s
        s = s.cpu().numpy()
        s_dim_len = len(s.shape)
        # reps equals to (num_samples, 1) if s_dim_len == 1, else (num_samples, 1, ..., 1) where the number of 1s is s.shape[1]
        reps = (num_samples, ) + (1, ) * (s_dim_len - 1)
        s = torch.tensor(np.tile(A=s[0, :], reps=reps).astype(np.float32)).to(device)
        logging("======== FINISHED: get state ========")

        shared_scale = None
        # gen q_main
        _s = preprocess_state(agent, s, main_critic=True)
        critic_network = agent.main_ref_critic
        # get q values
        q_main = get_correspond_q_val(critic_network, _s,
                                true_actions, device)
        logging("======== FINISHED: get main q values ========")

        q_max = np.max(q_main[sampled_action_ids])
        q0 = q_main[sampled_action_ids[0]]
        # if (q_max - q_main[sampled_action_ids[0]]) / q_max < q_threshold and (q_max - q_main[sampled_action_ids[0]]) / q_max > 0:
        if (q_max - q0) / abs(q0) < q_threshold:
            s = s_next
            continue # dont plot if the best action is the first one
        # gen q_ar
        q_ars = []
        _s = preprocess_state(agent, s, main_critic=False)
        _s = _s[:, None, :]
        for i in range(len(sampled_actions_ids)):
            # get actions
            list_actions = np.zeros((true_actions.shape[0], len(sampled_actions_ids), action_dim)) # 10k * 3 * dim
            # for index before i, use the sampled_actions
            for j in range(i):
                # a0, a1, a2
                # sampled_actions_ids[0] --> the id of the a0 in true_actions
                # true_actions[sampled_actions_ids[0]] = a0
                sampled_action = true_actions[sampled_actions_ids[j]] # act_dim
                sampled_action = sampled_action[None, :] # 1, act_dim
                list_actions[:, j, :] = sampled_action.repeat(true_actions.shape[0], axis=0) # 10k, 1, act_dim

            # for index i, use the true_actions
            list_actions[:, i, :] = true_actions
            action_seq = torch.tensor(list_actions, device=device).float()
            # get Q values
            with torch.no_grad():
                Q = agent.main_ar_critic(state=_s,
                                action_seq=action_seq,
                                alternate_conditioning=None,
                                true_action=None)
            q_ar = Q[:, i].detach().cpu().numpy()
            q_ars.append(q_ar)

        q_ars = np.array(q_ars)
        logging("======== FINISHED: get ar q values ========")
        
        # align q value scale for all plots
        q_scale = (np.min([np.min(q_main), np.min(q_ars)]), np.max([np.max(q_main), np.max(q_ars)]))
            
        # plot 2d plot
        plot_save_dir = create_plot_save_dir(timestep, env_name)
        figsize=(6, 6)
        # main
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(1, 1, 1, facecolor="w")
        gen_critic_plot_2d(true_actions, q_main, sampled_actions_ids, plot_save_dir, q_scale, assigned_ax=ax)
        # shared_plot = None
        # ar
        if concat_ar_plots:
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, axs[i],
                       ar_critic=True, sampled_action_index=i)
        elif independent_ar_plots:
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                fig = plt.figure(figsize=figsize)
                ax = fig.add_subplot(1, 1, 1, facecolor="w")
                previous_q_value = q_main[sampled_action_ids[i-1]] if i > 0 else None
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, ax,
                       ar_critic=True, sampled_action_index=i, independent_ar_plots=True, previous_q_value=previous_q_value)
        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(1, 1, 1, facecolor="w")
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, ax,
                       ar_critic=True, sampled_action_index=i)
        # ar, but q_ar = max(q_main, q_main[sampled_action_ids[i-1]])
        if concat_ar_plots:
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, axs[i],
                       ar_critic=True, sampled_action_index=i)
        elif independent_ar_plots:
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                if i:
                    q_max_before = np.max(q_main[sampled_actions_ids[:i]])
                    q_ar = np.maximum(q_main, q_main[sampled_actions_ids[i-1]])
                fig = plt.figure(figsize=figsize)
                ax = fig.add_subplot(1, 1, 1, facecolor="w")
                previous_q_value = q_main[sampled_action_ids[i-1]] if i > 0 else None
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, ax,
                       ar_critic=True, sampled_action_index=i, independent_ar_plots=True, previous_q_value=previous_q_value, only_main_critic=True)
        else:
            fig = plt.figure(figsize=figsize)
            ax = fig.add_subplot(1, 1, 1, facecolor="w")
            for i in range(len(sampled_actions_ids)):
                q_ar = q_ars[i]
                gen_critic_plot_2d(true_actions, q_ar, sampled_actions_ids, plot_save_dir, q_scale, ax,
                       ar_critic=True, sampled_action_index=i)

                    

        s = s_next
