import numpy as np


def value_to_label(v):
    if v > 0:
        return '+'
    elif v < 0:
        return '-'
    else:
        return '0'


def save_task_behavior(agent, w, risk, trial):
    visits = agent.rollouts(w, 100)[0]
    label = ''
    for v in w.flatten()[:3]:
        label += value_to_label(v)
    file_name = 'shapes_{}_{}_{}_{}.csv'.format(trial, agent.key, risk, label)
    np.savetxt(file_name, visits, delimiter=',')


def all_states(agent):
    states = []
    for psi in agent.sf.psi:
        states.extend(psi.keys())
    return set(states)
    
def save_task_values(agent, w, risk, trial):
    states = all_states(agent)
    heat = np.zeros((13, 13))
    for y in range(13):
        for x in range(13):
            states_yx = [s for s in states if s[0] == (y, x)]
            value = np.mean([np.max(agent.sf.GPI_w(s, w)[0]) for s in states_yx])
            heat[y, x] = value
    label = ''
    for v in w.flatten()[:3]:
        label += value_to_label(v)
    file_name = 'shapes_{}_{}_{}_{}.csv'.format(trial, agent.key, risk, label)
    np.savetxt(file_name, heat, delimiter=',')
    import matplotlib.pyplot as plt
    plt.imshow(heat)
    plt.show()