# %%
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import dreamerv2.api as dv2
import common
from common import Config
import envs
import numpy as np
import collections
import matplotlib.pyplot as plt
from dreamerv2.common.replay import convert
import pathlib
import sys
import ruamel.yaml as yaml
import pickle


from run_goal_cond import make_env, make_eval_fn, make_obs2goal_fn
import gc_agent
import common
import goal_picker
import matplotlib.pyplot as plt

# %%
logdir = pathlib.Path('/home/anonymous/logdir/kitchen_mppi2_3repeat_4two1exp_s1')
yaml_config = yaml.safe_load((logdir / 'config.yaml').read_text())

# yaml_config["planner"]["mppi_gamma"] = 10.0
config = common.Config(yaml_config)
# import ipdb; ipdb.set_trace()
overrides = {
    "planner.planner_type": "shooting_mppi",
    "planner.horizon": 75,
    "planner.std_scale": 1.0,
    # "planner.final_step_cost": True,
    "planner.mega_prior": True,
    # "planner.optimization_steps": 5,
    "planner.repeat_samples": 0,
    "planner.mppi_gamma": 5.0,
    'planner.goal_min': [-2.5 for i in range(20)],
    'planner.goal_max': [2.5 for i in range(20)],
    # 'planner.sample_replay': True,
}
config = config.update(overrides)


env = make_env(config, use_goal_idx=False, log_per_goal=True)
# eval_env = make_env(config, use_goal_idx=True, log_per_goal=False, eval=True)
# report_render_fn = make_report_render_function(config)
# eval_fn = make_eval_fn(config)
# plot_fn = make_plot_fn(config)
# ep_render_fn = make_ep_render_fn(config)
# cem_vis_fn = make_cem_vis_fn(config)
obs2goal_fn = make_obs2goal_fn(config)
tf.config.run_functions_eagerly(not config.jit)
# tf.data.experimental.enable_debug_mode(not config.jit)
message = 'No GPU found. To actually train on CPU remove this assert.'
assert tf.config.experimental.list_physical_devices('GPU'), message
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
    from tensorflow.keras.mixed_precision import experimental as prec
    prec.set_policy(prec.Policy('mixed_float16'))

# %%
replay = common.Replay(logdir / 'train_episodes', **config.replay) # initialize replay buffer
step = common.Counter(replay.stats['total_steps']) # initialize step counter
dataset = iter(replay.dataset(**config.dataset))

# %%
agnt = gc_agent.GCAgent(config, env.obs_space, env.act_space, step, obs2goal_fn, None)
train_agent = common.CarryOverState(agnt.train)
train_agent(next(dataset))
assert (logdir / 'variables.pkl').exists()
print('Found existing checkpoint.')
agnt.load(logdir / 'variables.pkl')

# %%
"""Define goal picking """
should_cem_plot = common.Every(1)
goal_picker_cls = getattr(goal_picker, config.goal_strategy)
p_cfg = config.planner
if config.goal_strategy == "Greedy":
    goal_strategy = goal_picker_cls(replay, agnt.wm, agnt._expl_behavior._intr_reward, config.state_key, config.goal_key, 1000)
elif config.goal_strategy == "SampleReplay":
    goal_strategy = goal_picker_cls(agnt.wm, dataset, config.state_key, config.goal_key)
elif config.goal_strategy == "SubgoalPlanner":
    if p_cfg.init_candidates[0] == 123456789.0: # ugly hack for specifying no init cand.
        init_cand = None
    else:
        init_cand = np.array(p_cfg.init_candidates, dtype=np.float32)
        # unflatten list of init candidates
        goal_dim=np.prod(env.obs_space[config.state_key].shape) # assume goal dim = state dim
        assert len(init_cand) == goal_dim, f"{len(init_cand)}, {goal_dim}"
        init_cand = np.split(init_cand, len(init_cand)//goal_dim)
        init_cand = tf.convert_to_tensor(init_cand)

    def vis_fn(elite_inds, elite_samples, seq, wm):
        pass


    goal_dataset = None
    if p_cfg.sample_replay:
        goal_dataset = iter(replay.dataset(batch=10000//(config.time_limit+1), length=config.time_limit+1)) # take 10K states.

    mega_prior = None
    if p_cfg.mega_prior:
        mega_prior = goal_picker.MEGA(agnt, replay, env.act_space, config.state_key, config.time_limit+1, obs2goal_fn)

    sample_env_goals_fn = None
    env_goals_percentage = p_cfg.init_env_goal_percent

    goal_strategy = goal_picker_cls(
      agnt.wm,
      agnt._task_behavior.actor,
      agnt._expl_behavior.planner_intr_reward,
      gc_input=config.gc_input,
      obs2goal=obs2goal_fn,
      goal_dim=np.prod(env.observation_space[config.goal_key].shape),
      goal_min=np.array(p_cfg.goal_min, dtype=np.float32),
      goal_max=np.array(p_cfg.goal_max, dtype=np.float32),
      act_space=env.act_space,
      state_key=config.state_key,
      planner=p_cfg.planner_type,
      horizon=p_cfg.horizon,
      batch=p_cfg.batch,
      cem_elite_ratio=p_cfg.cem_elite_ratio,
      optimization_steps=p_cfg.optimization_steps,
      std_scale=p_cfg.std_scale,
      mppi_gamma=p_cfg.mppi_gamma,
      init_candidates=init_cand,
      dataset=goal_dataset,
      evaluate_only=p_cfg.evaluate_only,
      repeat_samples=p_cfg.repeat_samples,
      mega_prior=mega_prior,
      sample_env_goals_fn=sample_env_goals_fn,
      env_goals_percentage=env_goals_percentage,
      vis_fn=vis_fn
    )
else:
    raise NotImplementedError


# %%

def update_goal_strategy(*args):
    if config.goal_strategy == "Greedy":
        goal_strategy.update_buffer_priorities()
    elif "SubgoalPlanner" in config.goal_strategy:
        #  goal strategy will search for new distribution next time we sample.
        goal_strategy.will_update_next_call = True
        if config.planner.mega_prior:
            goal_strategy.mega.update_kde()
    elif config.goal_strategy in {"MEGA", "Skewfit"}:
        goal_strategy.update_kde()

# %% [markdown]
# # Set up scatter plot

# %%
complete_episodes = replay._complete_eps
ep_subsample = 10
step_subsample = 1
batch_size = 50
obs_key = 'state'
goal_key = 'goal'

wm = agnt.wm
episodes = list(complete_episodes.values())
num_goals = min(int(config.eval_every), len(episodes))
recent_episodes = episodes[-num_goals:]
if len(episodes) > num_goals:
  old_episodes = episodes[:-num_goals][::5]
  old_episodes.extend(recent_episodes)
  episodes = old_episodes
else:
  episodes = recent_episodes
all_observations = []
value_list = []
all_goals = []

if os.path.isfile("visualize_planner_cache.pkl"):
    with open("visualize_planner_cache.pkl", "rb") as f:
        cache = pickle.load(f)
        all_observations = cache["all_obs"]
        value_list = cache["value_list"]
        all_goals = cache["goals"]

else:
    for ep_count, episode in enumerate(episodes):
        # 2. Adding episodes to the batch
        if (ep_count % batch_size) == 0:
            start = ep_count
            chunk = collections.defaultdict(list)
        sequence = {
            k: convert(v[::step_subsample])
            for k, v in episode.items() if not k.startswith('log_')}
        data = wm.preprocess(sequence)
        for key, value in data.items():
            chunk[key].append(value)
        # 3. Forward passing each batch after it reaches size batch_size or it reaches the end of the episode list
        if (ep_count % batch_size == (batch_size - 1) or ep_count == len(episodes) - 1):
            end = ep_count
            chunk = {k: tf.stack(v) for k, v in chunk.items()}
            embed = wm.encoder(chunk)
            post, _ = wm.rssm.observe(embed, chunk['action'], chunk['is_first'], None)
            chunk['feat'] = wm.rssm.get_feat(post)
            value_fn = agnt._expl_behavior.ac._target_critic
            value = value_fn(chunk['feat']).mode()
            value_list.append(value)

            all_observations.append(tf.stack(chunk[obs_key]))
            all_goals.append(tf.stack(chunk[goal_key]))

    all_observations = tf.concat(all_observations, axis = 0)
    with open("visualize_planner_cache.pkl", "wb") as f:
        cache = dict(all_obs=all_observations, value_list=value_list, goals=all_goals)
        pickle.dump(cache, f)

# %%
def plot_goals_over_states(all_observations, all_goals):
    all_observations = np.concatenate(all_observations)
    all_observations = all_observations.reshape(-1, all_observations.shape[-1])
    all_goals = np.squeeze(all_goals)
    fig, all_axes = plt.subplots(1, 6, figsize=(13,2))
    value_axes = all_axes

    obj_to_ax = {
    "bottom_burner": (value_axes[0]),
    "light_switch": (value_axes[1]),
    "slide_cabinet": ( value_axes[2]),
    "hinge_cabinet": ( value_axes[3]),
    "microwave": ( value_axes[2]),
    "kettle": (value_axes[4]),
    "joints": (value_axes[5]),
    }
    object_obs_idxs = {'bottom_burner' :  [9, 10],
                    'light_switch' :  [11, 12],
                    'slide_cabinet':  [13],
                    'hinge_cabinet':  [14, 15],
                    'microwave'    :  [16],
                    'kettle'       :  [17, 18, 19],
                    'joints': [i for i in range(9)]}
    # plot value
    values = tf.concat(value_list, axis = 0)
    values = values.numpy().flatten()
    cm = plt.cm.get_cmap("viridis")
    for obj, axs in obj_to_ax.items():
        ax = axs
        obs_idxs = object_obs_idxs[obj]
        ax.set_title(obj, fontsize=6)
        if obj == "joints":
            x = all_observations[:, obs_idxs]
            # since y axis is from -3 to 3,
            # we can plot a 1d line every 0.5
            y = (np.ones_like(x) * (np.arange(9) * 0.6)) - 2.4
            values = np.ones_like(x) * values[:,None]
            p2e_scatter = ax.scatter(
                x=x.reshape(-1),
                y=y.reshape(-1),
                s=1,
                c=values.reshape(-1),
                cmap=cm,
                zorder=4
            )
        else:
            if obj == "kettle": # only plot xy dims.
                data = all_observations[:, obs_idxs[:2]]
            elif obj in {"microwave", "slide_cabinet"}: # plot both 1D lines on same plot.
                y = 0.25 if obj == "microwave" else 0.75
                x = all_observations[:, obs_idxs]
                y = np.ones_like(x) * y
                data = np.hstack([x,y])
                ax.set_title("microwave (lower), slide_cabinet", fontsize=6)
            else:
                data = all_observations[:, obs_idxs]
            p2e_scatter = ax.scatter(x=data[:, 0],
                y=data[:, 1],
                s=1,
                c=values,
                cmap=cm,
                zorder=4
            )

        color = 'red'
        if obj == "joints":
            x = all_goals[:, obs_idxs]
            # since y axis is from -3 to 3,
            # we can plot a 1d line every 0.5
            y = (np.ones_like(x) * (np.arange(9) * 0.6)) - 2.4
            ax.scatter(x=x.reshape(-1),
                y=y.reshape(-1),
                s=1,
                c=color,
                zorder=10
            )
        else:
            if obj == "kettle": # only plot xy dims.
                data = all_goals[:, obs_idxs[:2]]
            elif obj in {"microwave", "slide_cabinet"}: # plot both 1D lines on same plot.
                y = 0.25 if obj == "microwave" else 0.75
                color = 'blue' if obj == "microwave" else 'red'
                x = all_goals[:, obs_idxs]
                y = np.ones_like(x) * y
                data = np.hstack([x,y])
            else:
                data = all_goals[:, obs_idxs]
            ax.scatter(x=data[:, 0],
                y=data[:, 1],
                s=1,
                c=color,
                zorder=10
            )
        ax.set_xlim([-3.0, 3.0]) # assume obs are normalized.
        ax.set_ylim([-3.0, 3.0])
        # ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        import matplotlib.patches as patches
        rect = patches.Rectangle((0, 0), 1, 1, linewidth=1, edgecolor='b', facecolor='none')
        ax.add_patch(rect)


    plt.colorbar(p2e_scatter, ax=ax)

    fig.canvas.draw()
    plt.savefig("goals_over_states.png", dpi=200)
    # plt.show()

# %% [markdown]
# # Default CEM planner

# %%
obs = [env.reset()]
obs = tf.nest.map_structure(lambda x: tf.expand_dims(tf.expand_dims(tf.tensor(x),0),0), obs)[0]
obs = agnt.wm.preprocess(obs)
update_goal_strategy()
goal_strategy.will_update_next_call = True
_ = goal_strategy.search_goal(obs)

# %%
# generate CEM state distribution.
all_goals = goal_strategy.sample_goal(100)[None, None]
# import ipdb; ipdb.set_trace()
plot_goals_over_states(all_observations, all_goals)

# sys.exit(0)

# %% [markdown]
# It seems to be focusing on the start. One reason could be that the other 28 dimensions are very interesting.
# We can test by restricting the CEM to just search the xy dimensions.
#
# Let's see what imagined trajectories are coming from the CEM.

# %%
# import ipdb; ipdb.set_trace()
# num_vis = 5
# elite_inds = goal_strategy.elite_inds[:num_vis]
# elite_samples = goal_strategy.elite_samples[:num_vis] # topK x D
# final_seq = goal_strategy.final_seq

# elite_seq = tf.nest.map_structure(lambda x: tf.gather(x, elite_inds, axis=1), final_seq)
# elite_obs = wm.heads['decoder'](wm.rssm.get_feat(elite_seq))['observation'].mode() # T x topk x D

# %%
# plot the goals, and trajectories from planner process.
# def plot_elites(goals, trajectories, mega_goal=None):
#     # fig, (state_ax,  p2evalue_ax) = plt.subplots(1, 2, figsize=(1, 3))
#     fig, p2evalue_ax = plt.subplots(1, 1, figsize=(1, 3))
#     # state_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
#     goal_time_limit = round(config.goal_policy_rollout_percentage * config.time_limit)
#     # obs_list = tf.concat(all_obs, axis = 0)
#     obs_list = all_obs
#     before = obs_list[:,:goal_time_limit,:]
#     before = before[:,:,:2]
#     ep_order_before = tf.range(before.shape[0])[:, None]
#     ep_order_before = tf.repeat(ep_order_before, before.shape[1], axis=1)
#     before = tf.reshape(before, [before.shape[0]*before.shape[1], 2])
#     after = obs_list[:,goal_time_limit:,:]
#     after = after[:,:,:2]
#     ep_order_after = tf.range(after.shape[0])[:, None]
#     ep_order_after = tf.repeat(ep_order_after, after.shape[1], axis=1)
#     after = tf.reshape(after, [after.shape[0]*after.shape[1], 2])
#     # obs_list = tf.concat(obs, axis = 0)
#     # obs_list = obs_list[:,:,:2]
#     # ep_order = tf.range(obs_list.shape[0])[:, None] # Num_ep x 1
#     # ep_order = tf.repeat(ep_order, obs_list.shape[1], axis=1) #  Num_ep x T
#     # obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
#     # ep_order = tf.reshape(ep_order, [ep_order.shape[0]*ep_order.shape[1]])
#     ep_order_before = tf.reshape(ep_order_before, [ep_order_before.shape[0]*ep_order_before.shape[1]])
#     ep_order_after = tf.reshape(ep_order_after, [ep_order_after.shape[0]*ep_order_after.shape[1]])
#     goal_list = goals[...,:2]
#     goal_list = tf.reshape(goal_list, [-1, 2])

#     obs_list = obs_list[:, :, :2]
#     obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
#     values = tf.concat(value_list, axis = 0)
#     values = values.numpy().flatten()
#     cm = plt.cm.get_cmap("viridis")
#     # p2e_scatter = p2evalue_ax.scatter(
#     # x=obs_list[:,0],
#     # y=obs_list[:,1],
#     # s=1,
#     # c=values,
#     # cmap=cm,
#     # zorder=2,
#     # )

#     p2evalue_ax.scatter(
#     x=goal_list[:,0],
#     y=goal_list[:,1],
#     s=1,
#     c='r',
#     zorder=5,
#     )
#     if mega_goal is not None:
#         p2evalue_ax.scatter(
#         x=mega_goal[:,0],
#         y=mega_goal[:,1],
#         s=50,
#         c='y',
#         zorder=5,
#         marker='*',
#         )

#     # print(trajectories.shape)
#     # only start and end state of trajectories.
#     print(trajectories.shape)
#     # trajectories = tf.stack([trajectories[0,:,:], trajectories[-1, :, :]], axis=0)
#     # print(trajectories.shape)
#     # trajectories = tf.reshape(trajectories, (-1, 29))
#     # print(trajectories.shape)
#     # p2evalue_ax.scatter(
#     # x=trajectories[:,0],
#     # y=trajectories[:,1],
#     # s=1,
#     # c='b',
#     # zorder=4,
#     # )
#     trajectories = tf.transpose(trajectories, (1,0,2))
#     # (num_vis,horizon,29)
#     p2evalue_ax.plot(
#         trajectories[:,:,0],
#         trajectories[:,:,1],
#         c='b',
#         zorder=4,
#         marker='.'
#     )

#     # plt.colorbar(p2e_scatter, ax=p2evalue_ax)
#     p2evalue_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
#     p2evalue_ax.set_title('elite goals and states')

#     fig = plt.gcf()
#     fig.set_size_inches(7, 6)
#     # plt.show()
#     plt.savefig("plot_elites.png", dpi=150)

# plot_elites(elite_samples, elite_obs, goal_strategy.mega_sample)