# -*- coding: UTF-8 -*-  
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.keras import layers, Model, optimizers

from agents.dqn import DQN
from agents.dqn_gpi import DQN_GPI
from agents.sfdqn import SFDQN
from agents.buffer import ReplayBuffer
from features.deep import DeepSF
from tasks.reacher import Reacher
from utils.config import parse_config_file

import copy
import argparse

# parse arguments from command line
parser = argparse.ArgumentParser(description='Succesor Feature Deep Q-learning')
parser.add_argument('--test_task_idx', default=4, help='Traget task used for transfer')
parser.add_argument('--gamma', default=0.9, help='discount factor to be used')
parser.add_argument('--include_target_dqn', default=False, help='whether or not to use UVFA for DQN')
parser.add_argument('--itrain_task_idxs', default='', help='list of custom train task indeces to train on before transfering to test task' )


args = parser.parse_args()

# ablation parameters
test_task_idx = int(args.test_task_idx)
gamma = args.gamma
include_target_dqn = args.include_target_dqn

print('\n==========IMportnatn Experiment Parameters===========')
print(f"test_task_idx: {test_task_idx}")
print(f"gamma: {gamma}")
print(f"include_target_dqn: {include_target_dqn}")
print('=====================================================\n')


# read parameters from config file
config_params = parse_config_file('reacher.cfg')

# set gamma (for gamma ablation)
config_params['AGENT']['gamma'] = gamma

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']

task_params = config_params['TASK']
goals = task_params['train_targets']
test_goals = task_params['test_targets']
all_goals = goals + test_goals
    
agent_params = config_params['AGENT']
dqn_params = config_params['DQN']
sfdqn_params = config_params['SFDQN']

# SFDQN agent without GPI
sfdqn_wo_gpi_params = copy.deepcopy(config_params['SFDQN'])
sfdqn_wo_gpi_params['use_gpi'] =  False



# tasks
def generate_tasks(include_target):
    # task to be transfered (the used for testing)
    # total set of tasks used for GPI training/testing
    train_task_idxs = list(range(len(goals))) + [test_task_idx] #[4, 8, 9] + [test_task_idx]#list(range(len(goals))) + [test_task_idx]
    train_tasks = [Reacher(all_goals, i, include_target) for i in train_task_idxs]
    test_tasks = [Reacher(all_goals, i, include_target) for i in train_task_idxs]
    return train_tasks, test_tasks


# keras model
def dqn_model_lambda(include_target):
    keras_params = dqn_params['keras_params']
    # ADDED flexibility for DQN to be without goal locations as input features
    if include_target:
        x = y = layers.Input(6)
    else:
        x = y = layers.Input(4)
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9, activation='linear')(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


# keras model for the SF
def sf_model_lambda(x):
    n_features = len(all_goals)
    keras_params = sfdqn_params['keras_params']
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features, activation='linear')(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


def train():
    
    # build DQN
    print('building DQN')
    dqn_gpi = DQN_GPI(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']), include_target=include_target_dqn,
              **dqn_params, **agent_params)
    
    # training DQN
    print('training DQN')
    train_tasks, test_tasks = generate_tasks(include_target_dqn)
    dqn_gpi_perf = dqn_gpi.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # smooth data    
    def smooth(y, box_pts):
        return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    dqn_gpi_perf = smooth(dqn_gpi_perf, 10)[:-5]

    # save
    np.save(f'dqn_gpi_perf_{test_task_idx}_gamma-{gamma}_include_target-{include_target_dqn}', dqn_gpi_perf)

    #load
    # dqn_gpi_perf = np.load(f'dqn_gpi_perf_{test_task_idx}_gamma-{gamma}_include_target-{include_target_dqn}.npy')
    dqn_uvfa_perf = np.load(f'dqn_perf_{test_task_idx}_gamma-{gamma}_include_target-True.npy')
    dqn_perf = np.load(f'dqn_perf_{test_task_idx}_gamma-{gamma}_include_target-False.npy')

    x = np.linspace(0, 5, dqn_gpi_perf.size)

    
    # reporting progress
    ticksize = 14
    textsize = 18
    plt.rc('font', size=textsize)  # controls default text sizes
    plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('legend', fontsize=ticksize)  # legend fontsize

    plt.figure(figsize=(8, 6))
    ax = plt.gca()

    ax.plot(x, dqn_perf, label='DQN')
    ax.plot(x, dqn_gpi_perf, label='DQN_GPI')
    ax.plot(x, dqn_uvfa_perf, label='DQN_UVFA')
    
    ax.set_xticks([1, 2, 3, 4, 5], labels=['Tr1', 'Tr2', 'Tr3', 'Tr4', f'Te{test_task_idx - 3}'])
    plt.xlabel('training task')
    plt.ylabel('target task reward')
    plt.title(f'Transfer performance comaprison for test task {test_task_idx - 3}('+r'$\gamma$'+f'={gamma})')
    plt.tight_layout()
    plt.legend(frameon=False)
    plt.savefig(f'figures/dqn_comp_return_new_task-{test_task_idx}_gamma-{gamma}_include_target_dqn-{include_target_dqn}.png')


train()
