import collections
from functools import partial as bind

import elements
import embodied
import numpy as np
import cv2
import os

def delayedeval(
    make_agent,
    make_env,
    make_logger,
    args):

  agent = make_agent()
  logger = make_logger()

  logdir = elements.Path(args.logdir)
  logdir.mkdir()
  print('Logdir', logdir)
  step = logger.step
  usage = elements.Usage(**args.usage)
  agg = elements.Agg()
  epstats = elements.Agg()
  episodes = collections.defaultdict(elements.Agg)
  should_log = elements.when.Clock(args.log_every)
  policy_fps = elements.FPS()
  episode_cnt = []

  @elements.timer.section('logfn')
  def logfn(tran, worker):
    episode = episodes[worker]
    tran['is_first'] and episode.reset()
    episode.add('score', tran['reward'], agg='sum')
    episode.add('length', 1, agg='sum')
    episode.add('rewards', tran['reward'], agg='stack')
    episode.add('success', tran.get('log/success', 0), agg='max')
    if args.delay.visualize_reconstruction:
      obs_img = cv2.rotate(tran.get('image'), 1)
      recon_img = cv2.rotate(tran.get('recon'), 1)
      prev_mask = tran.get('prev_mask')
      idx = tran['log/obs_env_time']
      ep = len(episode_cnt)
      obs_img = cv2.cvtColor(obs_img, cv2.COLOR_RGB2BGR)
      recon_img = cv2.cvtColor(recon_img, cv2.COLOR_RGB2BGR)
      os.makedirs(f"{args.logdir}/obs_img/{ep}", exist_ok=True)
      os.makedirs(f"{args.logdir}/recon_img/{ep}", exist_ok=True)
      os.makedirs(f"{args.logdir}/prev_mask/{ep}", exist_ok=True)
      cv2.imwrite(f"{args.logdir}/obs_img/{ep}/{idx}.png", obs_img)
      cv2.imwrite(f"{args.logdir}/recon_img/{ep}/{idx}.png", recon_img)
      np.save(f"{args.logdir}/prev_mask/{ep}/{idx}.npy", prev_mask)

    for key, value in tran.items():
      isimage = (value.dtype == np.uint8) and (value.ndim == 3)
      if isimage and worker == 0:
        episode.add(f'policy_{key}', value, agg='stack')
      elif key.startswith('log/'):
        assert value.ndim == 0, (key, value.shape, value.dtype)
        episode.add(key + '/avg', value, agg='avg')
        episode.add(key + '/max', value, agg='max')
        episode.add(key + '/sum', value, agg='sum')
    if tran['is_last']:
      episode_cnt.append(1)
      result = episode.result()
      logger.add({
          'score': result.pop('score'),
          'length': result.pop('length'),
          'success': result.pop('success'),
      }, prefix='episode')
      rew = result.pop('rewards')
      if len(rew) > 1:
        result['reward_rate'] = (np.abs(rew[1:] - rew[:-1]) >= 0.01).mean()
      epstats.add(result)

  fns = [bind(make_env, i) for i in range(args.envs)]
  driver = embodied.Driver(fns, parallel=(not args.debug))
  driver.on_step(lambda tran, _: step.increment())
  driver.on_step(lambda tran, _: policy_fps.step())
  driver.on_step(logfn)

  cp = elements.Checkpoint()
  cp.agent = agent
  cp.load(args.from_checkpoint, keys=['agent'])

  print('Start evaluation')
  policy = lambda *params: agent.policy_delayed(*params, report_recon=True, mode='eval')

  imagine = lambda *params: agent.imagine(*params)
  observe = lambda *params: agent.observe(*params)

  driver.reset(agent.init_policy)
  while step < args.steps:
    driver.delayed_call(imagine, observe, policy, steps=10)
    if should_log(step):
      logger.add(agg.result())
      logger.add(epstats.result(), prefix='epstats')
      logger.add(usage.stats(), prefix='usage')
      logger.add({'fps/policy': policy_fps.result()})
      logger.add({'timer': elements.timer.stats()['summary']})
      logger.write()

  logger.close()
