import algo.net as net
from algo.agent import Policy, Agent
from algo.env_util import make_vec_env
from algo.util import LinearSchedule

import numpy as np
import torch
import wandb
import argparse
import yaml
import os
import random


MAGIC = 65535


def get_args():

    quantizers = ['vq', 'gumbel_hard', 'gumbel_soft', 'exact']

    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Breakout-MinAtar-v0')
    parser.add_argument('--num_envs', type=int, default=1)
    parser.add_argument('--buffer_size', type=int, default=1000000)
    parser.add_argument('--learning_rate', type=float, default=2.5e-4)
    parser.add_argument('--learning_rate_decay', default=False, action='store_true')
    parser.add_argument('--learning_rate_model', type=float, default=2.5e-4)
    parser.add_argument('--adam_eps', type=float, default=1e-5)
    parser.add_argument('--initial_steps', type=int, default=1000)
    parser.add_argument('--updates_per_step', type=float, default=0.03125)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--n_step', type=int, default=16)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--backup', type=str, default='naive', choices=['naive', 'dae', 'offdae', 'tree'])
    parser.add_argument('--max_grad_norm', type=float, default=-1.)
    parser.add_argument('--steps', type=int, default=100000)
    parser.add_argument('--net', type=str, default='MinAtarCNN')
    parser.add_argument('--target_update_steps', type=int, default=1)
    parser.add_argument('--target_update_tau', type=float, default=0.995)
    parser.add_argument('--target_bootstrap', default=False, action='store_true')
    parser.add_argument('--beta_kl', type=float, default=10.)
    parser.add_argument('--quantizer', type=str, default='gumbel_soft', choices=quantizers)
    parser.add_argument('--beta_entropy_model', type=float, default=1e-4)
    parser.add_argument('--beta_commit', type=float, default=1.)
    parser.add_argument('--z_dim', type=int, default=16)
    parser.add_argument('--gumbel_temp_final', type=float, default=0.1)
    parser.add_argument('--gumbel_temp_steps', type=int, default=1000000)
    parser.add_argument('--evaluate_episodes', type=int, default=100)
    parser.add_argument('--steps_per_eval', type=int, default=-1)
    parser.add_argument('--config', type=str, default=None)
    parser.add_argument('--hparam_search', default=False, action='store_true')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--logging', default=False, action='store_true')
    parser.add_argument('--project', type=str, default='test')
    return parser.parse_args()


def initialize_wandb(args):

    if not args.hparam_search:
        wandb.init(config=args.config, project=args.project)
    else:
        wandb.init()

    wandb.config.setdefaults(args)


if __name__ == "__main__":

    args = get_args()

    if args.logging:
        initialize_wandb(args)
        config = wandb.config
    else:
        config = args
        if args.config and os.path.isfile(args.config):
            with open(args.config) as f:
                for k, v in yaml.safe_load(f).items():
                    setattr(config, k, v['value'])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device", flush=True)

    env = make_vec_env(args.env, config.num_envs)
    env_eval = make_vec_env(args.env, config.evaluate_episodes)

    if args.seed:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    policy = Policy(
            env,
            torso=getattr(net, config.net),
            z_dim=config.z_dim,
            lucky=(config.backup == 'offdae'),
            ).to(device=device)

    if config.learning_rate_decay:
        lr = LinearSchedule(config.learning_rate, 0, config.steps)
    else:
        lr = LinearSchedule(config.learning_rate)

    agent = Agent(
            env,
            policy,
            learning_rate=lr,
            learning_rate_model=config.learning_rate_model,
            adam_eps=config.adam_eps,
            buffer_size=config.buffer_size,
            initial_steps=config.initial_steps,
            updates_per_step=config.updates_per_step,
            batch_size=config.batch_size,
            n_step=config.n_step,
            z_dim=config.z_dim,
            gumbel_temperature=LinearSchedule(1, config.gumbel_temp_final, config.gumbel_temp_steps),
            quantizer=config.quantizer,
            gamma=config.gamma,
            target_update_steps=config.target_update_steps,
            target_update_tau=config.target_update_tau,
            target_bootstrap=config.target_bootstrap,
            beta_kl=config.beta_kl,
            beta_entropy_model=config.beta_entropy_model,
            beta_commit=config.beta_commit,
            backup=config.backup,
            max_grad_norm=config.max_grad_norm,
            device=device,
            logging=args.logging,
    )

    if args.logging:
        wandb.watch(policy)
    agent.train(config.steps, env_eval, config.steps_per_eval, config.evaluate_episodes, args.seed)
    average_score = agent.evaluate(env_eval, config.evaluate_episodes, args.seed + MAGIC)
    print(f'Evaluation Score: {average_score:.3f}', flush=True)
    if args.logging:
        wandb.log({
            'score_avg': average_score,
            'episode_stats': wandb.Table(columns=['r', 'l'], data=agent.episode_stats),
        })
