import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import collections
from collections import defaultdict
# import dreamerv2.api as dv2
# import common
# from common.replay import convert, load_episodes
from envs.sibrivalry.toy_maze import PointMaze2D
from envs.sibrivalry.ant_maze import AntMazeEnvFullDownscale
#import tensorflow as tf
from tqdm import tqdm, trange
import pickle
import imageio
import pathlib


# def plot_state_frequency(ep_folder):
#   x_list = []
#   y_list = []
#   for file in tqdm(os.scandir(ep_folder)):
#     data = np.load(file.path)
#     # import ipdb; ipdb.set_trace()
#     observations = data['observation']
#     goal = data['goal']
#     plt.scatter(observations[:,0], observations[:,1], s = 2, c ="blue", alpha = 0.5)
#     plt.scatter(goal[:,0], goal[:,1], s = 2, c ="red", alpha = 0.5)
#     # import ipdb; ipdb.set_trace()
#     # print('iteration')
#   fig.set_size_inches(8, 6)
#   plt.title("State occupancy grid")
#   plt.savefig('states.png', dpi=200)

def plot_states_wbatch(agnt, complete_episodes, ep_subsample=5, step_subsample=1, batch_size = 50, obs_key = 'observation'):
  # 1. Load all episodes
  wm = agnt.wm
  episodes = list(complete_episodes.values())
  obs = []
  reward_list = []
  for ep_count, episode in enumerate(episodes[::ep_subsample]):
    # if ep_count > 5900:
    #   break
    # 2. Adding episodes to the batch
    if (ep_count % batch_size) == 0:
      start = ep_count
      chunk = collections.defaultdict(list)
    sequence = {
      k: convert(v[::step_subsample])
      for k, v in episode.items() if not k.startswith('log_')}
    data = wm.preprocess(sequence)
    for key, value in data.items():
      chunk[key].append(value)
    # 3. Forward passing each batch after it reaches size batch_size or it reaches the end of the episode list
    if (ep_count % batch_size == (batch_size - 1) or ep_count == len(episodes[::ep_subsample]) - 1):
      end = ep_count
      print(f'start: {start} end: {end}')
      chunk = {k: tf.stack(v) for k, v in chunk.items()}
      embed = wm.encoder(chunk)
      post, _ = wm.rssm.observe(embed, chunk['action'], chunk['is_first'], None)
      # If you want to save the first batch into a pkl file
      # comb_data = chunk
      # comb_data['reward'] = reward
      # comb_data['embed'] = embed
      # if ep_count == (batch_size - 1):
      #   with open("test2.pkl", "wb") as f:
      #     pickle.dump(comb_data, f)
      obs.append(chunk[obs_key])
      # reward_list.append(reward)
  # 4. Plotting
  obs_list = tf.concat(obs, axis = 0)
  obs_list = tf.reshape(obs_list, [obs_list.shape[0]*obs_list.shape[1], 2])
  print('plotting')
  plt.scatter(
    x=obs_list[:,0],
    y=obs_list[:,1],
    s=1,
    c='blue',
    zorder=3,
    )
  # plt.colorbar()
  fig = plt.gcf()
  plt.title('states')
  fig.set_size_inches(8, 6)
  return fig

def load_episodes(directory, capacity=None, minlen=1):
  # The returned directory from filenames to episodes is guaranteed to be in
  # temporally sorted order.
  filenames = sorted(directory.glob('*.npz')) # earliest first.
  if capacity:
    num_steps = 0
    num_episodes = 0
    for filename in reversed(filenames): # get latest episodes.
      length = int(str(filename).split('-')[-1][:-4])
      num_steps += length
      num_episodes += 1
      if num_steps >= capacity:
        break
    filenames = filenames[-num_episodes:]
  episodes = {}
  for filename in filenames:
    try:
      with filename.open('rb') as f:
        episode = np.load(f)
        episode = {k: episode[k] for k in episode.keys()}
    except Exception as e:
      print(f'Could not load episode {str(filename)}: {e}')
      continue
    episodes[str(filename)] = episode
  return episodes

def plot_state_frequency(save_path, logdir, ep_folder, fig):
  # TODO: fix this to respect temporal order.
  pkl_path = os.path.join(logdir, "maze_plt_data.pkl")
  all_observations = []
  all_goals = []
  # for file in tqdm(list(os.scandir(os.path.join(logdir, ep_folder)))[::1]):
  #   data = np.load(file.path)
  #   observations = data['observation']
  #   goal = data['goal']
  #   all_observations.append(observations)
  #   all_goals.append(goal)
  # all_observations = np.concatenate(all_observations)
  # all_goals = np.concatenate(all_goals)
  with open(pkl_path, "rb") as f:
    all_observations, all_goals = pickle.load(f)
  num_eps = 0
  batch_size = 1
  chunk = defaultdict(list)
  gif = []
  maze = PointMaze2D()
  fig, ax = plt.subplots(1, 1, figsize=(8, 6))
  maze.maze.plot(ax) # plot the walls
  ax.set(xlim=(-1, 11), ylim=(-1, 11))
  total_eps = 0
  for idx in trange(0, len(all_observations), 51):
    start = idx
    end = idx + 51
    num_eps += 1
    total_eps += 1
    ep_obs = all_observations[start:end]
    ep_goal = all_goals[start][None]
    chunk['obs'].append(ep_obs)
    chunk['goal'].append(ep_goal)
    if num_eps >= batch_size: # TODO: handle last batch
      data = {k: np.concatenate(v, 0) for k,v in chunk.items()}
      chunk = defaultdict(list)

      # plot data here.
      plt.scatter(data['obs'][:,0], data['obs'][:,1], s = 2, c ="blue", alpha = 0.75)
      plt.scatter(data['goal'][:,0], data['goal'][:,1], s = 2, c ="red", alpha = 0.85, facecolors='none', edgecolors='r')
      plt.title(f"Ep {total_eps - num_eps} - {total_eps}")
      fig.set_size_inches(8, 6)
      fig.canvas.draw()
      image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
      image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
      gif.append(image_from_plot)

      fig, ax = plt.subplots(1, 1, figsize=(8, 6))
      maze.maze.plot(ax) # plot the walls
      ax.set(xlim=(-1, 11), ylim=(-1, 11))
      num_eps = 0
    if total_eps > 100: # around 1e6 steps.
      break

  imageio.mimwrite(save_path, gif)

def find_umaze_downscale_count(directory, x_min=-1, y_min=-1, x_max=3.25, y_max=1, capacity=None):
  directory = pathlib.Path(directory).expanduser()
  episodes = list(load_episodes(directory, capacity).values())
  count_list = []
  count = 0
  ep_in_list = []
  ep_in = 0
  for ep_count, ep in enumerate(reversed(episodes)):
    ep_counted = False
    for obs in ep['observation']:
      if x_min <= obs[0] <= x_max and y_min <= obs[1] <= y_max:
        count = count + 1
        if not ep_counted:
          ep_in = ep_in + 1
          ep_counted = True
    count_list.append(count)
    ep_in_list.append(ep_in)
  ax1 = plt.subplot()
  l1, = ax1.plot(count_list, color='red')
  ax2 = ax1.twinx()
  l2, = ax2.plot(ep_in_list, color='orange')
  plt.legend([l1, l2], ["state count", "episodes count"])
  plt.xlabel('episodes')
  plt.savefig('top_left_state.png')
  return count








  # plt.scatter(all_observations[:,0], all_observations[:,1], s = 2, c ="blue", alpha = 0.75)
  # plt.scatter(all_goals[:,0], all_goals[:,1], s = 2, c ="red", alpha = 0.85, facecolors='none', edgecolors='r')
  # with open(pkl_path, "wb") as f:
  #   pickle.dump([all_observations, all_goals], f)
  # fig.set_size_inches(8, 6)
  # plt.title("State occupancy grid")
  # plt.savefig('scatterplot.png', dpi=200)


if __name__ == "__main__":
  episodes_folder = '/home/anonymous/logdir/umazefulldownscale_ar1_dyndist_sgp_32_32_mlp_eval_env_20/train_episodes'
  # episodes_folder = pathlib.Path(episodes_folder)
  # logdir = '/home/anonymous/logdir/mega_pointmaze_10p2e_1ar_dyndist_2pol_goalmppi'
  # logdir = '/home/anonymous/logdir/mega_pointmaze_10p2e_1ar_dyndist_2pol_goalcem'
  # logdir = '/home/anonymous/logdir/mega_pointmaze_10p2e_1ar_dyndist_2pol_goalmega'
  # episodes_folder = 'train_episodes'
  # maze = PointMaze2D()
  # fig, ax = plt.subplots(1, 1, figsize=(1, 1))
  # maze.maze.plot(ax) # plot the walls
  # episodes = load_episodes(episodes_folder, capacity=50000)
  # episodes = list(episodes.values())
  # env = AntMazeEnvFullDownscale()
  # i = 0
  # for ep in episodes:
  #   all_img = []
  #   executions = []
  #   all_goal = []
  #   goal = ep['goal']
  #   env.maze.wrapped_env.set_state(goal[0][:15], np.zeros_like(goal[0][:14]))
  #   env.maze.wrapped_env.sim.forward()
  #   goal_img = env.render(mode = 'rgb_array')
  #   import ipdb; ipdb.set_trace()
  #   for obs in ep['observation']:
  #     env.maze.wrapped_env.set_state(obs[:15], np.zeros_like(obs[:14]))
  #     env.maze.wrapped_env.sim.forward()
  #     img = env.render(mode = 'rgb_array')
  #     whole_img = np.vstack((goal_img, img))
  #     all_img.append(whole_img)
  #   imageio.mimwrite(os.path.join('/home/anonymous/projects',f'umazefulldownscale_ar1_dyndist_sgp_32_32_mlp_eval_env_20_wgoal_{i}.mp4'), all_img, macro_block_size = 1)
  #   i = i + 1
  # b_left = find_umaze_downscale_count(episodes_folder, capacity=50000)
  t_left = find_umaze_downscale_count(episodes_folder, x_min=-1, y_min=3.4, x_max=1.5, y_max=5)
  # print(b_left / t_left)

  # fig = None
  # save_path = 'mppi.mp4'
  # plot_state_frequency(save_path, logdir, episodes_folder, fig)
  # plt.show()