from __future__ import division
from ast import arg
from distutils.util import strtobool
from multiprocessing.sharedctypes import Value
from test import test
from memory import ReplayMemory
from env_crafter import Env
from agent import QrdqnAgent, Agent
from tqdm import trange
from collections import deque
import torch
import numpy as np
import atari_py
import pickle
import os
from datetime import datetime
import bz2
import argparse
import wandb

import chunkedfile
chunkedfile.patch_pathlib_append(3600)

# -*- coding: utf-8 -*-


# from env import Env


# Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps
parser = argparse.ArgumentParser(description='Rainbow')
parser.add_argument('--id', type=str, default='default', help='Experiment ID')
parser.add_argument('--logdir', type=str, default='logdir')
parser.add_argument('--seed', type=int, default=123, help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--game', type=str, default='space_invaders',
                    choices=atari_py.list_games(), help='ATARI game')
parser.add_argument('--T-max', type=int, default=int(5e6), metavar='STEPS',
                    help='Number of training steps (4x number of frames)')
parser.add_argument('--max-episode-length', type=int, default=int(108e3),
                    metavar='LENGTH', help='Max episode length in game frames (0 to disable)')
parser.add_argument('--history-length', type=int, default=4,
                    metavar='T', help='Number of consecutive states processed')
parser.add_argument('--architecture', type=str, default='canonical', choices=[
                    'canonical', 'data-efficient'], metavar='ARCH', help='Network architecture')
parser.add_argument('--hidden-size', type=int, default=512,
                    metavar='SIZE', help='Network hidden size')
parser.add_argument('--noisy-std', type=float, default=0.1, metavar='σ',
                    help='Initial standard deviation of noisy linear layers')
parser.add_argument('--atoms', type=int, default=51, metavar='C',
                    help='Discretised size of value distribution')
parser.add_argument('--V-min', type=float, default=-10,
                    metavar='V', help='Minimum of value distribution support')
parser.add_argument('--V-max', type=float, default=10,
                    metavar='V', help='Maximum of value distribution support')
parser.add_argument('--model', type=str, metavar='PARAMS',
                    help='Pretrained model (state dict)')
parser.add_argument('--memory-capacity', type=int, default=int(1e6),
                    metavar='CAPACITY', help='Experience replay memory capacity')
parser.add_argument('--replay-frequency', type=int, default=4,
                    metavar='k', help='Frequency of sampling from memory')
parser.add_argument('--priority-exponent', type=float, default=0.5, metavar='ω',
                    help='Prioritised experience replay exponent (originally denoted α)')
parser.add_argument('--priority-weight', type=float, default=0.4, metavar='β',
                    help='Initial prioritised experience replay importance sampling weight')
parser.add_argument('--multi-step', type=int, default=3,
                    metavar='n', help='Number of steps for multi-step return')
parser.add_argument('--discount', type=float, default=0.99,
                    metavar='γ', help='Discount factor')
parser.add_argument('--target-update', type=int, default=int(8e3), metavar='τ',
                    help='Number of steps after which to update target network')
parser.add_argument('--reward-clip', type=int, default=1,
                    metavar='VALUE', help='Reward clipping (0 to disable)')
parser.add_argument('--learning-rate', type=float,
                    default=0.0000625, metavar='η', help='Learning rate')
parser.add_argument('--adam-eps', type=float, default=1.5e-4,
                    metavar='ε', help='Adam epsilon')
parser.add_argument('--batch-size', type=int, default=32,
                    metavar='SIZE', help='Batch size')
parser.add_argument('--norm-clip', type=float, default=10,
                    metavar='NORM', help='Max L2 norm for gradient clipping')
parser.add_argument('--learn-start', type=int, default=int(20e3),
                    metavar='STEPS', help='Number of steps before starting training')
parser.add_argument('--evaluate', action='store_true', help='Evaluate only')
parser.add_argument('--evaluation-interval', type=int, default=20000,
                    metavar='STEPS', help='Number of training steps between evaluations')
parser.add_argument('--evaluation-episodes', type=int, default=10,
                    metavar='N', help='Number of evaluation episodes to average over')
# TODO: Note that DeepMind's evaluation method is running the latest agent for 500K frames ever every 1M steps
parser.add_argument('--evaluation-size', type=int, default=500,
                    metavar='N', help='Number of transitions to use for validating Q')
parser.add_argument('--render', action='store_true',
                    help='Display screen (testing only)')
parser.add_argument('--enable-cudnn', action='store_true',
                    help='Enable cuDNN (faster but nondeterministic)')
parser.add_argument('--checkpoint-interval', default=0,
                    help='How often to checkpoint the model, defaults to 0 (never checkpoint)')
parser.add_argument('--memory', help='Path to save/load the memory from')
parser.add_argument('--disable-bzip-memory', action='store_true',
                    help='Don\'t zip the memory file. Not recommended (zipping is a bit slower and much, much smaller)')
# New things
parser.add_argument('--algorithm', type=str, default='rainbow')
parser.add_argument('--explore_strat', type=str,
                    default='greedy', choices=["greedy", "egreedy", "ucb", "thompson"])
parser.add_argument('--ucb_c', default=0.5, type=float,
                    help='Path to save/load the memory from')
parser.add_argument('--n_quantiles', default=200, type=int,
                    help='number of quantiles for qrdqn')
parser.add_argument('--n_ensemble', default=5, type=int,
                    help='number of ensemble heads')
# epsilon greedy
parser.add_argument('--end_eps', default=0.1, type=float,
                    help='Final values for epsilon greedy')
parser.add_argument("--eps_decay_period", type=int, default=5000)

parser.add_argument("--per", type=lambda x: bool(strtobool(x)),
                    default=False, help="Whether to use PER")
parser.add_argument("--bootstrapped_qrdqn", type=lambda x: bool(strtobool(x)),
                    default=False, help="Whether to bootstrapping")
parser.add_argument("--use_wbb", type=lambda x: bool(strtobool(x)),
                    default=False, help="Whether to use uncertainty weighted backprop")
parser.add_argument("--qrdqn_always_train_feat", type=lambda x: bool(strtobool(x)),
                    default=False, help="Always train the feature extractor in ensemble")
parser.add_argument("--use_average_target", type=lambda x: bool(strtobool(x)),
                    default=False, help="Train ensemble with average target")
parser.add_argument("--double_qrdqn", type=lambda x: bool(strtobool(x)),
                    default=False, help="Using double q learning for qrdqn")
parser.add_argument("--wandb", type=lambda x: bool(strtobool(x)),
                    default=False, help="Log to wandb")

# Setup
args = parser.parse_args()

print(' ' * 26 + 'Options')
for k, v in vars(args).items():
    print(' ' * 26 + k + ': ' + str(v))
results_dir = os.path.join(args.logdir, 'results', args.id)
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
metrics = {'steps': [], 'rewards': [],
           'Qs': [], 'best_avg_reward': -float('inf')}
np.random.seed(args.seed)
torch.manual_seed(np.random.randint(1, 10000))
if torch.cuda.is_available() and not args.disable_cuda:
    args.device = torch.device('cuda')
    torch.cuda.manual_seed(np.random.randint(1, 10000))
    torch.backends.cudnn.enabled = args.enable_cudnn
else:
    args.device = torch.device('cpu')


if not args.wandb:
    os.environ["WANDB_MODE"] = "offline"

wandb.init(
    # settings=wandb.Settings(start_method="fork"),
    project="crafter",
    entity="anon",
    name=args.id,
    config=args,
    group=None,
)


# Simple ISO 8601 timestamped logger
def log(s):
    print('[' + str(datetime.now().strftime('%Y-%m-%dT%H:%M:%S')) + '] ' + s)


def load_memory(memory_path, disable_bzip):
    if disable_bzip:
        with open(memory_path, 'rb') as pickle_file:
            return pickle.load(pickle_file)
    else:
        with bz2.open(memory_path, 'rb') as zipped_pickle_file:
            return pickle.load(zipped_pickle_file)


def save_memory(memory, memory_path, disable_bzip):
    if disable_bzip:
        with open(memory_path, 'wb') as pickle_file:
            pickle.dump(memory, pickle_file)
    else:
        with bz2.open(memory_path, 'wb') as zipped_pickle_file:
            pickle.dump(memory, zipped_pickle_file)


# Environment
env = Env(args)
env.train()
action_space = env.action_space()

# Agent
if args.algorithm == "qrdqn":
    print("Using QR-DQN")
    dqn = QrdqnAgent(args, env)
elif args.algorithm == "rainbow":
    print("Using Rainbow")
    dqn = Agent(args, env)
else:
    raise ValueError("Unrecognized algorithm: {}".format(args.algorithm))

# If a model is provided, and evaluate is false, presumably we want to resume, so try to load memory
if args.model is not None and not args.evaluate:
    if not args.memory:
        raise ValueError(
            'Cannot resume training without memory save path. Aborting...')
    elif not os.path.exists(args.memory):
        raise ValueError(
            'Could not find memory file at {path}. Aborting...'.format(path=args.memory))

    mem = load_memory(args.memory, args.disable_bzip_memory)

else:
    mem = ReplayMemory(args, args.memory_capacity)

priority_weight_increase = (1 - args.priority_weight) / \
    (args.T_max - args.learn_start)


# Construct validation memory
val_mem = ReplayMemory(args, args.evaluation_size)
current_episode_reward = 0
episode_reward = deque(maxlen=30)
action_infos = deque(maxlen=100)

T, done = 0, True
while T < args.evaluation_size:
    if done:
        state = env.reset()

    next_state, _, done = env.step(np.random.randint(0, action_space))
    val_mem.append(state, -1, 0.0, done)
    state = next_state
    T += 1

if args.evaluate:
    dqn.eval()  # Set DQN (online network) to evaluation mode
    avg_reward, avg_Q = test(args, 0, dqn, val_mem,
                             metrics, results_dir, evaluate=True)  # Test
    print('Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q))
else:
    # Training loop
    dqn.train()
    T, done = 0, True
    for T in range(1, args.T_max + 1):
        if T % 1000 == 0:
            print('Iter {} / {}'.format(T, args.T_max))

        if done:
            episode_reward.append(current_episode_reward)
            current_episode_reward = 0
            state = env.reset()

        if T % args.replay_frequency == 0:
            dqn.reset_noise()  # Draw a new set of noisy weights

        # Choose an action greedily (with noisy weights)
        if T < args.learn_start:
            action = np.random.randint(env.action_space())
            info = {}
        else:
            action, info = dqn.explore(state, t=T)
            action_infos.append(info)

        next_state, reward, done = env.step(action)  # Step
        current_episode_reward += reward

        if args.reward_clip > 0:
            reward = max(min(reward, args.reward_clip), -
                         args.reward_clip)  # Clip rewards
        mem.append(state, action, reward, done)  # Append transition to memory

        # Train and test
        if T >= args.learn_start:
            # Anneal importance sampling weight β to 1
            mem.priority_weight = min(
                mem.priority_weight + priority_weight_increase, 1)

            if T % args.replay_frequency == 0:
                # Train with n-step distributional double-Q learning
                dqn.learn(mem)

            if T % args.evaluation_interval == 0:
                avg_reward = np.mean(episode_reward)
                log('T = ' + str(T) + ' / ' + str(args.T_max) +
                    ' | Avg. reward: ' + str(avg_reward))
                stats = {
                    "episode reward": avg_reward,
                }
                if args.algorithm == "qrdqn":
                    stats["epsilon"] =dqn.epsilon(T)
                if args.explore_strat == "ucb":
                    for ucb_stats in ["mean", "eps_var", "ale_var"]:
                        cumulative = 0
                        for info in action_infos:
                            cumulative += info[ucb_stats].max(axis=1)[0].mean().item()
                        stats["ucb / " + ucb_stats] = cumulative / (len(action_infos) + 0.001)
                    combined = 0
                    for info in action_infos:
                        combined += (info["mean"] + info["eps_var"]).max(axis=1)[0].mean().item()
                    stats["ucb / combined"] = combined / (len(action_infos) + 0.001)
                wandb.log(stats, step=T)
                # dqn.eval()  # Set DQN (online network) to evaluation mode
                # avg_reward, avg_Q = test(
                #     args, T, dqn, val_mem, metrics, results_dir)  # Test
                # log('T = ' + str(T) + ' / ' + str(args.T_max) +
                #     ' | Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q))
                # dqn.train()  # Set DQN (online network) back to training mode

            #   # If memory path provided, save it
            #   if args.memory is not None:
            #     save_memory(mem, args.memory, args.disable_bzip_memory)

            # Update target network
            if T % args.target_update == 0:
                dqn.update_target_net()

            # Checkpoint the network
            if (args.checkpoint_interval != 0) and (T % args.checkpoint_interval == 0):
                dqn.save(results_dir, 'checkpoint.pth')

        state = next_state

env.close()
