import pathlib
import sys
from functools import partial as bind

sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))

import elements
import embodied

from pprint import pprint

class TestDriver:
  def __init__(self):
    self.usage = elements.Usage(
      psutil = True,
    )

  def test_throughput_dummy(self, parallel=True):
    from embodied.envs import dummy
    make_env_fns = [bind(dummy.Dummy, 'disc') for _ in range(32)]
    example = make_env_fns[0]()
    agent = embodied.RandomAgent(example.obs_space, example.act_space)
    example.close()
    driver = embodied.Driver(make_env_fns, parallel)
    driver.reset(agent.init_policy)
    fps = elements.FPS()
    while True:
      driver(agent.policy, steps=100)
      fps.step(100 * len(make_env_fns))
      print(f'FPS: {fps.result():.0f}')

  def test_throughput_crafter(self, parallel=True):
    from embodied.envs import crafter
    make_env_fns = [bind(crafter.Crafter, 'reward') for _ in range(32)]
    example = make_env_fns[0]()
    agent = embodied.RandomAgent(example.obs_space, example.act_space)
    example.close()
    driver = embodied.Driver(make_env_fns, parallel)
    driver.reset(agent.init_policy)
    fps = elements.FPS()
    while True:
      driver(agent.policy, steps=100)
      fps.step(100 * len(make_env_fns))
      print(f'FPS: {fps.result():.0f}')
      pprint(self.usage.stats())

  def test_throughput_unknot(self, parallel=True):
    from embodied.envs.knot import Unknot
    make_env_fns = [bind(Unknot, task="flat_pixel",
                         size=(128, 128), backend="mjc") for _ in range(32)]
    example = make_env_fns[0]()
    agent = embodied.RandomAgent(example.obs_space, example.act_space)
    example.close()
    driver = embodied.Driver(make_env_fns, parallel)
    driver.reset(agent.init_policy)
    fps = elements.FPS()
    while True:
      driver(agent.policy, steps=100)
      fps.step(100 * len(make_env_fns))
      print(f'FPS: {fps.result():.0f}')
      pprint(self.usage.stats())


if __name__ == '__main__':
  TestDriver().test_throughput_dummy(parallel=False)  # 470_000 FPS
  TestDriver().test_throughput_dummy(parallel=True)  # 200_000 FPS
  TestDriver().test_throughput_crafter(parallel=False)  # 16_000 FPS
  TestDriver().test_throughput_crafter(parallel=True)  # 50_000 FPS and high var
  TestDriver().test_throughput_unknot(parallel=False)  # 3_000 FPS
  TestDriver().test_throughput_unknot(parallel=True)  # 17_000 FPS
