import argparse
from collections import deque
import itertools
import time

import numpy as np

import agents
import cmdline
from tasks import ControlTask
from utils import jax_device_context


def main(**kwargs):  # Hook for automation
    kwargs = cmdline.insert_defaults(kwargs)
    cmdline.assert_not_none(kwargs)

    force_cpu = kwargs.pop('cpu', False)
    with jax_device_context(force_cpu):
        return run(**kwargs)


def run(env: str, agent: str, discount: float, duration: float, seed: int, verbose: bool = False, **agent_kwargs):
    duration = int(duration)
    assert duration > 0

    task = ControlTask(env, discount, duration, seed)

    agent_cls = getattr(agents, agent)
    agent_args = (task.env.observation_space, task.env.action_space, seed, task.discount)
    agent = agent_cls(*agent_args, **agent_kwargs)

    # Start training

    time_periods = 0
    period_start = time.time()
    period_length_minutes = 15

    obs, _ = task.reset()

    performance = np.zeros(task.duration + 1)
    i = 0
    episodes = 0
    recent_undisc_returns = deque(maxlen=100)

    for t in itertools.count(start=1):
        action, b_prob = agent.act(obs)
        next_obs, reward, terminated, truncated, _ = task.step(action)
        agent.reinforce(obs, action, next_obs, reward, terminated, truncated, b_prob)

        if terminated or truncated:
            episodes += 1
            recent_undisc_returns.append(task.undiscounted_return)
            avg_undisc_return = np.mean(recent_undisc_returns)

            while i <= t:
                performance[i] = avg_undisc_return

                if (i % 500) == 0 and verbose:
                    print(f"{task.time():.2f}s  t={i}  ep={episodes}  {task.undiscounted_return} (avg: {avg_undisc_return:.2f})")

                if i == task.duration:
                    return performance
                i += 1

            next_obs, _ = task.reset()

        obs = next_obs

        # Periodic logging even when verbose=False (for time estimation on Compute Canada)
        minutes = (time.time() - period_start) / 60
        if minutes >= period_length_minutes:
            time_periods += 1
            percent_complete = round(100 * t / task.duration, 1)
            print(f"Approximately {round(time_periods * period_length_minutes, 1)} minutes elapsed; "
                  f"{t}/{task.duration} timesteps ({percent_complete}%) completed")
            period_start = time.time()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--env', type=str)
    parser.add_argument('--agent', type=str)
    parser.add_argument('--defaults', type=str)
    parser.add_argument('--discount', type=float)
    parser.add_argument('--duration', type=float)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--verbose', action='store_true')
    kwargs = cmdline.parse_kwargs(parser)
    main(**kwargs)
