# -*- coding: UTF-8 -*-  
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()
for gpu in tf.config.experimental.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.keras import layers, Model, optimizers

from agents.dqn import DQN
from agents.dqn_gpi import DQN_GPI
from agents.sfdqn import SFDQN
from agents.buffer import ReplayBuffer
from features.deep import DeepSF
from tasks.reacher import Reacher
from utils.config import parse_config_file

import copy
import argparse

# parse arguments from command line
parser = argparse.ArgumentParser(description='Succesor Feature Deep Q-learning')
parser.add_argument('--test_task_idx', default=11, type=int, help='Traget task used for transfer')
parser.add_argument('--gamma', default=0.8, type=float, help='discount factor to be used')
parser.add_argument('--n_samples', default=100000, type=int, help='number of training samples(steps) per task')
parser.add_argument('--include_target_dqn', default="False", type=str, help='whether or not to use UVFA for DQN')
parser.add_argument('--train_task_idxs', default='', type=str, help='list of custom train task indeces to train on before transfering to test task' )
parser.add_argument('--imbalance_task_idx', default=None, type=str, help='test task that has imbalanced reward function (compared to trianing tasks)' )
parser.add_argument('--imbalance_task_scale', default=1, type=float, help='scaling of imbalanced reward function (compared to trianing tasks)' )


args = parser.parse_args()

# ablation parameters
test_task_idx = args.test_task_idx
gamma = args.gamma
n_samples = args.n_samples
include_target_dqn = args.include_target_dqn
# convert include_target_dqn to bool
include_target_dqn = args.include_target_dqn
if include_target_dqn == "True":
    include_target_dqn = True
elif include_target_dqn == "False":
    include_target_dqn = False
else:
    raise ValueError(f'Unsupported input {include_target_dqn} for include_target_dqn (only "True"/"False" allowed)')
imbalance_task_idx=args.imbalance_task_idx
imbalance_task_scale=args.imbalance_task_scale
# DEBUG
if include_target_dqn:
    print("Error")


print('\n==========Importnant Experiment Parameters===========')
print(f"test_task_idx: {test_task_idx}")
print(f"gamma: {gamma}")
print(f"n_samples: {n_samples}")
print(f"include_target_dqn: {include_target_dqn}")
print(f"imbalance_task_idx: {imbalance_task_idx}")
print(f"imbalance_task_scale: {imbalance_task_scale}")
print('=====================================================\n')

# read parameters from config file
config_params = parse_config_file('reacher.cfg')

# set params for ablation
config_params['AGENT']['gamma'] = gamma
config_params['GENERAL']['n_samples'] = n_samples

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']

task_params = config_params['TASK']
goals = task_params['train_targets']
test_goals = task_params['test_targets']
all_goals = goals + test_goals
    
agent_params = config_params['AGENT']
dqn_params = config_params['DQN']
sfdqn_params = config_params['SFDQN']

# SFDQN agent without GPI
sfdqn_wo_gpi_params = copy.deepcopy(config_params['SFDQN'])
sfdqn_wo_gpi_params['use_gpi'] =  False



# tasks
def generate_tasks(include_target):
    # task to be transfered (the used for testing)
    # total set of tasks used for GPI training/testing
    train_task_idxs = [list(range(len(goals)))[0]] + [test_task_idx]
    train_tasks = [Reacher(all_goals, i, include_target, imbalance_task_idx, imbalance_task_scale) for i in train_task_idxs]
    test_tasks = [Reacher(all_goals, i, include_target, imbalance_task_idx, imbalance_task_scale) for i in train_task_idxs]
    return train_tasks, test_tasks


# keras model
def dqn_model_lambda(include_target):
    keras_params = dqn_params['keras_params']
    # ADDED flexibility for DQN to be without goal locations as input features
    if include_target:
        x = y = layers.Input(6)
    else:
        x = y = layers.Input(4)
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9, activation='linear')(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


# keras model for the SF
def sf_model_lambda(x):
    n_features = len(all_goals)
    keras_params = sfdqn_params['keras_params']
    y = x
    for n_neurons, activation in zip(keras_params['n_neurons'], keras_params['activations']):
        y = layers.Dense(n_neurons, activation=activation)(y)
    y = layers.Dense(9 * n_features, activation='linear')(y)
    y = layers.Reshape((9, n_features))(y)
    model = Model(inputs=x, outputs=y)
    sgd = optimizers.Adam(learning_rate=keras_params['learning_rate'])
    model.compile(sgd, 'mse')
    return model


def train():
    
    # build SFDQN    
    print('building SFDQN')
    deep_sf = DeepSF(keras_model_handle=sf_model_lambda, **sfdqn_params)
    sfdqn = SFDQN(deep_sf=deep_sf, buffer=ReplayBuffer(sfdqn_params['buffer_params']),
                  **sfdqn_params, **agent_params)
    
    # train SFDQN
    print('training SFDQN')
    train_tasks, test_tasks = generate_tasks(False)
    sfdqn_perf = sfdqn.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # build SFDQN (wo GPI enabled)    
    print('building SFDQN (wo GPI enabled)')
    deep_sf_wo_gpi = DeepSF(keras_model_handle=sf_model_lambda, **sfdqn_wo_gpi_params)
    sfdqn_wo_gpi = SFDQN(deep_sf=deep_sf_wo_gpi, buffer=ReplayBuffer(sfdqn_wo_gpi_params['buffer_params']),
                  **sfdqn_wo_gpi_params, **agent_params)
    
    # train SFDQN
    print('training SFDQN (wo GPI enabled)')
    train_tasks, test_tasks = generate_tasks(False)
    sfdqn_wo_gpi_perf = sfdqn_wo_gpi.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)
    
    #!REMOVE
    # # build DQN
    # print('building DQN')
    # dqn = DQN(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']), include_target=include_target_dqn,
    #           **dqn_params, **agent_params)
    
    # # training DQN
    # print('training DQN')
    # train_tasks, test_tasks = generate_tasks(include_target_dqn)
    # dqn_perf = dqn.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)
    #!REMOVE

    # build DQN with GPI
    print('building DQN with GPI')
    dqn_gpi = DQN_GPI(model_lambda=dqn_model_lambda, buffer=ReplayBuffer(dqn_params['buffer_params']), include_target=include_target_dqn,
              **dqn_params, **agent_params)
    
    # training DQN
    print('training DQN with GPI')
    train_tasks, test_tasks = generate_tasks(include_target_dqn)
    dqn_gpi_perf = dqn_gpi.train(train_tasks, n_samples, test_tasks=test_tasks, n_test_ev=agent_params['n_test_ev'], test_task_idx=-1)

    # # smooth data    
    # def smooth(y, box_pts):
    #     return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    # sfdqn_perf = smooth(sfdqn_perf, 10)[:-5]
    # # dqn_perf = smooth(dqn_perf, 10)[:-5] #!REMOVE
    # dqn_gpi_perf = smooth(dqn_gpi_perf, 10)[:-5]
    # sfdqn_wo_gpi_perf = smooth(sfdqn_wo_gpi_perf, 10)[:-5]

    # save
    np.save(f'sfdqn_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new', sfdqn_perf)
    # np.save(f'dqn_perf_{test_task_idx}_gamma-{gamma}_include_target-{include_target_dqn}', dqn_perf) #!REMOVE
    np.save(f'dqn_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_include_target_dqn-{include_target_dqn}-new', dqn_gpi_perf)
    np.save(f'sfdqn_wo_gpi_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new', sfdqn_wo_gpi_perf)

    #load
    # sfdqn_perf = np.load(f'../data/perfs/gamma_0.8_send_server3/sfdqn_perf_{test_task_idx}_gamma-{gamma}.npy')
    # dqn_gpi_perf = np.load(f'../data/perfs/gamma_0.8_send_server3/dqn_perf_{test_task_idx}_gamma-{gamma}_include_target_dqn-{include_target_dqn}.npy')
    # sfdqn_wo_gpi_perf = np.load(f'sfdqn_wo_gpi_perf_{test_task_idx}_gamma-{gamma}.npy')

    # x = np.linspace(0, 5, sfdqn_perf.size)

    
    # # reporting progress
    # ticksize = 20
    # textsize = 25
    # plt.rc('font', size=textsize)  # controls default text sizes
    # plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    # plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    # plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    # plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    # plt.rc('legend', fontsize=ticksize)  # legend fontsize

    # plt.figure(figsize=(8, 6))
    # ax = plt.gca()
    # ax.plot(x, sfdqn_perf, color='#386cb0', linewidth=2, label='SFDQN (GPI)')
    # ax.plot(x, sfdqn_wo_gpi_perf, label='SFDQN_wo_GPI')
    # ax.plot(x, dqn_gpi_perf, color='#33a02c', linewidth=2, label='DQN (GPI)')
    # # ax.plot(x, dqn_perf, label='DQN') #!REMOVE
    # ax.set_xticks([1, 2, 3, 4, 5], labels=['Src1', 'Src2', 'Src3', 'Src4', f'Trg{test_task_idx +1}'])
    # if test_task_idx==8 or test_task_idx==9 or test_task_idx==10 or test_task_idx==11:
    #     plt.xlabel('training task')
    # if test_task_idx==4 or test_task_idx==8:
    #     plt.ylabel('target task reward')
    # # plt.title(f'Transfer performance comaprison for test task {test_task_idx - 3}\n('+r'$\gamma$'+f'={gamma}, n_samples={n_samples}, imbalance_task_scale={imbalance_task_scale})')
    # plt.tight_layout()
    # if test_task_idx==7:
    #     plt.legend(frameon=False)
    # plt.savefig(f'figures/paper/sfdqn_dqn_gpi_return_new_task-{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}_include_target_dqn-{include_target_dqn}.pdf', format="pdf", bbox_inches="tight")

def plot():
    
    #load
    sfdqn_perf = np.load(f'../data/perfs/sfdqn_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new.npy')
    dqn_gpi_perf = np.load(f'../data/perfs/dqn_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_include_target_dqn-{include_target_dqn}-new.npy')
    sfdqn_wo_gpi_perf = np.load(f'../data/perfs/sfdqn_wo_gpi_perf_{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new.npy')

    def smooth(y, box_pts):
        return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    sfdqn_perf = smooth(sfdqn_perf, 20)[:-10]
    # dqn_perf = smooth(dqn_perf, 10)[:-5] #!REMOVE
    dqn_gpi_perf = smooth(dqn_gpi_perf, 20)[:-10]
    sfdqn_wo_gpi_perf = smooth(sfdqn_wo_gpi_perf, 20)[:-10]

    # x = np.linspace(0, 5, sfdqn_perf.size)
    x = np.linspace(0, 2, sfdqn_perf.size)
    
    # reporting progress
    ticksize = 20
    textsize = 25
    plt.rc('font', size=textsize)  # controls default text sizes
    plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('legend', fontsize=ticksize)  # legend fontsize

    plt.figure(figsize=(8, 6))
    ax = plt.gca()
    ax.plot(x, sfdqn_perf, color='#386cb0', linewidth=2, label='SFDQN (GPI)')
    ax.plot(x, sfdqn_wo_gpi_perf, label='SFDQN_wo_GPI')
    ax.plot(x, dqn_gpi_perf, color='#33a02c', linewidth=2, label='DQN (GPI)')
    # ax.plot(x, dqn_perf, label='DQN') #!REMOVE
    # ax.set_xticks([1, 2, 3, 4, 5], labels=['Src1', 'Src2', 'Src3', 'Src4', f'Trg{test_task_idx +1}'])
    ax.set_xticks([1, 2], labels=['Src1', f'Trg{test_task_idx +1}'])
    if test_task_idx==8 or test_task_idx==9 or test_task_idx==10 or test_task_idx==11:
        plt.xlabel('training task')
    if test_task_idx==4 or test_task_idx==8:
        plt.ylabel('target task reward')
    # plt.title(f'Transfer performance comaprison for test task {test_task_idx - 3}\n('+r'$\gamma$'+f'={gamma}, n_samples={n_samples}, imbalance_task_scale={imbalance_task_scale})')
    plt.tight_layout()
    # if test_task_idx==7:
    plt.legend(frameon=False)
    plt.savefig(f'figures/paper/sfdqn_dqn_gpi_return_new_task-{test_task_idx}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}_include_target_dqn-{include_target_dqn}-new.pdf', format="pdf", bbox_inches="tight")

    # avg plot
    def smooth(y, box_pts):
        return np.convolve(y, np.ones(box_pts) / box_pts, mode='same')

    test_task_idx_list = [4, 7, 9, 10]
    sfdqn_perf_avg = 0
    dqn_gpi_perf_avg = 0
    sfdqn_wo_gpi_perf_avg = 0
    for test_task_idx_ in test_task_idx_list:

        #load
        sfdqn_perf = np.load(f'../data/perfs/sfdqn_perf_{test_task_idx_}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new.npy')
        dqn_gpi_perf = np.load(f'../data/perfs/dqn_perf_{test_task_idx_}_gamma-{gamma}_n_samples-{n_samples}_include_target_dqn-{include_target_dqn}-new.npy')
        sfdqn_wo_gpi_perf = np.load(f'../data/perfs/sfdqn_wo_gpi_perf_{test_task_idx_}_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}-new.npy')

        sfdqn_perf = smooth(sfdqn_perf, 20)[:-10]
        # dqn_perf = smooth(dqn_perf, 10)[:-5] #!REMOVE
        dqn_gpi_perf = smooth(dqn_gpi_perf, 20)[:-10]
        sfdqn_wo_gpi_perf = smooth(sfdqn_wo_gpi_perf, 20)[:-10]

        sfdqn_perf_avg += sfdqn_perf
        dqn_gpi_perf_avg += dqn_gpi_perf
        sfdqn_wo_gpi_perf_avg += sfdqn_wo_gpi_perf

    sfdqn_perf_avg = sfdqn_perf_avg/len(test_task_idx_list)
    dqn_gpi_perf_avg = dqn_gpi_perf_avg/len(test_task_idx_list)
    sfdqn_wo_gpi_perf_avg = sfdqn_wo_gpi_perf_avg/len(test_task_idx_list)

    # x = np.linspace(0, 5, sfdqn_perf.size)
    x = np.linspace(0, 2, sfdqn_perf.size)
    
    # reporting progress
    ticksize = 20
    textsize = 25
    plt.rc('font', size=textsize)  # controls default text sizes
    plt.rc('axes', titlesize=textsize)  # fontsize of the axes title
    plt.rc('axes', labelsize=textsize)  # fontsize of the x and y labels
    plt.rc('xtick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('ytick', labelsize=ticksize)  # fontsize of the tick labels
    plt.rc('legend', fontsize=ticksize)  # legend fontsize

    plt.figure(figsize=(8, 6))
    ax = plt.gca()
    ax.plot(x, sfdqn_perf_avg, color='#386cb0', linewidth=2, label='SFDQN (GPI)')
    ax.plot(x, sfdqn_wo_gpi_perf_avg, label='SFDQN_wo_GPI')
    ax.plot(x, dqn_gpi_perf_avg, color='#33a02c', linewidth=2, label='DQN (GPI)')
    # ax.plot(x, dqn_perf, label='DQN') #!REMOVE
    # ax.set_xticks([1, 2, 3, 4, 5], labels=['Src1', 'Src2', 'Src3', 'Src4', f'Trg{test_task_idx +1}'])
    ax.set_xticks([1, 2], labels=['Src1', f'Trg'])
    if test_task_idx==8 or test_task_idx==9 or test_task_idx==10 or test_task_idx==11:
        plt.xlabel('training task')
    if test_task_idx==4 or test_task_idx==8:
        plt.ylabel('target task reward')
    # plt.title(f'Transfer performance comaprison for test task {test_task_idx - 3}\n('+r'$\gamma$'+f'={gamma}, n_samples={n_samples}, imbalance_task_scale={imbalance_task_scale})')
    plt.tight_layout()
    # if test_task_idx==7:
    plt.legend(frameon=False)
    plt.savefig(f'figures/paper/sfdqn_dqn_gpi_return_new_task-avg_gamma-{gamma}_n_samples-{n_samples}_imbalance_task_scale-{imbalance_task_scale}_include_target_dqn-{include_target_dqn}-avg.pdf', format="pdf", bbox_inches="tight")


# train()
plot()
