import os
import json


class Logger:

	def __init__(self, renderer, logpath, vis_freq=10, max_render=8):
		self.renderer = renderer
		self.savepath = logpath
		self.vis_freq = vis_freq
		self.max_render = max_render

	def log(self, t, samples, state, rollout=None, conditions=None):
		if t % self.vis_freq != 0:
			return

		## render image of plans
		self.renderer.composite(
			os.path.join(self.savepath, f'{t}.png'),
			samples.observations,
			conditions
		)

		## render video of plans
		# self.renderer.render_plan(
		#     os.path.join(self.savepath, f'{t}_plan.mp4'),
		#     samples.actions[:self.max_render],
		#     samples.observations[:self.max_render],
		#     state,
		#     conditions, 
		# )

		if rollout is not None:
			## render video of rollout thus far
			self.renderer.render_rollout(
				os.path.join(self.savepath, f'rollout.png'),
				rollout,
				conditions,
				fps=80,
			)

	def finish(self, t, score, total_reward, terminal, diffusion_experiment, value_experiment):
		json_path = os.path.join(self.savepath, 'rollout.json')
		json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal,
			'epoch_diffusion': diffusion_experiment.epoch, 'epoch_value': value_experiment.epoch}
		json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True)
		print(f'[ utils/logger ] Saved log to {json_path}')
