# %%
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import dreamerv2.api as dv2
import common
from common import Config
import envs
import numpy as np
import collections
import matplotlib.pyplot as plt
from dreamerv2.common.replay import convert
import pathlib
import sys
import ruamel.yaml as yaml
import pickle
from functools import partial


from run_goal_cond import make_env, make_eval_fn, make_obs2goal_fn
import gc_agent
import common
import goal_picker
import matplotlib.pyplot as plt


NUM_GOALS = 5
logdir = pathlib.Path('/home/anonymous/logdir/ant_page_mppi1_s0')
yaml_config = yaml.safe_load((logdir / 'config.yaml').read_text())
yaml_config["train_env_goal_percent"] = 0.0
yaml_config["planner"]["init_env_goal_percent"] = 0.0
yaml_config["gc_reward_shape"] = "sum"
yaml_config["epsilon_expl_noise"]  = 0.0
yaml_config["goal_strategy"] = "Greedy"
config = common.Config(yaml_config)
overrides = {
    # "planner.planner_type": "shooting_cem",
    # "planner.horizon": 150,
    # "planner.std_scale": 10,
    # "planner.final_step_cost": True,
    # "planner.mega_prior": False,
    # "planner.optimization_steps": 5,
    # "planner.repeat_samples": 1,
    "planner.mppi_gamma": 5.0,
}
config = config.update(overrides)


env = make_env(config, use_goal_idx=False, log_per_goal=True)
eval_env = make_env(config, use_goal_idx=True, log_per_goal=False, eval=True)
# report_render_fn = make_report_render_function(config)
# eval_fn = make_eval_fn(config)
# plot_fn = make_plot_fn(config)
# ep_render_fn = make_ep_render_fn(config)
# cem_vis_fn = make_cem_vis_fn(config)
obs2goal_fn = make_obs2goal_fn(config)
tf.config.run_functions_eagerly(not config.jit)
# tf.data.experimental.enable_debug_mode(not config.jit)
message = 'No GPU found. To actually train on CPU remove this assert.'
assert tf.config.experimental.list_physical_devices('GPU'), message
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
assert config.precision in (16, 32), config.precision
if config.precision == 16:
    from tensorflow.keras.mixed_precision import experimental as prec
    prec.set_policy(prec.Policy('mixed_float16'))

# %%
replay = common.Replay(logdir / 'train_episodes', **config.replay) # initialize replay buffer
step = common.Counter(replay.stats['total_steps']) # initialize step counter
dataset = iter(replay.dataset(**config.dataset))

# %%
agnt = gc_agent.GCAgent(config, env.obs_space, env.act_space, step, obs2goal_fn, None)
train_agent = common.CarryOverState(agnt.train)
train_agent(next(dataset))
assert (logdir / 'variables.pkl').exists()
print('Found existing checkpoint.')
agnt.load(logdir / 'variables.pkl')

# %%
"""Define goal picking """
should_cem_plot = common.Every(1)
goal_picker_cls = getattr(goal_picker, config.goal_strategy)
p_cfg = config.planner
if config.goal_strategy == "Greedy":
    goal_strategy = goal_picker_cls(replay, agnt.wm, agnt._expl_behavior._intr_reward, config.state_key, config.goal_key, 1000)
elif config.goal_strategy == "SampleReplay":
    goal_strategy = goal_picker_cls(agnt.wm, dataset, config.state_key, config.goal_key)
elif config.goal_strategy == "SubgoalPlanner":
    if p_cfg.init_candidates[0] == 123456789.0: # ugly hack for specifying no init cand.
        init_cand = None
    else:
        init_cand = np.array(p_cfg.init_candidates, dtype=np.float32)
        # unflatten list of init candidates
        goal_dim=np.prod(env.obs_space[config.state_key].shape) # assume goal dim = state dim
        assert len(init_cand) == goal_dim, f"{len(init_cand)}, {goal_dim}"
        init_cand = np.split(init_cand, len(init_cand)//goal_dim)
        init_cand = tf.convert_to_tensor(init_cand)

    def vis_fn(elite_inds, elite_samples, seq, wm):
        pass


    goal_dataset = None
    if p_cfg.sample_replay:
        goal_dataset = iter(replay.dataset(batch=10000//(config.time_limit+1), length=config.time_limit+1)) # take 10K states.

    mega_prior = None
    if p_cfg.mega_prior:
        mega_prior = goal_picker.MEGA(agnt, replay, env.act_space, config.state_key, config.time_limit+1, obs2goal_fn)

    sample_env_goals_fn = None
    env_goals_percentage = p_cfg.init_env_goal_percent

    goal_strategy = goal_picker_cls(
      agnt.wm,
      agnt._task_behavior.actor,
      agnt._expl_behavior.planner_intr_reward,
      gc_input=config.gc_input,
      obs2goal=obs2goal_fn,
      goal_dim=np.prod(env.observation_space[config.goal_key].shape),
      goal_min=np.array(p_cfg.goal_min, dtype=np.float32),
      goal_max=np.array(p_cfg.goal_max, dtype=np.float32),
      act_space=env.act_space,
      state_key=config.state_key,
      planner=p_cfg.planner_type,
      horizon=p_cfg.horizon,
      batch=p_cfg.batch,
      cem_elite_ratio=p_cfg.cem_elite_ratio,
      optimization_steps=p_cfg.optimization_steps,
      std_scale=p_cfg.std_scale,
      mppi_gamma=p_cfg.mppi_gamma,
      init_candidates=init_cand,
      dataset=goal_dataset,
      evaluate_only=p_cfg.evaluate_only,
      repeat_samples=p_cfg.repeat_samples,
      mega_prior=mega_prior,
      sample_env_goals_fn=sample_env_goals_fn,
      env_goals_percentage=env_goals_percentage,
      vis_fn=vis_fn
    )
elif config.goal_strategy in {"MEGA", "Skewfit"}:
    goal_strategy = goal_picker_cls(agnt, replay, env.act_space, config.state_key, config.time_limit, obs2goal_fn)
elif config.goal_strategy == "SubgoalPlannerKDE":
    if p_cfg.init_candidates[0] == 123456789.0: # ugly hack for specifying no init cand.
        init_cand = None
    else:
        init_cand = np.array(p_cfg.init_candidates, dtype=np.float32)
        # unflatten list of init candidates
        goal_dim=np.prod(env.obs_space[config.state_key].shape) # assume goal dim = state dim
        assert len(init_cand) == goal_dim, f"{len(init_cand)}, {goal_dim}"
        init_cand = np.split(init_cand, len(init_cand)//goal_dim)
        init_cand = tf.convert_to_tensor(init_cand)

    def vis_fn(elite_inds, elite_samples, seq, wm):
        if should_cem_plot(num_eps) and cem_vis_fn is not None:
            cem_vis_fn(elite_inds, elite_samples, seq, wm, eval_env, logger)

    goal_dataset = None
    if p_cfg.sample_replay:
        goal_dataset = iter(replay.dataset(batch=10000//(config.time_limit+1), length=config.time_limit+1)) # take 10K states.

    mega_prior = None
    if p_cfg.mega_prior:
        mega_prior = goal_picker.MEGA(agnt, replay, env.act_space, config.state_key, config.time_limit+1, obs2goal_fn)

    goal_strategy = goal_picker_cls(agnt, replay, env.act_space, config.state_key, config.time_limit, obs2goal_fn,
      gc_input=config.gc_input,
      goal_dim=np.prod(env.obs_space[config.state_key].shape), # assume goal dim = state dim
      goal_min=np.array(p_cfg.goal_min, dtype=np.float32),
      goal_max=np.array(p_cfg.goal_max, dtype=np.float32),
      planner=p_cfg.planner_type,
      horizon=p_cfg.horizon,
      batch=p_cfg.batch,
      cem_elite_ratio=p_cfg.cem_elite_ratio,
      optimization_steps=p_cfg.optimization_steps,
      std_scale=p_cfg.std_scale,
      init_candidates=init_cand,
      dataset=goal_dataset,
      evaluate_only=p_cfg.evaluate_only,
      repeat_samples=p_cfg.repeat_samples,
      mega_prior=mega_prior,
      vis_fn=vis_fn
    )
else:
    raise NotImplementedError


# %%

def update_goal_strategy(*args):
    if config.goal_strategy == "Greedy":
        goal_strategy.update_buffer_priorities()
    elif "SubgoalPlanner" in config.goal_strategy:
        #  goal strategy will search for new distribution next time we sample.
        goal_strategy.will_update_next_call = True
        if config.planner.mega_prior:
            goal_strategy.mega.update_kde()
    elif config.goal_strategy in {"MEGA", "Skewfit"}:
        goal_strategy.update_kde()

# %% [markdown]
# # Set up scatter plot

# %%
complete_episodes = replay._complete_eps
ep_subsample = 10
step_subsample = 1
batch_size = 50
obs_key = 'observation'
goal_key = 'goal'

wm = agnt.wm
episodes = list(complete_episodes.values())
num_goals = min(int(config.eval_every), len(episodes))
recent_episodes = episodes[-num_goals:]
if len(episodes) > num_goals:
  old_episodes = episodes[:-num_goals][::5]
  old_episodes.extend(recent_episodes)
  episodes = old_episodes
else:
  episodes = recent_episodes
all_obs = []
value_list = []
goals = []

if os.path.isfile("visualize_planner_cache.pkl"):
    with open("visualize_planner_cache.pkl", "rb") as f:
        cache = pickle.load(f)
        all_obs = cache["all_obs"]
        value_list = cache["value_list"]
        goals = cache["goals"]

else:
    for ep_count, episode in enumerate(episodes):
        # 2. Adding episodes to the batch
        if (ep_count % batch_size) == 0:
            start = ep_count
            chunk = collections.defaultdict(list)
        sequence = {
            k: convert(v[::step_subsample])
            for k, v in episode.items() if not k.startswith('log_')}
        data = wm.preprocess(sequence)
        for key, value in data.items():
            chunk[key].append(value)
        # 3. Forward passing each batch after it reaches size batch_size or it reaches the end of the episode list
        if (ep_count % batch_size == (batch_size - 1) or ep_count == len(episodes) - 1):
            end = ep_count
            chunk = {k: tf.stack(v) for k, v in chunk.items()}
            embed = wm.encoder(chunk)
            post, _ = wm.rssm.observe(embed, chunk['action'], chunk['is_first'], None)
            chunk['feat'] = wm.rssm.get_feat(post)
            value_fn = agnt._expl_behavior.ac._target_critic
            value = value_fn(chunk['feat']).mode()
            value_list.append(value)

            all_obs.append(tf.stack(chunk[obs_key]))
            goals.append(tf.stack(chunk[goal_key]))

    all_obs = tf.concat(all_obs, axis = 0)
    with open("visualize_planner_cache.pkl", "wb") as f:
        cache = dict(all_obs=all_obs, value_list=value_list, goals=goals)
        pickle.dump(cache, f)

# %%
def plot_goals_over_states(goals):
    # fig, (state_ax,  p2evalue_ax) = plt.subplots(1, 2, figsize=(1, 3))
    fig, p2evalue_ax = plt.subplots(1, 1, figsize=(1, 3))
    # state_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
    goal_time_limit = round(config.goal_policy_rollout_percentage * config.time_limit)
    # obs_list = tf.concat(all_obs, axis = 0)
    obs_list = all_obs
    before = obs_list[:,:goal_time_limit,:]
    before = before[:,:,:2]
    ep_order_before = tf.range(before.shape[0])[:, None]
    ep_order_before = tf.repeat(ep_order_before, before.shape[1], axis=1)
    before = tf.reshape(before, [before.shape[0]*before.shape[1], 2])
    after = obs_list[:,goal_time_limit:,:]
    after = after[:,:,:2]
    ep_order_after = tf.range(after.shape[0])[:, None]
    ep_order_after = tf.repeat(ep_order_after, after.shape[1], axis=1)
    after = tf.reshape(after, [after.shape[0]*after.shape[1], 2])
    # obs_list = tf.concat(obs, axis = 0)
    # obs_list = obs_list[:,:,:2]
    # ep_order = tf.range(obs_list.shape[0])[:, None] # Num_ep x 1
    # ep_order = tf.repeat(ep_order, obs_list.shape[1], axis=1) #  Num_ep x T
    # obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
    # ep_order = tf.reshape(ep_order, [ep_order.shape[0]*ep_order.shape[1]])
    ep_order_before = tf.reshape(ep_order_before, [ep_order_before.shape[0]*ep_order_before.shape[1]])
    ep_order_after = tf.reshape(ep_order_after, [ep_order_after.shape[0]*ep_order_after.shape[1]])
    goal_list = goals[...,:2]
    goal_list = tf.reshape(goal_list, [-1, 2])

    obs_list = obs_list[:, :, :2]
    obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
    values = tf.concat(value_list, axis = 0)
    values = values.numpy().flatten()
    cm = plt.cm.get_cmap("viridis")
    p2e_scatter = p2evalue_ax.scatter(
    x=obs_list[:,0],
    y=obs_list[:,1],
    s=1,
    # c=values,
    # cmap=cm,
    c='r',
    zorder=2,
    )


    # p2evalue_ax.scatter(
    # x=goal_list[:,0],
    # y=goal_list[:,1],
    # s=1,
    # c='r',
    # zorder=3,
    # )

    plt.colorbar(p2e_scatter, ax=p2evalue_ax)
    p2evalue_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
    p2evalue_ax.set_title('p2e value')

    fig = plt.gcf()
    fig.set_size_inches(7, 6)
    plt.savefig("goals_over_states.png", dpi=150)
    # plt.show()

# %% [markdown]
# # Default CEM planner

# %%
obs = [env.reset()]
obs = tf.nest.map_structure(lambda x: tf.expand_dims(tf.expand_dims(tf.tensor(x),0),0), obs)[0]
obs = agnt.wm.preprocess(obs)
update_goal_strategy()
if config.goal_strategy == "SubgoalPlanner":
    goal_strategy.will_update_next_call = True
    _ = goal_strategy.search_goal(obs)
    # generate CEM state distribution.
    goals = goal_strategy.sample_goal(NUM_GOALS)[None, None]
elif config.goal_strategy == "MEGA":
    # goal = goal_strategy.sample_goal(obs, state=None)
    # goals = np.array([goal])
    goals = np.array([goal_strategy.sample_goal(obs, state=None) for _ in range(NUM_GOALS)])
elif config.goal_strategy == "Greedy":
    # goal = goal_strategy.get_goal()
    # goals = np.array([goal])
    goals = np.array([goal_strategy.get_goal() for _ in range(NUM_GOALS)])

# plot_goals_over_states(goals)

# sys.exit(0)
# %% [markdown]
# It seems to be focusing on the start. One reason could be that the other 28 dimensions are very interesting.
# We can test by restricting the CEM to just search the xy dimensions.
#
# Let's see what imagined trajectories are coming from the CEM.

# %%
# num_vis = 5
# elite_inds = goal_strategy.elite_inds[:num_vis]
# elite_samples = goal_strategy.elite_samples[:num_vis] # topK x D
# final_seq = goal_strategy.final_seq

# elite_seq = tf.nest.map_structure(lambda x: tf.gather(x, elite_inds, axis=1), final_seq)
# elite_obs = wm.heads['decoder'](wm.rssm.get_feat(elite_seq))['observation'].mode() # T x topk x D

# %%
# plot the goals, and trajectories from planner process.
def plot_elites(goals, trajectories, mega_goal=None):
    # fig, (state_ax,  p2evalue_ax) = plt.subplots(1, 2, figsize=(1, 3))
    fig, p2evalue_ax = plt.subplots(1, 1, figsize=(1, 3))
    # state_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
    goal_time_limit = round(config.goal_policy_rollout_percentage * config.time_limit)
    # obs_list = tf.concat(all_obs, axis = 0)
    obs_list = all_obs
    before = obs_list[:,:goal_time_limit,:]
    before = before[:,:,:2]
    ep_order_before = tf.range(before.shape[0])[:, None]
    ep_order_before = tf.repeat(ep_order_before, before.shape[1], axis=1)
    before = tf.reshape(before, [before.shape[0]*before.shape[1], 2])
    after = obs_list[:,goal_time_limit:,:]
    after = after[:,:,:2]
    ep_order_after = tf.range(after.shape[0])[:, None]
    ep_order_after = tf.repeat(ep_order_after, after.shape[1], axis=1)
    after = tf.reshape(after, [after.shape[0]*after.shape[1], 2])
    # obs_list = tf.concat(obs, axis = 0)
    # obs_list = obs_list[:,:,:2]
    # ep_order = tf.range(obs_list.shape[0])[:, None] # Num_ep x 1
    # ep_order = tf.repeat(ep_order, obs_list.shape[1], axis=1) #  Num_ep x T
    # obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
    # ep_order = tf.reshape(ep_order, [ep_order.shape[0]*ep_order.shape[1]])
    ep_order_before = tf.reshape(ep_order_before, [ep_order_before.shape[0]*ep_order_before.shape[1]])
    ep_order_after = tf.reshape(ep_order_after, [ep_order_after.shape[0]*ep_order_after.shape[1]])
    goal_list = goals[...,:2]
    goal_list = tf.reshape(goal_list, [-1, 2])

    obs_list = obs_list[:, :, :2]
    obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
    values = tf.concat(value_list, axis = 0)
    values = values.numpy().flatten()
    cm = plt.cm.get_cmap("viridis")
    # p2e_scatter = p2evalue_ax.scatter(
    # x=obs_list[:,0],
    # y=obs_list[:,1],
    # s=1,
    # c=values,
    # cmap=cm,
    # zorder=2,
    # )

    p2evalue_ax.scatter(
    x=goal_list[:,0],
    y=goal_list[:,1],
    s=1,
    c='r',
    zorder=5,
    )
    if mega_goal is not None:
        p2evalue_ax.scatter(
        x=mega_goal[:,0],
        y=mega_goal[:,1],
        s=50,
        c='y',
        zorder=5,
        marker='*',
        )

    # print(trajectories.shape)
    # only start and end state of trajectories.
    # print(trajectories.shape)
    # trajectories = tf.stack([trajectories[0,:,:], trajectories[-1, :, :]], axis=0)
    # print(trajectories.shape)
    # trajectories = tf.reshape(trajectories, (-1, 29))
    # print(trajectories.shape)
    # p2evalue_ax.scatter(
    # x=trajectories[:,0],
    # y=trajectories[:,1],
    # s=1,
    # c='b',
    # zorder=4,
    # )
    trajectories = tf.transpose(trajectories, (1,0,2))
    # (num_vis,horizon,29)
    p2evalue_ax.plot(
        trajectories[:,:,0],
        trajectories[:,:,1],
        c='b',
        zorder=4,
        marker='.'
    )

    # plt.colorbar(p2e_scatter, ax=p2evalue_ax)
    p2evalue_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
    p2evalue_ax.set_title('elite goals and states')

    fig = plt.gcf()
    fig.set_size_inches(7, 6)
    # plt.show()
    plt.savefig("plot_elites.png", dpi=150)

# plot_elites(elite_samples, elite_obs, goal_strategy.mega_sample)


# evaluate the agent on these goals.
eval_driver = common.GCDriver([eval_env], config.goal_key)
eval_gcpolicy = partial(agnt.policy, mode='eval')
def expl_policy(obs, state, **kwargs):
    actions, state = agnt.expl_policy(obs, state, mode='train')
    return  actions, state

print('Start evaluation.')
outputs = [
    common.TerminalOutput(),
]
logger = common.Logger(step, outputs, multiplier=config.action_repeat) # initialize logger

goal_time_limit = round(config.goal_policy_rollout_percentage * config.time_limit)
def temporal_dist(obs):
    # TODO: assumes obs list is only 1 element.
    obs = tf.nest.map_structure(lambda x: tf.expand_dims(tf.tensor(x), 0), obs)[0]
    dist = agnt.temporal_dist(obs).numpy().item()
    success = dist < config.subgoal_threshold
    metric = {"subgoal_dist": dist, "subgoal_success": float(success)}
    return success, metric

def eval_fn(driver, gc_policy, expl_policy,  goals, logger):
    # for each goal, run the policy 10 times.
    env = driver._envs[0]
    num_eps_per_goal = 10
    if config.goal_strategy in {"MEGA"}:
        num_eps_per_goal = 30
    goals = tf.squeeze(goals)
    goals = tf.reshape(goals, (NUM_GOALS, 29))
    total_in_corner = 0
    all_observations = []
    all_observations_per_goal = [] # Num_goals x ...
    for g_idx, g in enumerate(goals):
        def get_goal(*args, **kwargs):
            return tf.squeeze(g)
        observations_per_goal = []
        for _ in range(num_eps_per_goal):
            driver.reset()
            driver(gc_policy, get_goal=get_goal, episodes=1)
            # driver(gc_policy, expl_policy, get_goal, episodes=1, goal_time_limit=goal_time_limit, goal_checker=temporal_dist)
            ep = driver._eps[0] # get episode data of 1st env.
            ep = {k: driver._convert([t[k] for t in ep]) for k in ep[0]}
            score = float(ep['reward'].astype(np.float64).sum())
            print(f'Eval goal {g_idx} has {len(ep["reward"] - 1)} steps and return {score:.1f}.')
            all_xy = ep["observation"][:, :2]
            all_observations.extend(all_xy)
            observations_per_goal.append(all_xy)
            # upper left is x < 2 and y > 3.
            in_corner = (all_xy[:, 0] < 2.0) * (all_xy[:, 1] > 3.0)
            total_in_corner += in_corner.sum()
        all_observations_per_goal.append(observations_per_goal)

    print(total_in_corner)

    # plot all trajectories and goals.
    all_observations = np.array(all_observations)
    fig, traj_ax = plt.subplots(1, 1, figsize=(1, 3))
    # (num_vis,horizon,29)
    # all_observations = tf.transpose(all_observations, (1,0,2))
    traj_ax.scatter(
        all_observations[:,0],
        all_observations[:,1],
        c='b',
        zorder=4,
        marker='.'
    )

    # plot the greedy goal
    traj_ax.scatter(
        goals[:,0],
        goals[:,1],
        c='r',
        zorder=4,
        marker='*'
    )

    traj_ax.set(xlim=(-1, 5.25), ylim=(-1, 5.25))
    traj_ax.set_title(f'{config.goal_strategy} States')
    fig = plt.gcf()
    fig.set_size_inches(7, 6)
    # plt.show()
    plt.savefig(f"{config.goal_strategy}_states.png", dpi=150)
    # save data for future reference
    with open(f"{config.goal_strategy}_visualize.pkl", "wb") as f:
        data = [all_observations_per_goal, goals]
        pickle.dump(data, f)





eval_fn(eval_driver, eval_gcpolicy, expl_policy, goals,  logger)