from __future__ import division

import argparse
import bz2
import os
import pathlib
import pickle
import time
from datetime import datetime
from test import test

import atari_py
import numpy as np
import pandas as pd
import torch
from agent import Agent
from env import Env
from memory import ReplayMemory
from tqdm import trange

os.chdir(pathlib.Path(__file__).parent.absolute())
# Note that hyperparameters may originally be reported in ATARI game frames instead of agent steps
parser = argparse.ArgumentParser(description='MfRL_Disc')
parser.add_argument('--id', type=str, default='default', help='Experiment ID')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--game', type=str, default='ms_pacman', choices=atari_py.list_games(), help='ATARI game')
parser.add_argument('--T-max', type=int, default=int(200000001), metavar='STEPS',
                    help='Number of training steps (4x number of frames)')
parser.add_argument('--max-episode-length', type=int, default=int(108e3), metavar='LENGTH',
                    help='Max episode length in game frames (0 to disable)')
parser.add_argument('--history-length', type=int, default=4, metavar='T', help='Number of consecutive states processed')
parser.add_argument('--architecture', type=str, default='data-efficient', choices=['canonical', 'data-efficient'],
                    metavar='ARCH', help='Network architecture')
parser.add_argument('--hidden-size', type=int, default=256, metavar='SIZE', help='Network hidden size')
parser.add_argument('--noisy-std', type=float, default=0.1, metavar='?',
                    help='Initial standard deviation of noisy linear layers')
parser.add_argument('--atoms', type=int, default=51, metavar='C', help='Discretised size of value distribution')
parser.add_argument('--V-min', type=float, default=-10, metavar='V', help='Minimum of value distribution support')
parser.add_argument('--V-max', type=float, default=10, metavar='V', help='Maximum of value distribution support')
parser.add_argument('--model', type=str, metavar='PARAMS', help='Pretrained model (state dict)')
parser.add_argument('--memory-capacity', type=int, default=int(100000), metavar='CAPACITY',
                    help='Experience replay memory capacity')
parser.add_argument('--replay-frequency', type=int, default=1, metavar='k', help='Frequency of sampling from memory')
parser.add_argument('--replay-type', type=str, default='MaPER')
parser.add_argument('--priority-exponent', type=float, default=0.5, metavar='?',
                    help='Prioritised experience replay exponent (originally denoted α)')
parser.add_argument('--priority-weight', type=float, default=0.4, metavar='β',
                    help='Initial prioritised experience replay importance sampling weight')
parser.add_argument('--multi-step', type=int, default=20, metavar='n', help='Number of steps for multi-step return')
parser.add_argument('--discount', type=float, default=0.99, metavar='γ', help='Discount factor')
parser.add_argument('--target-update', type=int, default=2000, metavar='?',
                    help='Number of steps after which to update target network')
parser.add_argument('--reward-clip', type=int, default=1, metavar='VALUE', help='Reward clipping (0 to disable)')
parser.add_argument('--learning-rate', type=float, default=0.0001, metavar='η', help='Learning rate')
parser.add_argument('--adam-eps', type=float, default=1.5e-4, metavar='ε', help='Adam epsilon')
parser.add_argument('--batch-size', type=int, default=32, metavar='SIZE', help='Batch size')
parser.add_argument('--norm-clip', type=float, default=10, metavar='NORM', help='Max L2 norm for gradient clipping')
parser.add_argument('--learn-start', type=int, default=int(5000), metavar='STEPS',
                    help='Number of steps before starting training')
parser.add_argument('--evaluate', action='store_true', help='Evaluate only')
parser.add_argument('--evaluation-interval', type=int, default=5000, metavar='STEPS',
                    help='Number of training steps between evaluations')
parser.add_argument('--evaluation-episodes', type=int, default=10, metavar='N',
                    help='Number of evaluation episodes to average over')
parser.add_argument('--evaluation-size', type=int, default=500, metavar='N',
                    help='Number of transitions to use for validating Q')
parser.add_argument('--render', action='store_true', help='Display screen (testing only)')
parser.add_argument('--enable-cudnn', action='store_true', help='Enable cuDNN (faster but nondeterministic)')
parser.add_argument('--checkpoint-interval', default=0,
                    help='How often to checkpoint the model, defaults to 0 (never checkpoint)')
parser.add_argument('--memory', help='Path to save/load the memory from')
parser.add_argument('--disable-bzip-memory', action='store_true',
                    help='Don\'t zip the memory file. Not recommended (zipping is a bit slower and much, much smaller)')
parser.add_argument('--save-folder', type=str, default='./exps/')
parser.add_argument('--temp', type=float, default=50.0)
# Setup
args = parser.parse_args()
args.enable_cudnn = True
print(' ' * 26 + 'Options')
for k, v in vars(args).items():
    print(' ' * 26 + k + ': ' + str(v))
results_dir = os.path.join('./results/', args.id)
# if not os.path.exists(results_dir):
#     os.makedirs(results_dir)
metrics = {'steps': [], 'rewards': [], 'Qs': [], 'best_avg_reward': -float('inf')}
np.random.seed(args.seed)
torch.manual_seed(np.random.randint(1, 10000))
if torch.cuda.is_available() and not args.disable_cuda:
    args.device = torch.device('cuda')
    torch.cuda.manual_seed(np.random.randint(1, 10000))
    torch.backends.cudnn.enabled = args.enable_cudnn
else:
    args.device = torch.device('cpu')


def log(s):
    print('[' + str(datetime.now().strftime('%Y-%m-%dT%H:%M:%S')) + '] ' + s)


def load_memory(memory_path, disable_bzip):
    if disable_bzip:
        with open(memory_path, 'rb') as pickle_file:
            return pickle.load(pickle_file)
    else:
        with bz2.open(memory_path, 'rb') as zipped_pickle_file:
            return pickle.load(zipped_pickle_file)


def save_memory(memory, memory_path, disable_bzip):
    if disable_bzip:
        with open(memory_path, 'wb') as pickle_file:
            pickle.dump(memory, pickle_file)
    else:
        with bz2.open(memory_path, 'wb') as zipped_pickle_file:
            pickle.dump(memory, zipped_pickle_file)


# Environment
env = Env(args)
env.train()
action_space = env.action_space()

# Agent
dqn = Agent(args, env)

# If a model is provided, and evaluate is fale, presumably we want to resume, so try to load memory
if args.model is not None and not args.evaluate:
    if not args.memory:
        raise ValueError('Cannot resume training without memory save path. Aborting...')
    elif not os.path.exists(args.memory):
        raise ValueError('Could not find memory file at {path}. Aborting...'.format(path=args.memory))

    mem = load_memory(args.memory, args.disable_bzip_memory)
else:
    if 'PER' in args.replay_type:
        mem = ReplayMemory(args, args.memory_capacity)
    else:
        assert False

priority_weight_increase = (1 - args.priority_weight) / (args.T_max - args.learn_start)

# Construct validation memory
val_mem = ReplayMemory(args, args.evaluation_size)
start_T, done = 0, True
base_filename = args.game + '-' + args.replay_type + '-' + str(args.seed)
folder = args.save_folder

if not os.path.exists(folder):
    os.makedirs(folder)

if start_T == 0:
    T = 0
    total_steps = []
    avg_qs = []
    avg_rewards = []
    while T < args.evaluation_size:
        if done:
            state, done = env.reset(), False

        next_state, _, done = env.step(np.random.randint(0, action_space))
        val_mem.append(state, None, None, done)
        state = next_state
        T += 1
    state, done = env.reset(), False
    T = 0
    dqn.eval()  # Set DQN (online network) to evaluation mode
    avg_reward, avg_Q = test(args, T, dqn, val_mem, metrics, results_dir)  # Test
    log('T = ' + str(T) + ' / ' + str(args.T_max) + ' | Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(
        avg_Q))
    # avg_qs.append(avg_Q)
    avg_rewards.append(avg_reward)
    total_steps.append(T)
if args.evaluate:
    dqn.eval()  # Set DQN (online network) to evaluation mode
    avg_reward, avg_Q = test(args, 0, dqn, val_mem, metrics, results_dir, evaluate=True)  # Test
    print('Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(avg_Q))
else:
    # Training loop
    dqn.train()
    done = True
    counter = 0
    for T in trange(start_T, args.T_max + 1):
        if done:
            state, done = env.reset(), False
            counter += 1
        if T % args.replay_frequency == 0:
            dqn.reset_noise()  # Draw a new set of noisy weights

        action = dqn.act(state)  # Choose an action greedily (with noisy weights)
        next_state, reward, done = env.step(action)  # Step
        if args.reward_clip > 0:
            reward = max(min(reward, args.reward_clip), -args.reward_clip)  # Clip rewards
        mem.append(state, action, reward, done)  # Append transition to memory

        # Train and test
        if T >= args.learn_start:
            if args.replay_type != "RANDOM":
                mem.priority_weight = min(mem.priority_weight + priority_weight_increase,
                                          1)  # Anneal importance sampling weight β to 1
            else:
                mem.priority_weight = 0.0

            if T % args.replay_frequency == 0:
                dqn.learn(mem)  # Train with n-step distributional double-Q learning

            # Update target network
            if T % args.target_update == 0:
                dqn.update_target_net()

        state = next_state
        if T % args.evaluation_interval == 0 and T > 0:
            dqn.eval()  # Set DQN (online network) to evaluation mode
            avg_reward, avg_Q = test(args, T, dqn, val_mem, metrics, results_dir)  # Test
            log('T = ' + str(T) + ' / ' + str(args.T_max) + ' | Avg. reward: ' + str(avg_reward) + ' | Avg. Q: ' + str(
                avg_Q))
            # avg_qs.append(avg_Q)
            avg_rewards.append(avg_reward)
            total_steps.append(T)

            dataframe = pd.DataFrame([])
            dataframe['all_scores'] = np.nan
            dataframe['total_steps'] = np.nan
            dataframe['all_scores'] = pd.Series(np.array(avg_rewards))
            dataframe['total_steps'] = pd.Series(np.array(total_steps))
            base_filename = args.game + '-' + args.replay_type + '-' + str(args.seed)
            csv_filename = base_filename + '.csv'
            dataframe.to_csv(os.path.join(folder, csv_filename))
            dqn.train()  # Set DQN (online network) back to training mode `

dataframe = pd.DataFrame([])
dataframe['all_scores'] = np.nan
dataframe['total_steps'] = np.nan
dataframe['total_times'] = np.nan
dataframe['all_scores'] = pd.Series(np.array(avg_rewards))
dataframe['total_steps'] = pd.Series(np.array(total_steps))
base_filename = args.game + '-' + args.replay_type + '-' + str(args.seed)
csv_filename = base_filename + '.csv'
dataframe.to_csv(os.path.join(folder, csv_filename))
env.close()
