import argparse
import tensorflow as tf
from tensorflow.python.summary.summary_iterator import summary_iterator
import imageio
import cv2
import numpy as np
from tqdm import tqdm
import pathlib
import dreamerv2.api as dv2
import common
import collections

parser = argparse.ArgumentParser(description='Create a gif from tensorboard file images')
parser.add_argument(
  'logdir',
  type=str,
  help='Path to tensorboard file'
)
# parser.add_argument(
#     '--method',
#     type=str,
# )
# parser.add_argument(
#     '--output',
#     default="./out.mp4",
#     type=str,
#     help='File to store the final result'
# )
# parser.add_argument(
#     '--start',
#     type=int,
#     help='First image in the gif (corresponds to tensorboard step)'
# )
# parser.add_argument(
#     '--stop',
#     type=int,
#     help='Last image in the gif (corresponds to tensorboard step)'
# )

args = parser.parse_args()
logdir = args.logdir

tb_output = common.TensorBoardOutput(logdir)

ep_folder = "train_episodes"
directory = pathlib.Path(logdir).expanduser() / ep_folder
# The returned directory from filenames to episodes is guaranteed to be in
# temporally sorted order.
filenames = sorted(directory.glob('*.npz')) # earliest first.
# if ep_capacity:
#   filenames = filenames[-ep_capacity:]
ep_subsample = 1
num_obj = 2
num_steps = 0
all_total_obj_above = collections.defaultdict(float)
# total_obj_above = 0
z_threshold = 0.5
for filename in tqdm(filenames[::ep_subsample]):
  summary = []
  try:
    with filename.open('rb') as f:
      length = int(str(filename).split('-')[-1][:-4])
      num_steps += length
      episode = np.load(f)
      # episode = {k: episode[k] for k in episode.keys()}
      obs = episode["observation"]
      all_obj_pos = np.split(obs[:, 5:], num_obj, axis=1)
      for idx, obj_pos in enumerate(all_obj_pos):
        obj_above = np.sum(obj_pos[:, -1] > z_threshold).astype(np.float32)
        all_total_obj_above[idx] += obj_above
        summary.append((num_steps, f"obj{idx}_z_above_{z_threshold:.2f}", obj_above))
        summary.append((num_steps, f"total_obj{idx}_z_above_{z_threshold:.2f}", np.array(all_total_obj_above[idx])))
      tb_output(summary)

  except Exception as e:
    print(f'Could not load episode {str(filename)}: {e}')
    continue


# all_imgs = []
# color = (255, 0, 0)
# thickness = 1
# for e in summary_iterator(args.filename):
#   step = e.step
#   if step < args.start:
#     continue
#   if step > args.stop:
#     break

#   for v in e.summary.value:
#     if v.tag == 'state_occupancy':
#       step = e.step
#       img_str = v.image.encoded_image_string
#       height = v.image.height
#       width = v.image.width
#       img = tf.squeeze(tf.io.decode_image(img_str, channels=3))
#       img = img.numpy()
#       cv2.putText(
#             img,
#             f"{step}",
#             (16, 32),
#             cv2.FONT_HERSHEY_SIMPLEX,
#             1.5,
#             color,
#             thickness,
#             cv2.LINE_AA,
#       )
#       cv2.putText(
#             img,
#             f"{args.method}",
#             (350, 32),
#             cv2.FONT_HERSHEY_SIMPLEX,
#             1.5,
#             color,
#             thickness,
#             cv2.LINE_AA,
#       )
#       all_imgs.append(img)
# # make gif
# print("writing", len(all_imgs))
# imageio.mimwrite(args.output, all_imgs)


