import gym
import dreamerv2.api as dv2
from dreamerv2.gc_agent import GCAgent
import common
import tensorflow as tf
from common import Config
import envs
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from run_goal_cond import make_env, make_report_render_function, make_eval_fn
import os
from dreamerv2.common.replay import convert, load_episodes
from dreamerv2.expl import Plan2Explore
import torch
import collections
import pickle
from tqdm import tqdm
import ruamel.yaml as yaml

# 1. Load the weights of the gc_agent to get the dynamical distance network.

# create the config dict
# create the environment

def plot_heatmap_p2e(ep_subsample=5, step_subsample=5):
  episode_folder = '/home/anonymous/logdir/mega_pointmaze_35bs_dyndist_3p2e_400mlp_3train_5action/train_episodes'
  wm = agnt.wm
  episode_folder = pathlib.Path(episode_folder).expanduser()
  complete_episodes = load_episodes(episode_folder)
  episodes = list(complete_episodes.values())
  obs_xlist = []
  obs_ylist = []
  reward_list = []
  i = 1
  for episode in episodes[::ep_subsample]:
    sequence = {
      k: convert(v[::step_subsample])
      for k, v in episode.items() if not k.startswith('log_')}
    data = wm.preprocess(sequence)
    for keys in data.keys():
      data[keys] = tf.expand_dims(data[keys], axis = 0)
    states = data['observation'] #1 x T x 2
    T = data['observation'].shape[1]
    embed = wm.encoder(data)
    post, _ = wm.rssm.observe(embed, data['action'], data['is_first'], None)
    data['feat'] = wm.rssm.get_feat(post)
    reward_fn = agnt._expl_behavior._intr_reward
    reward = reward_fn(data).reshape((-1, ))
    # comb_data = data
    # comb_data['reward'] = reward
    # comb_data['embed'] = embed
    # if i == 1:
    #   with open("test1.pkl", "wb") as f:
    #       pickle.dump(comb_data, f)
    #   i = i + 1
    obs_xlist.append(data['observation'][:,:,0])
    obs_ylist.append(data['observation'][:,:,1])
    reward_list.append(reward)
  x_list = tf.concat(obs_xlist, axis = 0)
  y_list = tf.concat(obs_ylist, axis = 0)
  x_list = tf.reshape(x_list, [-1])
  y_list = tf.reshape(y_list, [-1])
  rewards = tf.concat(reward_list, axis = 0)
  cm = plt.cm.get_cmap("viridis")
  print('plotting')
  plt.scatter(
    x=x_list,
    y=y_list,
    s=1,
    c=rewards,
    cmap=cm,
    zorder=3,
    )
  fig = plt.gcf()
  fig.set_size_inches(8, 6)
  plt.colorbar()
  plt.title('plan2explore reward')
  plt.savefig('p2p_map_.png', dpi = 300)
  return x_list, y_list, rewards

def plot_heatmap_p2e_wbatch(ep_folder, ep_subsample=5, step_subsample=5, batch_size = 50, obs_key = 'observation'):
  # 1. Load all episodes
  wm = agnt.wm
  episode_folder = pathlib.Path(ep_folder).expanduser()
  complete_episodes = load_episodes(episode_folder)
  episodes = list(complete_episodes.values())
  obs = []
  reward_list = []
  for ep_count, episode in enumerate(episodes[::ep_subsample]):
    if (ep_count % batch_size) == 0:
      start = ep_count
      chunk = collections.defaultdict(list)
    sequence = {
      k: convert(v[::step_subsample])
      for k, v in episode.items() if not k.startswith('log_')}
    data = wm.preprocess(sequence)
    for key, value in data.items():
      chunk[key].append(value)
    # 3. Forward passing each batch after it reaches size batch_size or it reaches the end of the episode list
    if (ep_count % batch_size == (batch_size - 1) or ep_count == len(episodes[::ep_subsample]) - 1):
      end = ep_count
      print(f'start: {start} end: {end}')
      chunk = {k: tf.stack(v) for k, v in chunk.items()}
      embed = wm.encoder(chunk)
      post, _ = wm.rssm.observe(embed, chunk['action'], chunk['is_first'], None)
      chunk['feat'] = wm.rssm.get_feat(post)
      reward_fn = agnt._expl_behavior._intr_reward
      reward = reward_fn(chunk).reshape((-1, ))
      # If you want to save the first batch into a pkl file
      # comb_data = chunk
      # comb_data['reward'] = reward
      # comb_data['embed'] = embed
      # if ep_count == (batch_size - 1):
      #   with open("test2.pkl", "wb") as f:
      #     pickle.dump(comb_data, f)
      obs.append(chunk[obs_key])
      reward_list.append(reward)
  # 4. Plotting
  obs_list = tf.concat(obs, axis = 0)
  obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
  rewards = tf.concat(reward_list, axis = 0)
  cm = plt.cm.get_cmap("viridis")
  print('plotting')
  plt.scatter(
    x=obs_list[:,0],
    y=obs_list[:,1],
    s=1,
    c=rewards,
    cmap=cm,
    zorder=3,
    )
  fig = plt.gcf()
  fig.set_size_inches(8, 6)
  plt.colorbar()
  plt.title('plan2explore reward')
  plt.savefig('p2p_map_50batch_old.png', dpi = 300)
  return obs_list[:,0], obs_list[:,1], rewards

def plot_value_fn_p2e_wbatch(agnt, ep_folder, ep_subsample=5, step_subsample=5, batch_size=50, obs_key='observation'):
  # 1. Load all episodes
  wm = agnt.wm
  episode_folder = pathlib.Path(ep_folder).expanduser()
  complete_episodes = load_episodes(episode_folder)
  episodes = list(complete_episodes.values())
  obs = []
  value_list = []
  for ep_count, episode in enumerate(episodes[::ep_subsample]):
    if (ep_count % batch_size) == 0:
      start = ep_count
      chunk = collections.defaultdict(list)
    sequence = {
      k: convert(v[::step_subsample])
      for k, v in episode.items() if not k.startswith('log_')}
    data = wm.preprocess(sequence)
    for key, value in data.items():
      chunk[key].append(value)
    # 3. Forward passing each batch after it reaches size batch_size or it reaches the end of the episode list
    if (ep_count % batch_size == (batch_size - 1) or ep_count == len(episodes[::ep_subsample]) - 1):
      end = ep_count
      chunk = {k: tf.stack(v) for k, v in chunk.items()}
      embed = wm.encoder(chunk)
      post, _ = wm.rssm.observe(embed, chunk['action'], chunk['is_first'], None)
      chunk['feat'] = wm.rssm.get_feat(post)
      value_fn = agnt._expl_behavior.ac._target_critic
      value = value_fn(chunk['feat']).mode()
      # If you want to save the first batch into a pkl file
      # comb_data = chunk
      # comb_data['reward'] = reward
      # comb_data['embed'] = embed
      # if ep_count == (batch_size - 1):
      #   with open("test2.pkl", "wb") as f:
      #     pickle.dump(comb_data, f)
      obs.append(chunk[obs_key])
      value_list.append(value)
  # 4. Plotting
  obs_list = tf.concat(obs, axis = 0)
  obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
  values = tf.concat(value_list, axis = 0)
  values = values.numpy().flatten()
  cm = plt.cm.get_cmap("viridis")
  print('plotting')
  plt.scatter(
    x=obs_list[:,0],
    y=obs_list[:,1],
    s=1,
    c=values,
    cmap=cm,
    zorder=3,
    )
  fig = plt.gcf()
  fig.set_size_inches(8, 6)
  plt.colorbar()
  plt.title('plan2explore value fn reward')
  plt.savefig('p2e_val_fn_reward', dpi = 300)
  return obs_list[:,0], obs_list[:,1], values


def plot_heatmap(env, agnt, x_min, x_max, y_min, y_max, x_div, y_div, goal_x, goal_y):
  maze = env._env._env._env
  fig, ax = plt.subplots(1, 1, figsize=(1, 1))
  maze.maze.plot(ax) # plot the walls
  if goal_x < x_min or goal_x > x_max:
    raise ValueError("invalid goal x coord")
  if goal_y < y_min or goal_y > y_max:
    raise ValueError("invalid goal y coord")
  if x_min >= x_max or y_min >= y_max:
    raise ValueError("invalid coordinate ranges")
  x = np.linspace(x_min, x_max, x_div)
  y = np.linspace(y_min, y_max, y_div)
  XY = X, Y = np.meshgrid(x, y)
  XY = np.stack([X, Y], axis=-1)
  XY = XY.reshape(x_div * y_div, 2)
  goal_vec = np.zeros((x_div*y_div, 2))
  goal_vec[:,0] = goal_vec[:,0] + goal_x
  goal_vec[:,1] = goal_vec[:,1] + goal_y
  obs = {"observation": XY, "goal": goal_vec, "reward": np.zeros(XY.shape[0]), "discount": np.ones(XY.shape[0]), "is_terminal": np.zeros(XY.shape[0])} 
  temporal_dist = agnt.temporal_dist(obs)
  plt.tricontourf(XY[:, 0], XY[:, 1], temporal_dist, zorder = 1)
  plt.colorbar()
  plt.title(f"Temp. Dist between (0,0) and ({goal_x},{goal_y})")
  fig.set_size_inches(8, 6)
 
def plot_state_frequency(ep_folder, fig):
  all_observations = []
  all_goals = []
  for file in tqdm(list(os.scandir(ep_folder))[::1]):
    data = np.load(file.path)
    observations = data['observation']
    goal = data['goal']
    all_observations.append(observations)
    all_goals.append(goal)

  all_observations = np.concatenate(all_observations)
  all_goals = np.concatenate(all_goals)
  plt.scatter(all_observations[:,0], all_observations[:,1], s = 2, c ="blue", alpha = 0.75)
  # plt.scatter(all_goals[:,0], all_goals[:,1], s = 2, c ="red", alpha = 0.85, facecolors='none', edgecolors='r')
  with open("maze_plt_data.pkl", "wb") as f:
    pickle.dump([all_observations, all_goals], f)
  fig.set_size_inches(8, 6)
  plt.title("State occupancy grid")
  plt.savefig('scatterplot.png', dpi=200)

def plot_states_reward(env, agnt, episode_folder):
  maze = env._env._env._env
  fig, ax = plt.subplots(1, 1, figsize=(1, 1))
  maze.maze.plot(ax) # plot the walls
  ep_subsample = 1
  step_subsample = 1
  batch_size = 50
  plot_value_fn_p2e_wbatch(agnt, episode_folder, ep_subsample, step_subsample, batch_size)

def plot_grid_heatmap(env, agnt):
  maze = env._env._env._env
  x_lims = y_lims = [0.0, 9.0]
  goal_grid_size = 3
  x_min, x_max = -0.5, 9.5
  y_min, y_max = -0.5, 9.5
  x_div = y_div = 100
  fig, axes = plt.subplots(goal_grid_size, goal_grid_size, figsize=(7, 7), sharex='all', sharey='all')
  for i, gx in enumerate(np.linspace(x_lims[0], x_lims[1], goal_grid_size)):
    for j, gy in enumerate(np.linspace(y_lims[0], y_lims[1], goal_grid_size)):
      ax = axes[i,j]
      maze.maze.plot(ax) # plot the walls
      ax.set_title(f"Dist between (0,0) and ({gx:.1f},{gy:.1f})", fontsize=5)
      x = np.linspace(x_min, x_max, x_div)
      y = np.linspace(y_min, y_max, y_div)
      XY = X, Y = np.meshgrid(x, y)
      XY = np.stack([X, Y], axis=-1)
      XY = XY.reshape(x_div * y_div, 2)
      import ipdb; ipdb.set_trace()
      goal_vec = np.zeros((x_div*y_div, 2))
      goal_vec[:,0] = goal_vec[:,0] + gx
      goal_vec[:,1] = goal_vec[:,1] + gy
      obs = {"observation": XY, "goal": goal_vec, "reward": np.zeros(XY.shape[0]), "discount": np.ones(XY.shape[0]), "is_terminal": np.zeros(XY.shape[0])} 
      temporal_dist = agnt.temporal_dist(obs)
      im = ax.tricontourf(XY[:, 0], XY[:, 1], temporal_dist, zorder = 1)
      ax.scatter(x=[gx], y=[gy], c="r", marker="*", s=20, zorder=2)
      ax.axis('off')
  plt.savefig('scatterplot.png', dpi=200)
  print("done")

def plot_grid_heatmap_downscale_umaze(agnt, gx, gy):
  other_dims = np.concatenate([[6.08193526e-01,  9.87496030e-01,
  1.82685311e-03, -6.82827458e-03,  1.57485326e-01,  5.14617396e-02,
  1.22386603e+00, -6.58701813e-02, -1.06980319e+00,  5.09069276e-01,
 -1.15506861e+00,  5.25953435e-01,  7.11716520e-01], np.zeros(14)])
  goal_grid_size = 3
  x_min, x_max = -1, 5.25
  y_min, y_max = -1, 5.25
  x_div = y_div = 100
  x = np.linspace(x_min, x_max, x_div)
  y = np.linspace(y_min, y_max, y_div)
  XY = X, Y = np.meshgrid(x, y)
  XY = np.stack([X, Y], axis=-1)
  XY = XY.reshape(x_div * y_div, 2)
  XY_plus = np.hstack((XY, np.tile(other_dims, (XY.shape[0], 1))))
  goal_vec = np.zeros((x_div*y_div, 29))
  goal_vec[:,0] = goal_vec[:,0] + gx
  goal_vec[:,1] = goal_vec[:,1] + gy
  goal_vec[:,2:] = goal_vec[:,2:] + other_dims
  obs = {"observation": XY_plus, "goal": goal_vec, "reward": np.zeros(XY.shape[0]), "discount": np.ones(XY.shape[0]), "is_terminal": np.zeros(XY.shape[0])} 
  temporal_dist = agnt.temporal_dist(obs)
  plt.tricontourf(XY[:, 0], XY[:, 1], temporal_dist, z = 1)
  plt.colorbar()
  plt.title("Output of learned reward function")
  plt.savefig('scatterplot.png', dpi=200)
  print("done")

def main():
  """ 
  General workflow:
  1.Set the following variables. 
  2.Go to configs.yaml and change hyperparameters.
  """
  configs = yaml.safe_load((
      pathlib.Path('/home/anonymous/logdir/umazefulldownscale_data_collect/config.yaml').read_text()))
  config = common.Config(configs)
  """ ========= SETUP ENVIRONMENTS  ========"""
  env = make_env(config, use_goal_idx=False, log_per_goal=True)

  """ ========= SETUP ENVIRONMENTS  ========"""
  env = make_env(config, use_goal_idx=False, log_per_goal=True)
  # eval_env = make_env(config, use_goal_idx=True, log_per_goal=False, eval=True)
  # report_render_fn = make_report_render_function(config)
  # eval_fn = make_eval_fn(config)

  """ ========= SETUP TF2 and GPU ========"""
  tf.config.run_functions_eagerly(not config.jit)
  # tf.data.experimental.enable_debug_mode(not config.jit)
  message = 'No GPU found. To actually train on CPU remove this assert.'
  assert tf.config.experimental.list_physical_devices('GPU'), message
  for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
  assert config.precision in (16, 32), config.precision
  if config.precision == 16:
    from tensorflow.keras.mixed_precision import experimental as prec
    prec.set_policy(prec.Policy('mixed_float16'))
  logdir = pathlib.Path(config.logdir).expanduser()
  logdir.mkdir(parents=True, exist_ok=True)
  config.save(logdir / 'config.yaml')
  print(config, '\n')
  print('Logdir', logdir)
  replay = common.Replay(logdir / 'train_episodes', **config.replay) # initialize replay buffer
  step = common.Counter(replay.stats['total_steps']) # initialize step counter
  dataset = iter(replay.dataset(**config.dataset))

  agnt = GCAgent(config, env.obs_space, env.act_space, step, None) 
  train_agent = common.CarryOverState(agnt.train)
  train_agent(next(dataset))
  agnt.load(logdir / 'variables.pkl')
  # plot_states_reward(env, agnt, logdir / "train_episodes")
  plot_grid_heatmap_downscale_umaze(agnt, 0, 4.2)
if __name__ == "__main__":
  main()