'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the reward that will be given
in each state.
We assume here that you have reach any state from any other state with no travel time.
'''

import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.stats import gaussian_kde




plot_folder = ''
experiment_name = 'gridworld/'
saveFolder = plot_folder + experiment_name

# create folder if it doesn't already exist
if os.path.exists(saveFolder):
    print('overwrite')
else:
    os.makedirs(saveFolder)




def random_position(map_size, prev_pos):
    # generate random position in the map
    x = np.random.randint(0, map_size)
    y = np.random.randint(0, map_size)
    if (x, y) in prev_pos:
        return random_position(map_size, prev_pos)
    else:
        return (x, y)


def generate_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    reward_magnitude_time_matrices = np.zeros((num_cues, map_size, map_size, max_reward_delay, max_reward_magnitude))
    rewarded_positions = []
    # number of possible reward time and mags
    num_poss = 2
    for i in range(num_cues):
        x,y = random_position(map_size, rewarded_positions)
        print('got x y',x,y)
        rewarded_positions.append((x,y))
        for n in range(num_poss):
            time = np.random.randint(1, max_reward_delay)
            magnitude = np.random.randint(1, max_reward_magnitude)
            reward_magnitude_time_matrices[i, x, y, time, magnitude] = 1
        # normalize the matrix
        reward_magnitude_time_matrices[i] /= np.sum(reward_magnitude_time_matrices[i])
    return reward_magnitude_time_matrices, rewarded_positions


def generate_discount_weights(max_reward_delay, max_reward_magnitude, s_exp, k_hyper, c_scale):
    # discount weights for time and magnitude
    # each element is a prob of magnitude given a time
    weights_2D = np.zeros((max_reward_delay, max_reward_magnitude))
    factorized_weights = np.zeros((max_reward_delay, max_reward_magnitude))
    for i in range(max_reward_delay):
        for j in range(max_reward_magnitude):
            func = i + (max_reward_magnitude - j) * c_scale
            weights_2D[i, j] = 1.0 / (1 + k_hyper * func) ** s_exp
            factorized_weights[i, j] = (1.0 / (1 + k_hyper * i) ** s_exp) * (1.0 / (1 + k_hyper * (max_reward_magnitude - j) * c_scale) ** s_exp)


    return weights_2D, factorized_weights


def generate_preset_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    reward_magnitude_time_matrices = np.zeros((num_cues, map_size, map_size, max_reward_delay, max_reward_magnitude))
    rewarded_positions = [(0,0), (2,2)]
    reward_magnitude_time_matrices[0,0,0,0,1] = 1 
    reward_magnitude_time_matrices[1,2,2,2,3] = 1 

    return reward_magnitude_time_matrices, rewarded_positions



def plot_reward_magnitude_time_matrices(RT_matrices, rewarded_positions, title, saveFolder=saveFolder, particles=None):
    # RT_matrices shape: num cues x map size x map size x delay x magnitude
    max_reward_delay = RT_matrices.shape[-2]
    max_reward_magnitude = RT_matrices.shape[-1]
    num_cues = RT_matrices.shape[0]
    # plot reward time and magnitude matrices
    fig, axs = plt.subplots(num_cues, 1, figsize=(5, 10))
    for i in range(num_cues):
        x,y = rewarded_positions[i]
        # print('reward x y',x,y,i)
        axs[i].imshow(RT_matrices[i,x,y], cmap='hot', interpolation='nearest', vmin=0, vmax=1)
        if particles is not None:
            n_particles = particles.shape[1]
            axs[i].scatter(particles[i,x,y, :,1]-0.5,particles[i,x,y, :,0]-0.5,color="limegreen",s=1,zorder=n_particles+1)
        axs[i].set_title('State ' + str(i))
        axs[i].set_xlabel('Reward Magnitude')
        axs[i].set_ylabel('Reward Time')
        axs[i].set_xticks(np.arange(max_reward_magnitude))
        axs[i].set_yticks(np.arange(max_reward_delay))
        axs[i].set_xticklabels(np.arange(1, max_reward_magnitude + 1))
        axs[i].set_yticklabels(np.arange(1, max_reward_delay + 1))
        plt.colorbar(axs[i].imshow(RT_matrices[i,x,y], cmap='hot', interpolation='nearest', vmin=0, vmax=1), ax=axs[i])
    fig.savefig(saveFolder + title + '.png')

def plot_reward_MT_positions(RT_matrices, rewarded_positions, title, saveFolder=saveFolder, particles=None):
    # RT_matrices shape: num cues x map size x map size x delay x magnitude
    max_reward_delay = RT_matrices.shape[-2]
    max_reward_magnitude = RT_matrices.shape[-1]
    num_cues = RT_matrices.shape[0]
    map_size = RT_matrices.shape[1]
    RT = np.max(np.sum(RT_matrices, axis=-1), axis=-1)
    # plot reward time and magnitude matrices
    fig, axs = plt.subplots(num_cues, 1, figsize=(5, 10))
    for i in range(num_cues):
        x,y = rewarded_positions[i]
        # print('reward x y',x,y,i)
        axs[i].imshow(RT[i], cmap='hot', interpolation='nearest', vmin=0, vmax=1)
        # if particles is not None:
        #     axs[i].scatter(particles[i, :,1]-0.5,particles[i, :,0]-0.5,color="limegreen",s=1,zorder=n_particles+1)
        axs[i].set_title('State ' + str(i))
        # axs[i].set_xlabel('Reward Magnitude')
        # axs[i].set_ylabel('Reward Time')
        axs[i].set_xticks(np.arange(map_size))
        axs[i].set_yticks(np.arange(map_size))
        # axs[i].set_xticklabels(np.arange(1, max_reward_magnitude + 1))
        # axs[i].set_yticklabels(np.arange(1, max_reward_delay + 1))
        plt.colorbar(axs[i].imshow(RT[i], cmap='hot', interpolation='nearest', vmin=0, vmax=1), ax=axs[i])
    fig.savefig(saveFolder + title + '.png')


def plot_weights(weights, title):
    max_reward_delay = weights.shape[-2]
    max_reward_magnitude = weights.shape[-1]
    # plot reward time and magnitude matrices
    fig, axs = plt.subplots(1, 1, figsize=(5, 10))
    axs.imshow(weights, cmap='hot', interpolation='nearest')
    axs.set_xlabel('Reward Magnitude')
    axs.set_ylabel('Reward Time')
    axs.set_xticks(np.arange(max_reward_magnitude))
    axs.set_yticks(np.arange(max_reward_delay))
    # axs[i].set_xticklabels(np.arange(1, max_reward_magnitude + 1))
    # axs[i].set_yticklabels(np.arange(1, max_reward_delay + 1))
    plt.colorbar(axs.imshow(weights, cmap='hot', interpolation='nearest'), ax=axs)
    fig.savefig(saveFolder + title + '.png')

def plot_policy(saveFolder, grid_size, policies, reward_locations, name=None):
    num_cues = policies.shape[0]
    action_arrows = {
        0: (0, -0.3),   # up
        1: (0, 0.3),    # down
        2: (-0.3, 0),   # left
        3: (0.3, 0),    # right
    }

    fig, axs = plt.subplots(num_cues, 1, figsize=(6, 6))
    for i in range(num_cues):
        axs[i].set_xlim(0, grid_size)
        axs[i].set_ylim(0, grid_size)
        axs[i].set_xticks(np.arange(grid_size+1))
        axs[i].set_yticks(np.arange(grid_size+1))
        axs[i].grid(True)
        axs[i].set_aspect('equal')

        # Draw arrows
        for x in range(grid_size):
            for y in range(grid_size):
                action = policies[i,x,y]
                dx, dy = action_arrows[action]
                axs[i].arrow(
                    x + 0.5, grid_size - y - 0.5, dx, dy,
                    head_width=0.2, head_length=0.2, fc='blue', ec='blue'
                )

        # Draw reward locations
        # for (r, c) in reward_locations:
        r,c = reward_locations[i]
        rect = plt.Rectangle((r, grid_size - c - 1), 1, 1, color='gold', alpha=0.5)
        axs[i].add_patch(rect)

        axs[i].set_xticklabels([])
        axs[i].set_yticklabels([])
        # ax.set_title("Policy Arrows with Reward Locations")
    if name is None:
        fig.savefig(saveFolder + 'learned_policies_for_each_cue.png')
    else:
        fig.savefig(saveFolder + name + '_policies_for_each_cue.png')
    print('REWARD LOCATIONS ',reward_locations, policies)

def plot_policy_two_cues(grid_size, fact_policies, policies_2D, rewarded_locations, title):
    action_arrows = {
        0: (0, -0.3),   # up
        1: (0, 0.3),    # down
        2: (-0.3, 0),   # left
        3: (0.3, 0),    # right
    }

    fig, axs = plt.subplots(1, 2, figsize=(6, 6))
    for i in range(2):
        axs[i].set_xlim(0, grid_size)
        axs[i].set_ylim(0, grid_size)
        axs[i].set_xticks(np.arange(grid_size+1))
        axs[i].set_yticks(np.arange(grid_size+1))
        axs[i].grid(True)
        axs[i].set_aspect('equal')

        # Draw reward locations
        # for (r, c) in reward_locations:
        for j in range(len(rewarded_locations)):
            r,c = rewarded_locations[j]
            rect = plt.Rectangle((r, grid_size - c - 1), 1, 1, color='gold', alpha=0.5)
            axs[i].add_patch(rect)

        axs[i].set_xticklabels([])
        axs[i].set_yticklabels([])
        # ax.set_title("Policy Arrows with Reward Locations")

    # Draw arrows for factorized policies
    for x in range(grid_size):
        for y in range(grid_size):
            action = fact_policies[x,y]
            dx, dy = action_arrows[action]
            axs[0].arrow(
                x + 0.5, grid_size - y - 0.5, dx, dy,
                head_width=0.2, head_length=0.2, fc='blue', ec='blue'
            )

    # Draw arrows for factorized policies
    for x in range(grid_size):
        for y in range(grid_size):
            action = policies_2D[x,y]
            dx, dy = action_arrows[action]
            axs[1].arrow(
                x + 0.5, grid_size - y - 0.5, dx, dy,
                head_width=0.2, head_length=0.2, fc='blue', ec='blue'
            )


    if title is None:
        fig.savefig(saveFolder + 'learned_policies_for_each_cue.png')
    else:
        fig.savefig(saveFolder + title + '_policies_for_each_cue.png')


def plot_SR(SR, policy, name=None):
    grid_size = SR.shape[1]
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(6, 6))
    for i in range(grid_size):
        for j in range(grid_size):
            axs[i, j].imshow(np.sum(SR[i,j],axis=-1), cmap='hot', interpolation='nearest')
            axs[i, j].set_xticklabels([])
            axs[i, j].set_yticklabels([])

    if name==None:
        fig.savefig(saveFolder + 'SR_reps_' +str(policy) + '.png')
    else:
        fig.savefig(saveFolder + name + '_SR_reps_' +str(policy) + '.png')




# construct no travel time environment
class gridWorldEnv(object):
    def __init__(self, map_size, reward_magnitude_time_matrices, cue_probs, cue_simultaneous=False):
        # map size
        self.map_size = map_size
        # hold reward time and magnitude matrices for each cue
        self.reward_magnitude_time_matrices = reward_magnitude_time_matrices
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # probability of cue at each time
        self.cue_probs = cue_probs
        # number of cues
        self.num_cues = len(cue_probs)
        # which state the agent is currently in
        self.current_state = (0,0)
        # current environment reward over time map
        self.cur_reward_time_mag = np.zeros((self.map_size, self.map_size, self.max_reward_delay))
        # current cue delay
        self.cur_cue_delay = np.ones(self.num_cues) * self.max_reward_delay
        # actions : 0: up, 1: down, 2: left, 3: right
        self.action_space = 4
        # if cues can be simultaneous
        self.cue_simultaneous = cue_simultaneous

    def reset(self):
        # which state the agent is currently in
        self.current_state = (0,0)
        # initial cue
        init_cue = np.zeros(self.num_cues)
        # current environment reward over time map
        self.cur_reward_time_mag = np.zeros((self.map_size, self.map_size, self.max_reward_delay))
        return self.current_state, init_cue
        
    def step(self, action):
        # get next state based on action
        next_state = self.get_next_state(self.current_state, action)
        # update current state
        self.current_state = next_state
        
        # get reward magnitude distribution for current state at current time since cue
        reward = self.cur_reward_time_mag[self.current_state[0], self.current_state[1], 0]
        
        # generate a new cue
        next_cue = np.zeros(self.num_cues)
        for indc in range(self.num_cues):
            # if np.random.choice(np.arange(2), p=self.cue_probs[indc]):
            if np.random.binomial(1, p=self.cue_probs[indc]):
                # give cue
                next_cue[indc] = 1
        if self.cue_simultaneous:
            next_cue[1] = next_cue[0]
                
        # move time forward
        self.cur_reward_time_mag[:, :, :-1] = self.cur_reward_time_mag[:, :, 1:]
        self.cur_reward_time_mag[:, :, -1] = 0
        # update time state
        self.cur_cue_delay += 1
        self.cur_cue_delay[self.cur_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # update current reward time and magnitude (drawn from distribution)
        for indc in range(self.num_cues):
            # check if cue is present
            if next_cue[indc] == 1:
                # This draws independent rewards at each time for each cue
                for time in range(self.max_reward_delay):
                    shuffle_mags = np.arange(self.max_reward_magnitude)
                    np.random.shuffle(shuffle_mags)
                    for mag in shuffle_mags:
                    # for mag in range(self.max_reward_magnitude):
                        for x in range(self.map_size):
                            for y in range(self.map_size):
                                if np.random.binomial(1, p=self.reward_magnitude_time_matrices[indc, x, y, time, mag]):
                                    if (self.cur_reward_time_mag[x, y, time] == 0) and (self.cur_cue_delay[indc] == self.max_reward_delay):
                                        # if no reward has been given yet, give reward
                                        self.cur_reward_time_mag[x, y, time] += mag

        # return reward and next cue
        return reward, next_cue, self.current_state
    
    # gets next state based on action
    def get_next_state(self, state, action):
        # check for valid action
        assert action in np.arange(self.action_space), "Invalid action"
        # take action, update state
        x, y = state
        if action == 0: y = min(y + 1, self.map_size - 1)  # up
        elif action == 1: y = max(y - 1, 0)  # down
        elif action == 2: x = max(x - 1, 0)  # left
        elif action == 3: x = min(x + 1, self.map_size - 1)  # right
        return (x, y)

    def get_reward_TM(self, cue_delay):
        '''
        return what the reward time and magnitude map would look like for certain cue delay
        '''
        reward_time_mag = np.zeros((self.map_size, self.map_size, self.max_reward_delay, self.max_reward_magnitude))
        for ind, delay in enumerate(cue_delay):
            if delay < self.max_reward_delay:
                reward_time_mag[:, :, :(self.max_reward_delay - delay), :] += self.reward_magnitude_time_matrices[ind, :, :, delay:, :]
        reward_time_mag[reward_time_mag > 1] = 1
        return np.sum(reward_time_mag * np.arange(self.max_reward_magnitude)[None, None, None, :], axis=-1)


    def get_avg_return(self, cue_delay, state, action):
        '''
        just return the reward and next state of taking action from state with cue_delay observed
        '''
        # get next state based on action
        next_state = self.get_next_state(state, action)
        
        # get reward magnitude distribution for current state at current time since cue
        reward_map = self.get_reward_TM(cue_delay)
        return reward_map[next_state[0], next_state[1], 0], next_state



def softmax(x):
    e_x = np.exp(x +0.001)
    norm = e_x / e_x.sum(axis=(-1,-2), keepdims=True)
    return norm - np.min(norm, axis=(-1,-2), keepdims=True)

def min_norm(x):
    # x[x< 0.0001] = 0
    return x / np.sum(x + 1E-5, axis=(-1,-2), keepdims=True)


# def particle_update(particles, reward, indc, cue_delay, next_state, iter, alpha, max_r_delay, max_r_magnitude, particle_reward_MT, cue_weights, n_particles=n_particles, cov=cov, cov_shrink=cov_shrink, batch_size=batch_size, dx_der=dx_der, dy_der=dy_der, bw=bw):
def particle_update(particles, reward, indc, cue_delay, next_state, iter, alpha, max_r_delay, max_r_magnitude, particle_reward_MT, cue_weights, config):
    '''
    Update particles based on reward and next state.
    '''
    # get parameters
    n_particles = config.n_particles
    cov = config.cov
    cov_shrink = config.cov_shrink
    batch_size = config.batch_size
    dx_der = config.dx_der
    dy_der = config.dy_der
    bw = config.bw
    x_der = config.x_der
    y_der = config.y_der
    particles_der = config.particles_der
    Nx_der = config.Nx_der 
    Ny_der = config.Ny_der
    lamb = config.lamb
    learning_rates = config.learning_rates
    particle_gamma = config.particle_gamma
    # indc = int(next_state)
    # print('reward', reward)
    mean = (cue_delay[indc] + 0.5, reward + 0.5)
    sample_rew=np.random.multivariate_normal(mean, cov * cov_shrink[int(iter[indc])], batch_size)

    RBF = gaussian_kde(sample_rew.T, bw_method=bw)
    likelihood = np.reshape(RBF(particles_der.T), (Nx_der, Ny_der))
    # likelihood=likelihood/np.sum(likelihood)
    likelihood = likelihood_limits(likelihood)
    likelihood= np.log(likelihood/np.sum(likelihood))

    # Gradient of F1
    gradient_likelihood = np.gradient(likelihood, dx_der, dy_der)
    bins_x = np.digitize(particles[indc, next_state[0], next_state[1], :, 0], x_der[:, 0]) - 1
    bins_y = np.digitize(particles[indc, next_state[0], next_state[1], :, 1], y_der[0, :]) - 1
    gradient_f1 =np.zeros((n_particles,2))
    gradient_f2 =np.zeros((n_particles,2))
    gradient_f1[:, 0] = -np.ndarray.flatten(gradient_likelihood[0][bins_x, bins_y])
    gradient_f1[:, 1] = -np.ndarray.flatten(gradient_likelihood[1][bins_x, bins_y])

    # Gradient of F2
    particles_matrix = np.tile(particles[indc, next_state[0], next_state[1]], (n_particles, 1, 1))
    # dif_matrix = -np.subtract(particles_matrix[:, :, :], particles[indc, :, np.newaxis, :])
    # distance_matrix = np.abs(dif_matrix)
    dif_matrix = np.subtract(particles_matrix[:, :, :], particles[indc, next_state[0], next_state[1], :, np.newaxis, :])
    distance_matrix = np.sqrt(np.sum(dif_matrix**2,axis=-1,keepdims=True))
    gradient_f2 = np.sum(particle_gamma * (distance_matrix / lamb - 1) * np.exp(-distance_matrix / lamb) * dif_matrix, axis=0)
    gradient_f2=gradient_f2/n_particles

    # Sum gradient
    # gradient = gradient_f1 * cue_weights[indc] + gradient_f2
    gradient = (gradient_f1 + gradient_f2) * cue_weights[indc]
    learn_r=learning_rates[int(iter[indc])] # learning rate

    # Update particles
    particles[indc, next_state[0], next_state[1]] = particles[indc, next_state[0], next_state[1]] - learn_r * gradient

    next_reward_mat = matrix_particles(particles[indc, next_state[0], next_state[1]], max_r_delay, max_r_magnitude)    
    particle_reward_MT[indc, next_state[0], next_state[1]] = cue_weights[indc] * (alpha * next_reward_mat + (1 - alpha) * particle_reward_MT[indc, next_state[0], next_state[1]])
    # particle_reward_MT[indc, next_state[0], next_state[1]] = particle_reward_MT[indc, next_state[0], next_state[1]] / np.sum(particle_reward_MT[indc, next_state[0], next_state[1]] + 1E-5)
    iter[indc] += 1

    return particles, iter, particle_reward_MT

def particle_limits(particles, max_delay, max_mag):
    margin = 2.0
    particles[:,:,0][particles[:,:,0] > max_delay + margin] = max_delay + margin
    particles[:,:,0][particles[:,:,0] < - margin] = -margin
    particles[:,:,1][particles[:,:,1] > max_mag + margin] = max_mag + margin
    particles[:,:,1][particles[:,:,1] < -margin] = -margin
    return particles

def likelihood_limits(likelihood):
    np.log(likelihood/np.sum(likelihood))
    likelihood[likelihood < 1E-50] = 1E-50
    return likelihood


def matrix_particles(particles, max_reward_delay, max_reward_magnitude):
    # create a matrix of particles
    counts = np.zeros((max_reward_delay, max_reward_magnitude), dtype=int)

    # Map each point to the nearest integer grid center
    grid_coords = np.floor(particles).astype(int)
    # print('grid_coords', grid_coords)

    # Filter and count
    for x, y in grid_coords:
        if 0 <= x < max_reward_delay and 0 <= y < max_reward_magnitude:
            counts[x, y] += 1

    return counts/np.sum(counts + 1E-7) 

def set_init_particles(max_reward_delay, max_reward_magnitude, num_cues, n_particles, map_size):
    x, y = np.meshgrid(
        np.linspace(-0.5, max_reward_delay + 0.5, np.sqrt(n_particles).astype(int)),
        np.linspace(0.5, max_reward_magnitude + 0.5, np.sqrt(n_particles).astype(int))
    )
    x = x.ravel()
    y = y.ravel()
    particles_init = np.vstack((x, y)).T
    particles = np.zeros((num_cues, map_size, map_size, n_particles, 2))
    for i in range(num_cues):
        particles[i] = np.copy(particles_init)
    return particles


def update_sr(prev_state, next_state, sr, num_cues, alpha, discount_gamma, importance_weights):
    # take softmax of Q values as probability
    # update successor representation ONLY for the policy with highest expected reward
    map_size = sr.shape[1]
    max_reward_delay = sr.shape[-1]
    for ind_pi in range(num_cues):
        sr_target = np.zeros((map_size, map_size, max_reward_delay))
        sr_target[next_state[0], next_state[1], 0] = 1
        sr_target[:, :, 1:] += discount_gamma * sr[ind_pi, next_state[0], next_state[1], :, :, :-1] #, self.policies[pi_index][s_next]]
        sr[ind_pi, prev_state[0], prev_state[1]] += alpha * importance_weights[ind_pi] * (sr_target - sr[ind_pi, prev_state[0], prev_state[1]])
    return sr

def TRUE_SR(reward_MT, discount_gamma, env, risk_function=None):
    '''
    Calculate the true successor representation for each policy.
    '''
    num_cues = reward_MT.shape[0]
    map_size = reward_MT.shape[1]
    max_reward_delay = reward_MT.shape[-2]
    sr = np.zeros((num_cues, map_size, map_size, map_size, map_size, max_reward_delay))
    for indc in range(num_cues):
        if risk_function is not None:
            weights = risk_function(reward_MT[indc])
            target = np.sum(reward_MT[indc] * weights, axis=(-1, -2))
        else:
            target = np.sum(reward_MT[indc] * np.arange(reward_MT.shape[-1])[None, None, None, :], axis=(-1, -2))
        # Find the flattened index of the largest element
        flat_index = np.argmax(target)
        # Convert the flattened index to a 2D index
        x_targ, y_targ = np.unravel_index(flat_index, target.shape)
        for x in range(map_size):
            for y in range(map_size):
                for x_later in range(map_size):
                    for y_later in range(map_size):
                        if (np.abs(x_later - x_targ) <= np.abs(x - x_targ)) and (np.abs(y_later - y_targ) <= np.abs(y - y_targ)):
                            time = np.abs(x_later - x_targ) + np.abs(y_later - y_targ)
                            sr[indc, x, y, x_later, y_later, time] = discount_gamma ** time
    return sr

def increment_cue_delay(cue_delay, MT_map, reward_MT, next_cue, num_cues):
    # update time state
    max_reward_delay = MT_map.shape[-2]
    cue_delay += 1
    cue_delay[cue_delay >= max_reward_delay] = max_reward_delay
    # update current MT map in time
    MT_map[:, :, :-1, :] = MT_map[:, :, 1:, :]
    MT_map[:, :, -1, :] = 0
    for indc in range(num_cues):
        if next_cue[indc] == 1:
            # update time state
            cue_delay[indc] = 0
            # update current MT map
            MT_map += reward_MT[indc]
    return cue_delay, MT_map

def gpi_action(state, reward_map, sr, num_cues, env, weights=None):
    # for each policy, get value of next state after every action
    max_reward_magnitude = reward_map.shape[-1]
    Q_values = np.zeros((num_cues, env.action_space))
    for ind_pi in range(num_cues):
        for action in range(env.action_space):
            next_state = env.get_next_state(state, action)
            if weights is not None:
                next_rew = np.sum(reward_map[next_state[0], next_state[1], 0] * weights[next_state[0], next_state[1], 0])
                Q_values[ind_pi, action] = next_rew + np.sum(sr[ind_pi, next_state[0], next_state[1], :, :, :-1] * np.sum(reward_map[:, :, 1:] * weights[:, :, 1:], axis=-1))
            else:
                next_rew = np.sum(reward_map[next_state[0], next_state[1], 0] * np.arange(max_reward_magnitude))
                Q_values[ind_pi, action] = next_rew + np.sum(sr[ind_pi, next_state[0], next_state[1], :, :, :-1] * np.sum(reward_map[:, :, 1:] * np.arange(max_reward_magnitude)[None, None, None, :], axis=-1))
            # print('state',state,'action',action,'next state',next_state, np.sum(reward_map[:, :, 1:] * np.arange(self.max_reward_magnitude)[None, None, None, :], axis=-1), self.sr[ind_pi, next_state[0], next_state[1], :, :, :-1])
    # choose action based on expected reward
    max_pi, max_act = np.unravel_index(np.argmax(Q_values, axis=None), Q_values.shape)
    return max_act, max_pi, Q_values

def epsilon_greedy(action, epsilon, env):
    if np.random.random() < epsilon:
        return np.random.randint(0, env.action_space)
    return action

def reward_MT_map(reward_MT, cue):
    num_cues = reward_MT.shape[0]
    rew_MT = np.zeros(reward_MT.shape[1:])
    for indc in range(num_cues):
        if cue[indc] == 1:
            rew_MT += reward_MT[indc]
    return rew_MT

def get_policies(reward_MT, sr, env, pair_cues=None, weights=None):
    '''
    output max action for every state for every policy
    '''
    num_cues = reward_MT.shape[0]
    map_size = reward_MT.shape[1]
    if pair_cues is not None:
        policies = np.zeros((map_size, map_size))
        cue = np.zeros(num_cues)
        for indc in pair_cues:
            cue[indc] = 1
        MT_map = reward_MT_map(reward_MT, cue)
        if weights is not None:
            MT_map = MT_map * weights[None, None, :, :]
        for x in range(map_size):
            for y in range(map_size):
                max_act, max_pi, Q_vals = gpi_action((x,y), MT_map, sr, num_cues, env)
                policies[x, y] = max_act
        return policies
    else:
        policies = np.ones((num_cues, map_size, map_size))
        for indc in range(num_cues):
            cue = np.zeros(num_cues)
            cue[indc] = 1
            MT_map = reward_MT_map(reward_MT, cue)
            if weights is not None:
                MT_map = MT_map * weights[None, None, :, :]
            for x in range(map_size):
                for y in range(map_size):
                    max_act, max_pi, Q_vals = gpi_action((x,y), MT_map, sr, num_cues, env)
                    policies[indc, x, y] = max_act
                    if indc==2:
                        print('cue',cue,'state',(x,y),'max act',max_act,'Q vals',Q_vals)
        return policies

def get_separate_policies(reward_MT, sr, env):
    '''
    output max action for every state for every policy
    '''
    num_cues = reward_MT.shape[0]
    map_size = reward_MT.shape[1]
    policies = np.zeros((num_cues, num_cues, map_size, map_size))
    for indc in range(num_cues):
        cue = np.zeros(num_cues)
        cue[indc] = 1
        MT_map = reward_MT_map(cue)
        for x in range(map_size):
            for y in range(map_size):
                max_act, max_pi, Q_vals = gpi_action((x,y), MT_map, sr, num_cues, env)
                policies[indc, x, y] = max_act
    return policies

def Q_value_each_policy(state, reward_MT, cue_delay, sr, env, weights=None):
    # for each policy, get value of next state after every action
    reward_MT_map = np.zeros(reward_MT.shape)
    max_reward_delay = reward_MT.shape[-2]
    max_reward_magnitude = reward_MT.shape[-1]
    num_cues = reward_MT.shape[0]
    for ind_pi in range(num_cues):
        reward_MT_map[ind_pi, :, :, :int(max_reward_delay - cue_delay[ind_pi])] = reward_MT[ind_pi, :, :, int(cue_delay[ind_pi]):]
        reward_MT_map[ind_pi, :, :, :] = reward_MT_map[ind_pi, :, :, :] + reward_MT[ind_pi, :, :, :]
        # reward_MT_map[ind_pi, :, :, :int(max_reward_delay - cue_delay[ind_pi])] = reward_MT[ind_pi, :, :, int(cue_delay[ind_pi]):]

    Q_values = np.zeros((num_cues, env.action_space))
    for ind_pi in range(num_cues):
        for action in range(env.action_space):
            next_state = env.get_next_state(state, action)
            if weights is not None:
                next_rew = np.sum(reward_MT_map[ind_pi, next_state[0], next_state[1], 0] * weights[next_state[0], next_state[1], 0])
                Q_values[ind_pi, action] = next_rew + np.sum(sr[ind_pi, next_state[0], next_state[1], :, :, :-1] * np.sum(reward_MT_map[ind_pi, :, :, 1:] * weights[:, :, 1:], axis=-1))
            else:
                next_rew = np.sum(reward_MT_map[ind_pi, next_state[0], next_state[1], 0] * np.arange(max_reward_magnitude))
                Q_values[ind_pi, action] = next_rew + np.sum(sr[ind_pi, next_state[0], next_state[1], :, :, :-1] * np.sum(reward_MT_map[ind_pi, :, :, 1:] * np.arange(max_reward_magnitude)[None, None, None, :], axis=-1))
    return Q_values
    

def homeostatic_weights(reward_map_shape, internal_state, internal_state_update, homeostatic_loss):
    '''
    output a weight matrix that is proportional to the homeostatic loss 
    Give the weights according to the loss if only the single reward was received
    '''
    weights = np.zeros(reward_map_shape)
    # track internal state if no reward was received
    internal_states_no_rew = np.zeros(reward_map_shape[-2:])
    # # track internal state if reward was already received
    # internal_states_rew = np.zeros(reward_map_shape[-2:])
    # initialize the time 0 of no reward
    internal_states_no_rew[0,:] = internal_state
    # # initialize the time 0 of rewarded as the reward plus internal state
    # internal_states_rew[0, :] = internal_state.copy() + np.arange(reward_map_shape[-1]) # ARANGE PLUS 1???
    for t in range(reward_map_shape[-2]):
        # if no reward, just update the internal state
        if t == 0:
            # initialize the time 0 of no reward
            internal_states_no_rew[0,:] = internal_state
        else:
            internal_states_no_rew[t, :] = internal_state_update(internal_states_no_rew[t-1, :])
        # internal_states_rew[t, :] = 
        # if rewarded sum the reward and compute the internal states up to the max delay
        internal_state_rew = internal_states_no_rew.copy() 
        internal_state_rew[t, :] = internal_state_rew[t, :] + np.arange(reward_map_shape[-1]) # ARANGE PLUS 1???
        for t_left in range(t+1, reward_map_shape[-2]):
            internal_state_rew[t_left, :] = internal_state_update(internal_state_rew[t_left - 1, :])
        # find the homeostatic loss if this reward sequence was sampled
        # print('internal states rew ', t, internal_state_rew, homeostatic_loss(internal_state_rew))
        weights[:,:, t, :] = np.sum(homeostatic_loss(internal_state_rew), axis=0)
    # print('weights ',weights)
    return weights
        





class homeostatic_agent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, internal_state_update, homeostatic_loss, alpha=0.1, discount_gamma=0.9, epsilon=0.1, internal_state_init=1, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        # update function for internal state
        self.internal_state_update = internal_state_update
        # initial internal state
        self.init_internal = internal_state_init
        # current internal state
        self.cur_internal = internal_state_init
        # current internal state during training
        self.train_internal = internal_state_init
        # homeostatic loss function on internal state
        self.homeostatic_loss = homeostatic_loss
        self.config = config

 
    
    def test_act(self, state, cue, prev_reward, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)
        # increment internal state
        self.cur_internal = self.internal_state_update(self.cur_internal) + prev_reward
        weights = homeostatic_weights(self.test_MT_map.shape, self.cur_internal, self.internal_state_update, self.homeostatic_loss)
        
        # Get the indices of the maximum value within the last two dimensions
        # max_indices = np.unravel_index(np.argmax(self.test_MT_map, axis=None), self.test_MT_map.shape)
        # best_y = np.argmax(self.test_MT_map[best_x], axis=0)

        

        # update time state
        # if weights is None:
        #     action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env)
        # else:
        action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env, weights=weights)
        action_normal, max_pi_normal, Q_values_normal = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env)
        if cue[0] > 0:
            print('cur internal state ', self.cur_internal, prev_reward)
            print('weights ',weights[0,4])
            print('map 0,4 early', self.test_MT_map[0,4], np.sum(self.test_MT_map[0,4] * weights[0,4], axis=-1))
            # print('weights ', weights[4,4])
            print('map 4,0 late', self.test_MT_map[4,0], np.sum(self.test_MT_map[4,0] * weights[4,0], axis=-1))
            print('actions ', action, action_normal, max_pi, max_pi_normal, state)
            # print('Q values ', Q_values, Q_values_normal)
        if prev_reward > 0:
            print('prev reward ', prev_reward, ' state ', state)
        # if action != action_normal:
        #     print('different actions!')
        #     a

        # print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        
        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
        # reset internal state
        self.cur_internal = self.init_internal
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            # Q_values = self.Q_value_each_policy(prev_state)
            # weights = homeostatic_weights(self.current_MT_map.shape, self.train_internal, self.internal_state_update, self.homeostatic_loss)
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)#, weights=weights)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            # pi = np.exp(Q_values) / (np.sum(np.exp(Q_values), axis=-1, keepdims=True) + 1E-6)
            # importance_weights = pi[:, action]
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)

        if reward > 0:
            reward_map = np.zeros(self.reward_MT.shape)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
                    reward_map[indc, next_state[0], next_state[1], self.cur_cue_delay[indc], int(reward)] = 1
            # update current MT map
            reward_map = (self.reward_MT + 1E-6) * reward_map / np.sum((self.reward_MT + 1E-6) * reward_map)
            self.reward_MT[:, next_state[0], next_state[1]] = (1 - self.alpha) * self.reward_MT[:, next_state[0], next_state[1]] + self.alpha * reward_map[:, next_state[0], next_state[1]] 
            # self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map 

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)
        action = np.random.randint(0, self.env.action_space)
        # increment internal state
        self.train_internal = self.internal_state_update(self.train_internal) + reward

        return action
    
    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)







class rewardMTAgent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, risk_weight=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        self.config = config

        
    
    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # if prev_reward > 0:
        #     print('cue delays ', self.test_cue_delay, ' state ', state, ' reward ', prev_reward)
        #     print('test map 4, 3 ', np.sum(self.test_MT_map[4,3] * np.arange(self.max_reward_magnitude), axis=-1))
        #     print('test map 2, 2 ', np.sum(self.test_MT_map[2,2] * np.arange(self.max_reward_magnitude), axis=-1))
        # # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)
        # get Q values for each action and best action
        if weights is None:
            action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env)
        else:
            action, max_pi, Q_values = gpi_action(state, self.test_MT_map * weights[None, None, :, :], self.sr, self.num_cues, self.env)
        # print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        if cue[0] > 0:
            print('cue delays ', self.test_cue_delay, ' state ', state, ' reward ', prev_reward)
            print('test map 0, 4 early ', self.test_MT_map[0, 4], np.sum(self.test_MT_map[0, 4] * np.arange(self.max_reward_magnitude), axis=-1))
            print('test map 4, 0 late ', self.test_MT_map[4, 0], np.sum(self.test_MT_map[4, 0] * np.arange(self.max_reward_magnitude), axis=-1))
        
        # return action
        return action
        # return self.epsilon_greedy(action)
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            # Q_values = self.Q_value_each_policy(prev_state)
            # get risk weights
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            # pi = Q_values/ (np.sum(Q_values, axis=-1, keepdims=True) + 1E-6)
            # importance_weights = pi[:, action]
            # pi = np.exp(Q_values) / (np.sum(np.exp(Q_values), axis=-1, keepdims=True) + 1E-6)
            # importance_weights = pi[:, action]
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)
            # importance_weights /= (np.sum(importance_weights) + 1E-6)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            # self.sr = TRUE_SR(self.reward_MT, self.gamma, self.env)
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            # if reward > 0:
            #     print('impotance weights', importance_weights, ' cue delays ', self.cur_cue_delay, ' state ', next_state, ' max_act ', max_act, ' action', action, ' reward ', reward)

        # if reward > 0:
        #     reward_map = np.zeros(self.reward_MT.shape)
        #     for indc in range(self.num_cues):
        #         if self.cur_cue_delay[indc] != self.max_reward_delay:
        #             # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
        #             reward_map[indc, next_state[0], next_state[1], self.cur_cue_delay[indc], int(reward)] = 1
        #     # update current MT map
        #     reward_map = (self.reward_MT + 0.01) * reward_map / np.sum((self.reward_MT + 0.01) * reward_map)
        #     self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map 
        if reward > 0:
            reward_map = np.zeros(self.reward_MT.shape)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
                    reward_map[indc, next_state[0], next_state[1], self.cur_cue_delay[indc], int(reward)] = 1
            # update current MT map
            reward_map = (self.reward_MT + 1E-6) * reward_map / np.sum((self.reward_MT + 1E-6) * reward_map)
            self.reward_MT[:, next_state[0], next_state[1]] = (1 - self.alpha) * self.reward_MT[:, next_state[0], next_state[1]] + self.alpha * reward_map[:, next_state[0], next_state[1]] 
        # if reward > 0:
        #     for indc in range(self.num_cues):
        #         if self.cur_cue_delay[indc] != self.max_reward_delay:
        #             reward_map = np.zeros(self.reward_MT.shape[3:])
        #             reward_map[self.cur_cue_delay[indc], int(reward)] = 1
        #             self.reward_MT[indc, next_state[0], next_state[1]] = (1 - self.alpha) * self.reward_MT[indc, next_state[0], next_state[1]] + self.alpha * reward_map 

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)

        # get next action
        action = np.random.randint(0, self.env.action_space)

        return action

    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)


class magnitude_risk_agent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, risk_function=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude)
        # # reward magnitude values
        # self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        # risk function on reward magnitude
        self.risk_function = risk_function
        self.config = config

        
    
    def test_act(self, state, cue, prev_reward=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)
        # get Q values for each action and best action
        self.risk_function(np.zeros(self.test_MT_map.shape))
        weights = self.risk_function(self.test_MT_map)
        if cue[0] > 0:
            print('weights 4,0', weights[4,0], np.sum((self.test_MT_map)[4,0] * weights[4,0], axis=-1))
            print('weights 0,4', weights[0,4], np.sum((self.test_MT_map)[0,4] * weights[0,4], axis=-1))
        action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env, weights=weights)
        # print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        # if prev_reward > 0:
        #     print('state ', state, ' reward ', prev_reward)

        # return action
        return action
        # return self.epsilon_greedy(action)
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            # Q_values = self.Q_value_each_policy(prev_state)
            # weights = self.risk_function(self.current_MT_map)
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)#, weights=weights)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            # pi = Q_values/ (np.sum(Q_values, axis=-1, keepdims=True) + 1E-6)
            # pi = np.exp(Q_values) / (np.sum(np.exp(Q_values), axis=-1, keepdims=True) + 1E-6)
            # importance_weights = pi[:, action]
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            # self.sr = TRUE_SR(self.reward_MT, self.gamma, self.env, self.risk_function)
            # self.sr = TRUE_SR(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)

        if reward > 0:
            reward_map = np.zeros(self.reward_MT.shape)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
                    reward_map[indc, next_state[0], next_state[1], self.cur_cue_delay[indc], int(reward)] = 1
            # update current MT map
            reward_map = (self.reward_MT + 0.01) * reward_map / np.sum((self.reward_MT + 0.01) * reward_map)
            # self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map 
            self.reward_MT[:, next_state[0], next_state[1]] = (1 - self.alpha) * self.reward_MT[:, next_state[0], next_state[1]] + self.alpha * reward_map[:, next_state[0], next_state[1]] 

        # if reward > 0:
        #     for indc in range(self.num_cues):
        #         if self.cur_cue_delay[indc] != self.max_reward_delay:
        #             reward_map = np.zeros(self.reward_MT.shape[3:])
        #             reward_map[self.cur_cue_delay[indc], int(reward)] = 1
        #             self.reward_MT[indc, next_state[0], next_state[1]] = (1 - self.alpha) * self.reward_MT[indc, next_state[0], next_state[1]] + self.alpha * reward_map 

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)

        # get next action
        action = np.random.randint(0, self.env.action_space)

        return action

    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)

class magnitude_risk_particle_agent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, risk_function=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_cues, config.n_particles, map_size)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        # risk function on reward magnitude
        self.risk_function = risk_function
        self.config = config

    
    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)        
        # get Q values for each action and best action
        self.risk_function(np.zeros(self.test_MT_map.shape))
        weights = self.risk_function(self.test_MT_map)
        action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env, weights=weights)
        
        # print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        
        # return action
        return action
        # return self.epsilon_greedy(action)
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            # weights = self.risk_function(self.current_MT_map)
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)#, weights=weights)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            # importance weights of particles that make the learning rates different for each cue
            cue_weights = np.zeros(self.num_cues)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    cue_weights[indc] = self.reward_MT[indc, next_state[0], next_state[1], int(self.cur_cue_delay[indc]), int(reward)]
            if np.sum(cue_weights) == 0.0:
                cue_weights = np.ones(self.num_cues) / float(self.num_cues)
            else:
                cue_weights /= np.sum(cue_weights) 
            # print('CUE WEIGHTS ',cue_weights)

            

            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, indc, self.cur_cue_delay, next_state, self.iter, self.alpha, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT, cue_weights, self.config)
            # self.particle_update(reward, next_state)
        self.reward_MT = self.particle_reward_MT

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)
        # get next action
        action = np.random.randint(0, self.env.action_space)

        return action
    
    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)
    

    
class rewardMTAgent_particle(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, risk_weight=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_cues, config.n_particles, map_size)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        self.config = config

    
    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)        
        # get Q values for each action and best action
        if weights is None:
            action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env)
        else:
            action, max_pi, Q_values = gpi_action(state, self.test_MT_map * weights[None, None, :, :], self.sr, self.num_cues, self.env)
        
        # print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        if cue[0] > 0:
            print('cue delays ', self.test_cue_delay, ' state ', state, ' reward ', prev_reward)
            print('test map 0, 4 ', self.test_MT_map[0,4], np.sum(self.test_MT_map[0,4] * np.arange(self.max_reward_magnitude), axis=-1))
            print('test map 4, 0 ', self.test_MT_map[4,0], np.sum(self.test_MT_map[4,0] * np.arange(self.max_reward_magnitude), axis=-1))
        
        # return action
        return action
        # return self.epsilon_greedy(action)
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            # importance weights of particles that make the learning rates different for each cue
            cue_weights = np.zeros(self.num_cues)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    cue_weights[indc] = self.reward_MT[indc, next_state[0], next_state[1], int(self.cur_cue_delay[indc]), int(reward)]
            if np.sum(cue_weights) == 0.0:
                cue_weights = np.ones(self.num_cues) / float(self.num_cues)
            else:
                cue_weights /= np.sum(cue_weights) 
            # print('CUE WEIGHTS ',cue_weights)

            

            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, indc, self.cur_cue_delay, next_state, self.iter, self.alpha, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT, cue_weights, self.config)
            # self.particle_update(reward, next_state)
        # self.reward_MT = min_norm(self.particle_reward_MT)
        self.reward_MT = self.particle_reward_MT
        # self.reward_MT = softmax(self.particle_reward_MT)

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)
        # get next action
        action = np.random.randint(0, self.env.action_space)

        return action
    
    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)
    


class homeostatic_agent_particle(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, internal_state_update, homeostatic_loss, alpha=0.1, discount_gamma=0.9, epsilon=0.1, internal_state_init=1, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_cues, config.n_particles, map_size)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon
        # update function for internal state
        self.internal_state_update = internal_state_update
        # initial internal state
        self.init_internal = internal_state_init
        # current internal state
        self.cur_internal = internal_state_init
        # current internal state during training
        self.train_internal = internal_state_init
        # homeostatic loss function on internal state
        self.homeostatic_loss = homeostatic_loss
        self.config = config

    
    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)        
        # increment internal state
        self.cur_internal = self.internal_state_update(self.cur_internal) + prev_reward
        weights = homeostatic_weights(self.test_MT_map.shape, self.cur_internal, self.internal_state_update, self.homeostatic_loss)
        
        # get Q values for each action and best action
        action, max_pi, Q_values = gpi_action(state, self.test_MT_map, self.sr, self.num_cues, self.env, weights=weights)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
        # reset internal state
        self.cur_internal = self.init_internal
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            # print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            # weights = homeostatic_weights(self.current_MT_map.shape, self.train_internal, self.internal_state_update, self.homeostatic_loss)
            # Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env, weights=weights)
            
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            # importance weights of particles that make the learning rates different for each cue
            cue_weights = np.zeros(self.num_cues)
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    cue_weights[indc] = self.reward_MT[indc, next_state[0], next_state[1], int(self.cur_cue_delay[indc]), int(reward)]
            if np.sum(cue_weights) == 0.0:
                cue_weights = np.ones(self.num_cues) / float(self.num_cues)
            else:
                cue_weights /= np.sum(cue_weights) 
            print('CUE WEIGHTS ',cue_weights)

            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, indc, self.cur_cue_delay, next_state, self.iter, self.alpha, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT, cue_weights, self.config)
            # self.particle_update(reward, next_state)
        self.reward_MT = self.particle_reward_MT

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)
        # get next action
        action = np.random.randint(0, self.env.action_space)
        # increment internal state
        self.train_internal = self.internal_state_update(self.train_internal) + reward

        return action
    
    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)
    


def gpi_action_time_only(state, reward_map, expected_return, sr, num_cues, env):
    # for each policy, get value of next state after every action
    max_reward_magnitude = reward_map.shape[-1]
    Q_values = np.zeros((num_cues, env.action_space))
    for ind_pi in range(num_cues):
        for action in range(env.action_space):
            next_state = env.get_next_state(state, action)
            next_rew = np.sum(reward_map[next_state[0], next_state[1], 0] * np.arange(max_reward_magnitude))
            Q_values[ind_pi, action] = next_rew + np.sum(sr[ind_pi, next_state[0], next_state[1], :, :, :-1] * np.sum(reward_map[:, :, 1:] * np.arange(max_reward_magnitude)[None, None, None, :], axis=-1))
            # print('state',state,'action',action,'next state',next_state, np.sum(reward_map[:, :, 1:] * np.arange(self.max_reward_magnitude)[None, None, None, :], axis=-1), self.sr[ind_pi, next_state[0], next_state[1], :, :, :-1])
    # choose action based on expected reward
    max_pi, max_act = np.unravel_index(np.argmax(Q_values, axis=None), Q_values.shape)
    return max_act, max_pi, Q_values

class rewardMTAgent_particle_time_only(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, weighting=False, weights=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices (cue x map_size x map_size x max_reward_delay x max_reward_magnitude)
        self.reward_MT = reward_magnitude_time_matrices[:, :, :, :, 0]
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # successor representation, policies
        self.sr = np.abs(np.random.normal(0,0.1,size=(self.num_cues, self.map_size, self.map_size, self.map_size, self.map_size, self.max_reward_delay)))
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_cues, config.n_particles, map_size)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices
        # environment
        self.env = env
        # number of action
        self.num_actions = self.env.action_space
        # epsilon for exploration
        self.epsilon = epsilon

    
    def test_act(self, state, cue, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state and MT map
        self.test_cue_delay, self.test_MT_map = increment_cue_delay(self.test_cue_delay, self.test_MT_map, self.reward_MT, cue, self.num_cues)        
        # get Q values for each action and best action
        if weights is None:
            action, max_pi, Q_values = gpi_action_time_only(state, self.test_MT_map, self.sr, self.num_cues, self.env)
        else:
            action, max_pi, Q_values = gpi_action_time_only(state, self.test_MT_map * weights[None, None, :, :], self.sr, self.num_cues, self.env)
        
        print('cue delay, action, max pi, Q val', self.test_cue_delay, action, max_pi, Q_values)
        
        # return action
        return action
        # return self.epsilon_greedy(action)
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # test cue delay
        self.test_cue_delay = np.zeros(self.num_cues)
    

    def train(self, prev_state, action, reward, next_state, next_cue):

        if prev_state is not None:
            # find policy with highest expected reward for previous action
            # get Q values for each action from prev_state
            print('train')
            # _, _, Q_values = self.gpi_action(prev_state, self.current_MT_map)
            Q_values = Q_value_each_policy(prev_state, self.reward_MT, self.cur_cue_delay, self.sr, self.env)
            # policy with highest expected reward
            max_act = np.argmax(Q_values, axis=-1)
            # Q_pi = Q_values[:, action]
            
            random_prob_action = self.epsilon / self.num_actions
            # only update max pi
            importance_weights = np.ones(self.num_cues) * random_prob_action
            importance_weights[max_act == action] = 1 - self.epsilon / (self.num_actions - 1)

            # take softmax of Q values as probability
            # update successor representation ONLY for the policy with highest expected reward
            self.sr = update_sr(prev_state, next_state, self.sr, self.num_cues, self.alpha, self.gamma, importance_weights)
            
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, indc, self.cur_cue_delay, next_state, self.iter, self.alpha, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT)
            # self.particle_update(reward, next_state)
        self.reward_MT = np.sum(self.particle_reward_MT, axis=-1) # sum over magnitudes
        self.expected_return = np.sum(self.particle_reward_MT * self.magnitude_values[None, None, None, None, :], axis=(-1,-2))  # expected return for each state
        # self.reward_MT = self.particle_reward_MT

        # # update time state
        self.cur_cue_delay, self.current_MT_map = increment_cue_delay(self.cur_cue_delay, self.current_MT_map, self.reward_MT, next_cue, self.num_cues)
        # get next action
        action = np.random.randint(0, self.env.action_space)

        return action
    
    def get_policies(self, pair_cues=None, weights=None):
        '''
        output max action for every state for every policy
        '''
        return get_policies(self.reward_MT, self.sr, self.env, pair_cues, weights)
    
    

class QuantileRLAgent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, num_quantiles=50, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # internal state
        self.cur_cue_delay = np.ones(self.num_cues) * self.max_reward_delay
        # value
        self.Q_value = np.zeros((self.map_size, self.map_size, (self.max_reward_delay + 1)**self.num_cues, env.action_space))
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # exploration rate
        self.epsilon = epsilon
        # environment
        self.env = env
        print('normal agent params: ',self.alpha, self.map_size, self.num_cues, self.max_reward_delay, self.max_reward_magnitude, self.cur_cue_delay, self.gamma, self.epsilon)
        # number of quantiles
        self.num_quantiles = num_quantiles
        
        # Quantile values per state-action pair: shape (S, A, N)
        # self.quantiles = np.zeros((num_states, num_actions, num_quantiles))
        self.quantiles = np.zeros((self.map_size, self.map_size, (self.max_reward_delay + 1)**self.num_cues, env.action_space, num_quantiles))

        # Fixed quantile midpoints τ_i = (i + 0.5) / N
        self.taus = (np.arange(num_quantiles) + 0.5) / num_quantiles

    def get_cue_index(self, cue_delay):
        '''
        get index for the cue delay
        '''
        index = 0
        for ind in range(self.num_cues):
            index += cue_delay[ind] * ((self.max_reward_delay + 1) ** ind)
        return int(index)

    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        '''
        Select greedy action based on mean over quantiles.
        '''
        # set current state
        self.test_state = state
        # update time state
        self.test_cue_delay += 1
        self.test_cue_delay[self.test_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # set current cue
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update time state
                self.test_cue_delay[indc] = 0
        index = self.get_cue_index(self.test_cue_delay)
        q_means = self.quantiles[state[0], state[1], index].mean(axis=1)  # shape (A,)
        return np.argmax(q_means)
    
    def act(self, state, cue_index):
        """
        Select greedy action based on mean over quantiles.
        """
        q_means = self.quantiles[state[0], state[1], cue_index].mean(axis=1)  # shape (A,)
        return np.argmax(q_means)

    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_cue_delay = np.ones(self.num_cues) * self.max_reward_delay

    def train(self, prev_state, action, reward, next_state, next_cue):
        """
        Perform one quantile regression update for (s,a,r,s').
        """
        # current state is still before action was taken
        prev_cue = self.cur_cue_delay.copy()
        prev_cue_index = self.get_cue_index(prev_cue)
        # update time state
        self.cur_cue_delay += 1
        self.cur_cue_delay[self.cur_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # set current cue
        for indc in range(self.num_cues):
            if next_cue[indc] == 1:
                # update time state
                self.cur_cue_delay[indc] = 0

        # update quantiles function
        if prev_state is not None:
            # Current quantiles: shape (N,)
            theta = self.quantiles[prev_state[0], prev_state[1], prev_cue_index, action]
            # Given the cue delays, get unique index
            next_cue_index = self.get_cue_index(self.cur_cue_delay)
            # Next-state greedy action (double Q optional, here just max-mean)
            next_action = self.act(next_state, next_cue_index)
            next_theta = self.quantiles[next_state[0], next_state[1], next_cue_index, next_action]
            # Target distribution
            target = reward + self.gamma * next_theta

            # Pairwise differences: shape (N, N)
            diff = target[None, :] - theta[:, None]

            # Huber loss gradient (simplified for learning update)
            huber_kappa = 1.0
            abs_diff = np.abs(diff)
            huber_loss_grad = np.where(abs_diff <= huber_kappa, diff, huber_kappa * np.sign(diff))

            # Quantile regression weights
            quantile_indicator = (diff < 0).astype(float)
            quantile_weight = self.taus[:, None] - quantile_indicator

            # Full gradient estimate: shape (N,)
            grad = (quantile_weight * huber_loss_grad).mean(axis=1)

            # Update rule
            self.quantiles[prev_state[0], prev_state[1], prev_cue_index, action] += self.alpha * grad

        # Use random action during training
        action = np.random.randint(0, self.env.action_space)

        return action



class normalAgent(object):
    def __init__(self, env, map_size, init_state, reward_magnitude_time_matrices, alpha=0.1, discount_gamma=0.9, epsilon=0.1, risk_function=None, config=None):
        # number of states
        self.map_size = map_size
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # internal state
        self.cur_cue_delay = np.ones(self.num_cues) * self.max_reward_delay
        # value
        self.Q_value = np.zeros((self.map_size, self.map_size, (self.max_reward_delay + 1)**self.num_cues, env.action_space))
        # learning rate
        self.alpha = alpha
        # discount factor
        self.gamma = discount_gamma
        # exploration rate
        self.epsilon = epsilon
        # environment
        self.env = env
        print('normal agent params: ',self.alpha, self.map_size, self.num_cues, self.max_reward_delay, self.max_reward_magnitude, self.cur_cue_delay, self.gamma, self.epsilon)

    def get_value(self, cue_delay):
        '''
        get value for the cue delay
        '''
        index = self.get_cue_index(cue_delay)
        return self.Q_value[:, :, index]
    
    def get_cue_index(self, cue_delay):
        '''
        get index for the cue delay
        '''
        index = 0
        for ind in range(self.num_cues):
            index += cue_delay[ind] * ((self.max_reward_delay + 1) ** ind)
        return int(index)
    
    def act(self, state, cue_delay=None):
        # expected reward per state
        if cue_delay is None:
            expected_reward = self.get_value(self.cur_cue_delay)[state[0], state[1]]
        else:
            expected_reward = self.get_value(cue_delay)[state[0], state[1]]
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        return action
    
    def test_act(self, state, cue, prev_reward=None, weights=None, epoch=0):
        # set current state
        self.test_state = state
        # update time state
        self.test_cue_delay += 1
        self.test_cue_delay[self.test_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # set current cue
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update time state
                self.test_cue_delay[indc] = 0
        # expected reward per state
        expected_reward = self.get_value(self.test_cue_delay)[state[0], state[1]]
        # choose action based on expected reward
        action = np.argmax(expected_reward)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_cue_delay = np.ones(self.num_cues) * self.max_reward_delay
    
    def train(self, prev_state, action, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_cue = self.cur_cue_delay.copy()
        prev_cue_index = self.get_cue_index(prev_cue)
        # update time state
        self.cur_cue_delay += 1
        self.cur_cue_delay[self.cur_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # set current cue
        for indc in range(self.num_cues):
            if next_cue[indc] == 1:
                # update time state
                self.cur_cue_delay[indc] = 0
        # update value function
        if prev_state is not None:
            next_cue_index = self.get_cue_index(self.cur_cue_delay)
            target_value = reward + self.gamma * np.max(self.Q_value[next_state[0], next_state[1], next_cue_index])
            self.Q_value[prev_state[0], prev_state[1], prev_cue_index, action] = self.alpha * target_value + (1 - self.alpha) * self.Q_value[prev_state[0], prev_state[1], prev_cue_index, action]

        # get next action
        # action = self.act(next_state)
        # action = self.epsilon_greedy(action)
        action = np.random.randint(0, self.env.action_space)

        return action
    
    def epsilon_greedy(self, action):
        if np.random.random() < self.epsilon:
            return np.random.randint(0, self.env.action_space)
        return action

    def update_Q(self, prev_state, action, reward, next_state, cue_delay):
        prev_cue_index = self.get_cue_index(cue_delay.copy())
        next_cue = np.array(cue_delay.copy()) + 1
        next_cue[next_cue >= self.max_reward_delay] = self.max_reward_delay
        next_cue_index = self.get_cue_index(next_cue)
        # if reward > 0:
        #     print('update Q: rew',reward,'next Q',self.Q_value[next_state[0], next_state[1], next_cue_index])
        target_value = reward + self.gamma * np.max(self.Q_value[next_state[0], next_state[1], next_cue_index])
        # self.Q_value[prev_state[0], prev_state[1], prev_cue_index, action] = self.alpha * target_value + (1 - self.alpha) * self.Q_value[prev_state[0], prev_state[1], prev_cue_index, action]
        self.Q_value[prev_state[0], prev_state[1], prev_cue_index, action] = target_value
        # print('next state', next_state,'cue delay',cue_delay,'next_cue', next_cue,'reward',reward)

    def get_policies(self):
        '''
        output max action for every state for every policy
        '''
        policies = np.zeros((self.num_cues, self.map_size, self.map_size))
        for indc in range(self.num_cues):
            cue = np.ones(self.num_cues) * self.max_reward_delay
            cue[indc] = 0
            for x in range(self.map_size):
                for y in range(self.map_size):
                    max_act = self.act((x,y), cue_delay=cue)
                    policies[indc, x, y] = max_act
        return policies




def plot_rewards(rewards, title, saveFolder=saveFolder, opt_rewards=None):
    plt.figure()
    plt.plot(rewards)
    if opt_rewards is not None:
        plt.axhline(y=opt_rewards, color='r')
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Reward')
    plt.savefig(saveFolder + title + '.png')
    plt.close()
        
def plot_time_avg_rewards(rewards, title):
    '''
    plot rewards averaged in a 100 time window
    '''
    plt.figure()
    plt.plot(np.convolve(rewards, np.ones(100)/100, mode='valid'))
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Reward')
    plt.savefig(saveFolder + title + '.png')
    plt.close()

def plot_value_function(value, title):
    plt.figure()
    plt.imshow(value, cmap='hot', interpolation='nearest')
    plt.title(title)
    plt.xlabel('Cue Delay')
    plt.ylabel('State')
    plt.colorbar()
    plt.savefig(saveFolder + title + '.png')
    plt.close()


def plot_reward_distribution(rewards, title):
    plt.figure()
    plt.hist(rewards.flatten(), bins=4)
    plt.title(title)
    plt.xlabel('rewards')
    plt.ylabel('Frequency')
    plt.savefig(saveFolder + title + '.png')
    plt.close()

        

def test_agent(agent, env, num_timesteps, weights=None, prnt=False, epoch=0):
    # reset environment
    state, cue = env.reset()
    reward = 0
    # track rewards
    rewards = np.zeros(num_timesteps)
    # reset agent for test
    agent.reset_test(state)
    prev_state = None
    prev_action = None
    
    for time in range(num_timesteps):              
        # get action from agent
        if weights is not None:
            action = agent.test_act(state, cue, weights)
        else:
            action = agent.test_act(state, cue)
        prev_action = action
        prev_state = state
        # take action in environment
        reward, cue, state = env.step(action)
        if prnt:
            print('agent cue delay', agent.test_cue_delay, ' next cue ', cue)
            print('final test rew ',reward, ' state',state, ' action',action)            
        rewards[time] = reward
        
    return np.mean(rewards), rewards

def train_agent(agent, env, test_env, num_timesteps, test_timesteps, test_every, init_state, init_cue, init_reward):
    prev_state = None
    prev_action = None
    state = init_state
    cue = init_cue
    reward = init_reward

    print('train agent with ', num_timesteps, ' timesteps', test_timesteps, test_every, init_state, init_cue, init_reward)
    # track rewards
    rewards = np.zeros(num_timesteps)
    test_rewards = np.zeros(num_timesteps // test_every)
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.train(prev_state, prev_action, reward, state, cue)
        prev_action = action
        # take action in environment
        prev_state = state
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        # print('cue: ', cue, 'state: ', state, 'action: ', action, 'reward: ', reward, 'cur cue delay: ', agent.cur_cue_delay)
        # test agent every test_every time steps
        if time % test_every == 0:
            epoch = time // test_every
            # test agent
            test_rewards[time // test_every], full_test_rewards = test_agent(agent, test_env, test_timesteps, epoch=epoch)


    # test_agent(agent, test_env, test_timesteps, prnt=False, epoch=num_timesteps // test_every)
    
    return rewards, test_rewards, full_test_rewards

def nested_loop(ncues, delay):
    if ncues == 0:
        return [[]]
    else:
        all_cues = []
        other_cues = nested_loop(ncues - 1, delay)
        for cue in other_cues:
            for d in range(delay + 1):
                all_cues.append([d] + cue)
        return all_cues


def train_optimal_agent(agent, test_env, train_loops, test_timesteps, map_size, num_cues, max_reward_delay):
    
    true_reward_MT = test_env.reward_magnitude_time_matrices
    for loop in range(train_loops):
        for x in range(map_size):
            for y in range(map_size):
                state = (x,y)
                for action in range(test_env.action_space):
                    for cue_delay in nested_loop(num_cues, max_reward_delay):
                        reward, next_state = test_env.get_avg_return(cue_delay, state, action)
                        print('cue delay',cue_delay, reward, state, action, next_state)
                        agent.update_Q(state, action, reward, next_state, cue_delay)

    test_rewards, full_test_rewards = test_agent(agent, test_env, test_timesteps, prnt=True)

    return test_rewards, full_test_rewards


# run agent in environment
def main_train():

    # randomly generate reward time and magnitude matrices
    reward_magnitude_time_matrices, rewarded_positions = generate_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_per_state')
    plot_reward_MT_positions(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_over_states')

    # print('reward positions',rewarded_positions)

    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    # create agent
    agent_particles = rewardMTAgent_particle(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon)
    # agent = rewardMTAgent(env, map_size, state, reward_magnitude_time_matrices, alpha, gamma, epsilon)
    
    # train the agent
    rewards_particle_agt, test_rew_particle_agt, _ = train_agent(agent_particles, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        
    # plot rewards
    # plot_rewards(rewards_map_agt, 'rewards_map_agt')
    plot_rewards(test_rew_particle_agt, 'test_rewards_particle_agt')
    plot_reward_magnitude_time_matrices(agent_particles.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state_particles', particles=agent_particles.particles)
    policies = agent_particles.get_policies()
    plot_policy(map_size, policies, rewarded_positions, 'particles_agt')
    for i in range(num_cues):
        plot_SR(agent_particles.sr[i], i, 'particles_agt')
    plot_reward_MT_positions(agent_particles.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_over_states_particles', particles=agent_particles.particles)
    


    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    # create agent
    agent = rewardMTAgent(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon)
    # agent = rewardMTAgent(env, map_size, state, reward_magnitude_time_matrices, alpha, gamma, epsilon)
    
    # train the agent
    rewards_map_agt, test_rew_map_agt, _ = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        
    print('Average reward over time noTravelAgent: ', np.mean(rewards_map_agt))

    # plot rewards
    plot_rewards(rewards_map_agt, 'rewards_map_agt')
    plot_rewards(test_rew_map_agt, 'test_rewards_map_agt')
    plot_reward_magnitude_time_matrices(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state')
    policies = agent.get_policies()
    plot_policy(map_size, policies, rewarded_positions)
    for i in range(num_cues):
        plot_SR(agent.sr[i], i)
    plot_reward_MT_positions(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_over_states')
    

    

    # train normal agent
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0

    # create agent
    # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    normAgent = normalAgent(env, map_size, state, init_reward_MT, 0.01, discount_gamma, epsilon) #alpha)
            
    rewards_normAgent, test_rew_norm, _ = train_agent(normAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time normalAgent: ', np.mean(rewards_normAgent))

    # plot rewards
    plot_rewards(rewards_normAgent, 'rewards_normalAgent')
    plot_rewards(test_rew_norm, 'test_rewards_normalAgent')#, opt_rewards=opt_rewards)
    plot_value_function(np.max(normAgent.Q_value, axis=-1).reshape(-1, normAgent.Q_value.shape[2]), 'learned_value_function')



# run agent in environment
def main_train_risk_sensitive():

    # get weighting if using risk sensitive agent
    if weighting:
        weights_2D, factorized_weights = generate_discount_weights(max_reward_delay, max_reward_magnitude, s_exp, k_hyper, c_scale)
        plot_weights(weights_2D, 'weights_2D')
        plot_weights(factorized_weights, 'factorized_weights')
    

    # randomly generate reward time and magnitude matrices
    # reward_magnitude_time_matrices, rewarded_positions = generate_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude)
    reward_magnitude_time_matrices, rewarded_positions = generate_preset_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude)
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_per_state')
    plot_reward_MT_positions(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_over_states')

    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    # create agent
    agent_TMD = rewardMTAgent(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon)    
    # train the agent
    rewards_TMD, test_rew_TMD, full_test_rew_TMD = train_agent(agent_TMD, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward, rewarded_positions)
        
    print('Average reward over time noTravelAgent: ', np.mean(rewards_TMD))

    # plot rewards
    plot_rewards(rewards_TMD, 'rewards_TMD_agt')
    plot_rewards(test_rew_TMD, 'test_rewards_TMD_agt')
    plot_reward_magnitude_time_matrices(agent_TMD.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state_TMD')
    policies = agent_TMD.get_policies()
    plot_policy(map_size, policies, rewarded_positions, 'TMD_agt')
    for i in range(num_cues):
        plot_SR(agent_TMD.sr[i], i, 'TMD_agt')
    plot_reward_MT_positions(agent_TMD.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_over_states_TMD')
    plot_reward_distribution(full_test_rew_TMD, 'reward_distribution_TMD_agt')


    # test agent 2D weighting
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_rewards_2D, full_test_rew_2D = test_agent(agent_TMD, test_env, test_timesteps, weights=weights_2D)
    policies_2D = agent_TMD.get_policies(weights=weights_2D)
    plot_policy(map_size, policies_2D, rewarded_positions, '2D_agt')
    plot_reward_distribution(full_test_rew_2D, 'reward_distribution_2D_agt')


    # test agent factorized weighting
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_rewards_factorized, full_test_rew_factorized = test_agent(agent_TMD, test_env, test_timesteps, weights=factorized_weights)
    policies_factorized = agent_TMD.get_policies(weights=factorized_weights)
    plot_policy(map_size, policies_factorized, rewarded_positions, 'factorized_agt')
    plot_reward_distribution(full_test_rew_factorized, 'reward_distribution_factorized_agt')

    # plot policies for each of the factorized and 2D weights
    fact_policies = agent_TMD.get_policies(pair_cues=[0,1], weights=factorized_weights)
    policies_2D = agent_TMD.get_policies(pair_cues=[0,1], weights=weights_2D)
    plot_policy_two_cues(map_size, fact_policies, policies_2D, rewarded_positions[:2], '2D_vs_factorized_agt')



    # train normal agent
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0

    # create agent
    # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    normAgent = normalAgent(env, map_size, state, init_reward_MT, 0.01, discount_gamma, epsilon) #alpha)
            
    rewards_normAgent, test_rew_norm, _ = train_agent(normAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time normalAgent: ', np.mean(rewards_normAgent))

    # plot rewards
    plot_rewards(rewards_normAgent, 'rewards_normalAgent')
    plot_rewards(test_rew_norm, 'test_rewards_normalAgent')
    plot_value_function(np.max(normAgent.Q_value, axis=-1).reshape(-1, normAgent.Q_value.shape[2]), 'learned_value_function')



if __name__ == '__main__':
    main_train()
    # main_train_risk_sensitive()
