import sys
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.keras import layers, Model, optimizers

from agents.c51dqn import C51DQN
from agents.sfdqn import SFDQN
from agents.base import BaseC51
from agents.buffer import ReplayBuffer
from features.c51sf import C51SF
from features.keras import KerasMVSF
from plot_reacher_covariance import plot_covariance_dqn, plot_covariance_c51
from tasks.reacher import Reacher
from utils import stamp
from utils.config import parse_config_file
from utils.stats import OnlineMeanVariance

# general training params
config_params = parse_config_file('reacher.cfg')

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']
n_trials = gen_params['n_trials']

task_params = config_params['TASK']
train_goals = task_params['train_targets']    
test_goals = task_params['test_targets']
all_goals = train_goals + test_goals

agent_params = config_params['AGENT']
c51_params = config_params['C51DQN']
sfc51_params = config_params['SFC51DQN']
sfdqn_params = config_params['SFDQN']


# tasks
def generate_tasks(include_target_in_state):
    train_tasks = [Reacher(all_goals, i, include_target_in_state=include_target_in_state, **task_params) 
                   for i in range(len(train_goals))]
    test_tasks = [Reacher(all_goals, len(train_goals) + i, include_target_in_state=include_target_in_state,
                          **task_params) for i in range(len(test_goals))]
    return train_tasks, test_tasks


# keras model
def sfc51dqn_model_lambda(x):
    keras_params = sfc51_params['keras_params']
    n_features = len(all_goals) + 1
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features * sfc51_params['n_atoms'])(y)
    y = layers.Reshape((9, n_features, sfc51_params['n_atoms']))(y)
    y = layers.Softmax(axis=-1)(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(optimizer=sgd, loss='categorical_crossentropy')
    return model


def c51dqn_model_lambda():
    keras_params = c51_params['keras_params']
    x = y = layers.Input(shape=(6,))
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * c51_params['n_atoms'])(y)
    y = layers.Reshape((9, c51_params['n_atoms']))(y)
    y = layers.Softmax(axis=-1)(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(optimizer=sgd, loss='categorical_crossentropy')
    return model


def c51_base_model_lambda():
    keras_params = c51_params['keras_params']
    x = y = layers.Input(shape=(4,))
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * c51_params['n_atoms'])(y)
    y = layers.Reshape((9, c51_params['n_atoms']))(y)
    y = layers.Softmax(axis=-1)(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(optimizer=sgd, loss='categorical_crossentropy')
    return model


def sfdqn_model_lambda(x):
    keras_params = sfdqn_params['keras_params']
    n_features = len(all_goals) + 1
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features)(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(optimizer=sgd, loss='mse')
    return model


def sfdqn_sigma_model_lambda(x):
    keras_params = sfdqn_params['keras_params']
    n_features = len(all_goals) + 1
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features, activation='softplus')(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(optimizer=sgd, loss='mse')
    return model


# training
def train_agents(agent_name, penalty, method):
    param_selection = str(penalty).replace('.', '')
    method_name = method.replace('-', '')
    
    # build agent
    assert agent_name in ['sfc51', 'sfdqn', 'base', 'uvfa']
    
    if agent_name == 'sfc51':
        sf = C51SF(model_lambda=sfc51dqn_model_lambda,
                   risk_aversion=penalty, method=method, **sfc51_params)
        agent = SFDQN(deep_sf=sf, buffer=ReplayBuffer(**sfc51_params['buffer_params']),
                      **agent_params, **sfc51_params)
        include_target_in_state = False
        plotc = plot_covariance_c51
        
    elif agent_name == 'sfdqn':
        sf = KerasMVSF(keras_psi_model_handle=sfdqn_model_lambda,
                       keras_Sigma_model_handle=sfdqn_sigma_model_lambda,
                       risk_aversion=penalty, **sfdqn_params)
        agent = SFDQN(deep_sf=sf, buffer=ReplayBuffer(**sfdqn_params['buffer_params']),
                      **agent_params, **sfdqn_params)
        include_target_in_state = False
        plotc = plot_covariance_dqn
        
    elif agent_name == 'base':
        agent = BaseC51(model_lambda=c51_base_model_lambda, buffer=ReplayBuffer(**c51_params['buffer_params']),
                        **agent_params, **c51_params)
        include_target_in_state = False
        plotc = plot_covariance_c51
        
    elif agent_name == 'uvfa':
        agent = C51DQN(model_lambda=c51dqn_model_lambda, buffer=ReplayBuffer(**c51_params['buffer_params']),
                       risk_aversion=penalty, **agent_params, **c51_params)
        include_target_in_state = True
        plotc = plot_covariance_c51
        
    print('summary:\nagent = {}\npenalty = {}\nmethod = {}\n'.format(agent.key, penalty, method))
    
    # data
    data_train_return = OnlineMeanVariance()
    data_train_fails = OnlineMeanVariance()
    data_test_return = OnlineMeanVariance()
    data_test_return_var = OnlineMeanVariance()
    data_test_fails = OnlineMeanVariance()
    all_data = [data_train_return, data_train_fails, data_test_return, data_test_return_var, data_test_fails]
    data_names = ['train_return', 'train_fails', 'test_return', 'test_return_var', 'test_fails']
    
    # training
    for _ in range(n_trials):
        
        # train each agent on a set of tasks
        train_tasks, test_tasks = generate_tasks(include_target_in_state)
        agent.train(train_tasks, n_samples, test_tasks=test_tasks, plot_var=plotc)
        
        # update performance statistics 
        data_train_return.update(agent.episode_reward_hist_per_task)
        data_train_fails.update(agent.episode_fails_hist_per_task)
        data_test_return.update(agent.test_reward_hist_per_task)
        data_test_return_var.update(agent.test_reward_var_hist_per_task)
        data_test_fails.update(agent.test_fails_hist_per_task)
        
    # save mean performance
    label = 'reacher_{}_{}_{}_{}_'.format(agent.key, method_name, param_selection, stamp.get_timestamp())
    for data, data_name in zip(all_data, data_names):
        all_curves = np.column_stack([data.mean.T, data.calculate_standard_error().T])
        np.savetxt(label + data_name + '.csv', all_curves, delimiter=',')
    
    # save rollouts
    if agent_name in ['sfc51', 'sfdqn', 'uvfa']:
        rollouts = agent.test_agent_rollouts(train_tasks + test_tasks, n_rollouts=20)
        for i, task_rollouts in enumerate(rollouts):
            np.savetxt('reacher_{}_{}_{}_{}_rollouts.csv'.format(i, agent.key, method_name, param_selection),
                       np.vstack(task_rollouts), delimiter=',')
        

if __name__ == "__main__":
    args = sys.argv
    if len(args) < 4:
        agent_name = 'sfc51'
        penalty = 2.0
        method = 'gauss'
    else:
        agent_name = args[1]
        penalty = float(args[2])
        method = args[3]
    train_agents(agent_name, penalty, method)
    
