import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys

from agents.sfql import SFQL
from agents.prql import PRQL
from features.tabular import TabularMVSF

from plot_shapes_covariance import save_task_behavior, save_task_values
from tasks.gridworld import Shapes
from utils import stamp
from utils.config import parse_config_file
from utils.stats import OnlineMeanVariance

# general training params
config_params = parse_config_file('gridworld.cfg')

gen_params = config_params['GENERAL']
n_samples = gen_params['n_samples']
n_tasks = gen_params['n_tasks']

task_params = config_params['TASK']

agent_params = config_params['AGENT']
sfql_params = config_params['SFQL']
rsql_params = config_params['SMART']
prql_params = config_params['PRQL']
etas = prql_params['etas']
taus = prql_params['taus']


# tasks
def load_tasks(trial_num):
    weights = np.loadtxt('shapes_weights.csv', delimiter=',')
    istart = n_tasks * trial_num
    iend = istart + n_tasks
    weights = weights[istart:iend,:]
    assert weights.shape[0] == n_tasks
    rewards = [dict(zip(['1', '2', '3'], list(w.flatten()))) for w in weights]
    tasks = [Shapes(maze=np.array(task_params['maze']),
                    shape_rewards=reward,
                    fail_prob=task_params['fail_prob'],
                    fail_reward=task_params['fail_reward']) for reward in rewards]
    return tasks


# training
def train_agents(trial_num, agent_name, penalty):
    param_selection = str(penalty).replace('.', '')
    
    # build agent
    assert agent_name in ['sfql', 'prql']
    if agent_name == 'sfql':
        agent = SFQL(TabularMVSF(**sfql_params, risk_aversion=penalty), **agent_params)
        agents = [agent]
    elif agent_name == 'prql':
        agents = []
        for eta in etas:
            for tau in taus:
                agent = PRQL(**agent_params, **rsql_params, omega=penalty, eta=eta, tau=tau)
                agents.append(agent)
    
    # data
    data_task_return = OnlineMeanVariance()
    data_cuml_return = OnlineMeanVariance()
    data_task_failed = OnlineMeanVariance()
    data_cuml_failed = OnlineMeanVariance()
    all_data = [data_task_return, data_cuml_return, data_task_failed, data_cuml_failed]
    data_names = ['task_reward', 'cuml_reward', 'task_failed', 'cuml_failed']

    # training
    for agent in agents:
        agent.reset()
    tasks = load_tasks(trial_num)
    for itask, task in enumerate(tasks):
        print('\ntask {}'.format(itask))
        for agent in agents:
            print()
            agent.train_on_task(task, n_samples)
    
    # this is for generating the visitation counts
    # l, h = -1., 1.
    # for w1 in [l, 0., h]:
    #     for w2 in [l, 0., h]:
    #         for w3 in [l, 0., h]:
    #             w = np.array([w1, w2, w3, 1., -2.]).reshape((-1, 1))
    #             save_task_values(agent, w, param_selection, trial_num)
                
    # update performance statistics 
    data_task_return.update(np.column_stack([agent.reward_hist for agent in agents]))
    data_cuml_return.update(np.column_stack([agent.cum_reward_hist for agent in agents]))
    data_task_failed.update(np.column_stack([agent.fails_hist for agent in agents]))
    data_cuml_failed.update(np.column_stack([agent.cum_fails_hist for agent in agents]))
    
    # save mean performance
    label = 'shapes_{}_{}_{}_{}_'.format(agent.key, param_selection, trial_num, stamp.get_timestamp())
    for data, data_name in zip(all_data, data_names):
        all_curves = np.column_stack([data.mean, data.calculate_standard_error()])
        np.savetxt(label + data_name + '.csv', all_curves, delimiter=',')
        

if __name__ == "__main__":
    args = sys.argv
    if len(args) < 4:
        trial_num = 0
        agent_name = 'sfql'
        penalty = 2.0
    else:
        trial_num = int(args[1])
        agent_name = str(args[2])
        penalty = float(args[3])
    train_agents(trial_num, agent_name, penalty)
    
