import numpy as np
import matplotlib.pyplot as plt
import pathlib
from tqdm import tqdm

def load_episodes(directory, ep_capacity=None, minlen=1, ep_subsample=1):
  # The returned directory from filenames to episodes is guaranteed to be in
  # temporally sorted order.
  filenames = sorted(directory.glob('*.npz')) # earliest first.
  if ep_capacity:
    filenames = filenames[-ep_capacity:]
  episodes = {}
  for filename in tqdm(filenames[::ep_subsample]):
    try:
      with filename.open('rb') as f:
        episode = np.load(f)
        episode = {k: episode[k] for k in episode.keys()}
    except Exception as e:
      print(f'Could not load episode {str(filename)}: {e}')
      continue
    episodes[str(filename)] = episode
  return episodes

def plot_state_frequency(save_path, logdir, ep_folder, fig, ep_subsample, plot_goal, plot_dims):
  num_cols = len(plot_dims)
  figwidth =  (4 * len(plot_dims)) + 2
  if plot_goal:
    fig, (ob_obj_ax, ob_gripper_ax, g_obj_ax, g_gripper_ax) = plt.subplots(4 ,num_cols, figsize=(figwidth/3,(16 + 2)/3))
  else:
    fig, (ob_obj_ax, ob_gripper_ax) = plt.subplots(2,num_cols, figsize=(figwidth/3,(8 + 2)/3))
  all_observations = []
  all_goals = []
  directory = pathlib.Path(logdir).expanduser() / ep_folder
  episodes = load_episodes(directory, ep_subsample=ep_subsample)

  for name, data in episodes.items():
    observations = data['observation']
    goal = data['goal'][None,0]
    all_observations.append(observations)
    all_goals.append(goal)
  all_observations = np.concatenate(all_observations)
  all_goals = np.concatenate(all_goals)
  ob_obj_pos = all_observations[:,:3]
  ob_grip_pos = all_observations[:, 3:6]

  g_obj_pos = all_goals[:,:3]
  g_grip_pos = all_goals[:, 3:6]

  plot_dim_name = dict([(0,'x'), (1,'y'), (2,'z')])
  def plot_axes(axes, data, cmap, title):
    for ax, pd in zip(axes, plot_dims):
      ax.scatter(x=data[:, pd[0]],
        y=data[:, pd[1]],
        s=1,
        c=np.arange(len(data)),
        cmap=cmap,
        zorder=3,
      )
      ax.set_title(f"{title} {plot_dim_name[pd[0]]}{plot_dim_name[pd[1]]}", fontdict={'fontsize':4})

  plot_axes(ob_obj_ax, ob_obj_pos, 'Blues',f"Obj")
  plot_axes(ob_gripper_ax, ob_grip_pos, 'Reds',f"Gripper")
  if plot_goal:
    plot_axes(g_obj_ax, g_obj_pos, 'Blues',f"Goal Obj")
    plot_axes(g_gripper_ax, g_grip_pos, 'Reds',f"Goal Gripper")

  plt.subplots_adjust(left=0.1,
                      bottom=0.1, 
                      right=0.9, 
                      top=0.9, 
                      wspace=0.4, 
                      hspace=0.4)
  plt.savefig(save_path, dpi=100)
  plt.cla()
  plt.clf()

def plot_recent_goals(save_path, logdir, ep_folder,  ep_subsample, num_goals):
  """Plot most recent goals on the xz axis."""
  plot_dims = [[0, 2]]
  fig, g_ax = plt.subplots(1,1, figsize=(3,3))
  all_goals = []
  directory = pathlib.Path(logdir).expanduser() / ep_folder
  episodes = load_episodes(directory, ep_capacity=num_goals, ep_subsample=ep_subsample)

  for name, data in episodes.items():
    goal = data['goal'][None,0]
    all_goals.append(goal)

  all_goals = np.concatenate(all_goals)
  g_obj_pos = all_goals[:,:3]
  g_grip_pos = all_goals[:, 3:6]

  plot_dim_name = dict([(0,'x'), (1,'y'), (2,'z')])
  def plot_axes(axes, data, cmap, title):
    for ax, pd in zip(axes, plot_dims):
      ax.scatter(x=data[:, pd[0]],
        y=data[:, pd[1]],
        s=1,
        c=np.arange(len(data)),
        cmap=cmap,
        zorder=3,
      )
      ax.set_title(f"{title} {plot_dim_name[pd[0]]}{plot_dim_name[pd[1]]}", fontdict={'fontsize':10})

  plot_axes([g_ax], g_obj_pos, 'Blues',f"Goal Obj")
  plot_axes([g_ax], g_grip_pos, 'Reds',f"Goal Gripper")
  g_ax.set_xlim([1, 1.6]) # obj x axis
  g_ax.set_ylim([0.3, 1.0]) # obj z axis

  plt.savefig(save_path, dpi=100)
  plt.cla()
  plt.clf()

def plot_states(save_path, logdir, ep_folder,  ep_subsample):
  # just plot states on the xz axis.
  plot_dims = [[0, 2]]
  fig, g_ax = plt.subplots(1,1, figsize=(3,3))
  all_goals = []
  directory = pathlib.Path(logdir).expanduser() / ep_folder
  episodes = load_episodes(directory, ep_subsample=ep_subsample)

  for name, data in episodes.items():
    goal = data['observation']
    all_goals.append(goal)

  all_goals = np.concatenate(all_goals)
  g_obj_pos = all_goals[:,:3]
  g_grip_pos = all_goals[:, 3:6]

  plot_dim_name = dict([(0,'x'), (1,'y'), (2,'z')])
  def plot_axes(axes, data, cmap, title, zorder):
    for ax, pd in zip(axes, plot_dims):
      ax.scatter(x=data[:, pd[0]],
        y=data[:, pd[1]],
        s=1,
        c=np.arange(len(data)),
        cmap=cmap,
        zorder=zorder,
      )
      ax.set_title(f"{title} {plot_dim_name[pd[0]]}{plot_dim_name[pd[1]]}", fontdict={'fontsize':10})

  plot_axes([g_ax], g_obj_pos, 'Blues',f"State Obj", 3)
  plot_axes([g_ax], g_grip_pos, 'Reds',f"State Gripper", 2)
  g_ax.set_xlim([1, 1.6]) # obj/grip x axis
  g_ax.set_ylim([0.3, 1.0]) # obj/grip z axis

  plt.savefig(save_path, dpi=100)
  plt.cla()
  plt.clf()

def plot_goals_and_states(save_path, logdir, ep_folder,  ep_subsample, num_goals):
  """Plot most recent goals on the xz axis."""
  plot_dims = [[0, 2]]
  fig, (g_ax, ob_ax)= plt.subplots(1,2, figsize=(7,3))
  all_observations = []
  all_goals = []
  directory = pathlib.Path(logdir).expanduser() / ep_folder
  goal_episodes = load_episodes(directory, ep_capacity=num_goals, ep_subsample=1)
  episodes = load_episodes(directory, ep_subsample=ep_subsample)

  for name, data in episodes.items():
    observations = data['observation']
    all_observations.append(observations)
  for name, data in goal_episodes.items():
    goal = data['goal'][None,0]
    all_goals.append(goal)

  all_observations = np.concatenate(all_observations)
  ob_obj_pos = all_observations[:,:3]
  ob_grip_pos = all_observations[:, 3:6]

  all_goals = np.concatenate(all_goals)
  g_obj_pos = all_goals[:,:3]
  g_grip_pos = all_goals[:, 3:6]

  plot_dim_name = dict([(0,'x'), (1,'y'), (2,'z')])
  def plot_axes(axes, data, cmap, title, zorder):
    for ax, pd in zip(axes, plot_dims):
      ax.scatter(x=data[:, pd[0]],
        y=data[:, pd[1]],
        s=1,
        c=np.arange(len(data)),
        cmap=cmap,
        zorder=zorder,
      )
      ax.set_title(f"{title} {plot_dim_name[pd[0]]}{plot_dim_name[pd[1]]}", fontdict={'fontsize':10})

  plot_axes([ob_ax], ob_obj_pos, 'Blues',f"Obs Obj", 3)
  plot_axes([ob_ax], ob_grip_pos, 'Reds',f"Obs Gripper", 2)

  plot_axes([g_ax], g_obj_pos, 'Blues',f"Goal Obj", 3)
  plot_axes([g_ax], g_grip_pos, 'Reds',f"Goal Gripper", 3)


  g_ax.set_xlim([1, 1.6]) # obj x axis
  g_ax.set_ylim([0.3, 1.0]) # obj z axis
  ob_ax.set_xlim([1, 1.6]) # obj x axis
  ob_ax.set_ylim([0.3, 1.0]) # obj z axis

  plt.savefig(save_path, dpi=100)
  plt.cla()
  plt.clf()

if __name__ == "__main__":
  import argparse

  parser = argparse.ArgumentParser(description='Create a gif from tensorboard file images')
  parser.add_argument(
    'logdir',
    type=str,
    help='Path to logdir folder'
  )
  parser.add_argument(
      '--episodes_folder',
      type=str,
      default='train_episodes'
  )
  parser.add_argument(
      '--save_path',
      default="./pnp_state.png",
      type=str,
      help='File to store the final result'
  )
  parser.add_argument(
      '--ep_subsample',
      type=int,
      default=1,
  )
  parser.add_argument(
      '--plot_goal',
      action='store_true',
      default=False,
  )
  args = parser.parse_args()
  logdir = '/home/anonymous/logdir/test_fetchpnp'
  fig = None
  plot_dims = [[0,1], [0,2], [1,2]] # xy, xz, yz
  # plot_state_frequency(args.save_path, args.logdir, args.episodes_folder, fig, args.ep_subsample, args.plot_goal, plot_dims)
  # plot_recent_goals(args.save_path, args.logdir, args.episodes_folder, ep_subsample=1, num_goals=10)
  # plot_states(args.save_path, args.logdir, args.episodes_folder, ep_subsample=args.ep_subsample)
  plot_goals_and_states(args.save_path, args.logdir, args.episodes_folder, ep_subsample=1, num_goals=10)