from itertools import count
import torch
import gym
import hydra
import numpy as np
from omegaconf import DictConfig, OmegaConf

from make_envs import make_env
from agent import make_agent
from utils.utils import eval_mode


def get_args(cfg: DictConfig):
    cfg.device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(OmegaConf.to_yaml(cfg))
    return cfg


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
    args = get_args(cfg)

    device = torch.device(args.device)
    if device.type == 'cuda' and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    BATCH = args.train.batch
    EPISODE_STEPS = int(args.env.eps_steps)

    w = np.array(args.agent.preference)

    env = make_env(args, w, is_mogym=args.env.is_mogym, render=args.env.render)
    agent = make_agent(env, args)

    policy_file = 'experts'
    if args.eval.policy:
        policy_file = f'{args.eval.policy}'
    print(f'Loading policy from: {policy_file}', f'_{args.env.name}')

    agent.load(hydra.utils.to_absolute_path(policy_file), f'_{args.env.name}')

    episode_reward = 0
    evaluate(agent, w, env, device, num_episodes=args.eval.eps)

    if args.eval_only:
        exit()


def evaluate(actor, w, env, device, num_episodes=10):
    """Evaluates the policy.
    Args:
      actor: A policy to evaluate.
      env: Environment to evaluate the policy on.
      num_episodes: A number of episodes to average the policy on.
    Returns:
      Averaged reward and a total number of steps.
    """
    total_timesteps = []
    total_returns = []

    vec_returns = [[] for i in range(w.shape[0])]
    for _ in range(num_episodes):
        r = [0 for i in range(w.shape[0])]
        eps_timesteps = 0
        eps_returns = 0

        state, info = env.reset()
        done = False
        while not done:
            with eval_mode(actor):
                action = actor.choose_action(state, torch.FloatTensor(w).to(device), sample=False)
            next_state, reward, terminated, truncated, info = env.step(action)

            done = terminated or truncated

            if 'ale.lives' in info:  # true for breakout, false for pong
                done = info['ale.lives'] == 0
            eps_returns += reward
            eps_timesteps += 1
            state = next_state
            for i in range(w.shape[0]):
                r[i] += info['reward_dim{}'.format(i)]
            if done:
                for i in range(w.shape[0]):
                    vec_returns[i].append(r[i])
        print("rewards: {:.2f}".format(eps_returns))
        print("len: {:.2f}".format(eps_timesteps))
        total_timesteps.append(eps_timesteps)
        total_returns.append(eps_returns)

    total_returns = np.array(total_returns)
    total_timesteps = np.array(total_timesteps)

    for i in range(w.shape[0]):
        print('returns_dim{}: {}'.format(i, vec_returns[i].mean()))
    print("Avg rewards: {:.2f} +/- {:.2f}".format(total_returns.mean(), total_returns.std()))
    print("Avg len: {:.2f} +/- {:.2f}".format(total_timesteps.mean(), total_timesteps.std()))


if __name__ == "__main__":
    main()
