'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the reward that will be given
in each state.
We assume here that you have reach any state from any other state with no travel time.
'''

import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.stats import gaussian_kde




plot_folder = ''
experiment_name = 'no_travel_time_cue_test4/'
saveFolder = plot_folder + experiment_name

# create folder if it doesn't already exist
if os.path.exists(saveFolder):
    print('overwrite')
else:
    os.makedirs(saveFolder)


# Environment parameters

# number of states
num_states = 3
# max reward delay
max_reward_delay = 5
# max reward magnitude
max_reward_magnitude = 10
# learning rate
alpha = 0.01
# cue probabilities at each time
cue_probs = 0.1 * np.ones(num_states)
num_cues = len(cue_probs)
# number of time steps for training or acting
train_timesteps = 10000
test_timesteps = 10000
# test reward rate every time steps
test_every_n = 100


# DNL learning parameters
# To compute gradient
dx_der=0.1
dy_der=0.1
x_der, y_der = np.mgrid[-1:(max_reward_delay + 1):dx_der,-1:(max_reward_magnitude + 1):dy_der]
Nx_der=x_der.shape[0]
Ny_der=y_der.shape[1]
x_der_flat=np.expand_dims(np.ndarray.flatten(x_der),axis=1)
y_der_flat=np.expand_dims(np.ndarray.flatten(y_der),axis=1)
particles_der=np.concatenate((x_der_flat,y_der_flat),axis=1)

# Parameters for distributional neural learning
# WORKING
n_interactions=100
slower = 2
batch_size=10
bw= 0.5 #1.0 #0.8
lamb= 0.4  * np.max([max_reward_delay, max_reward_magnitude]) #1.0
gamma=100.0 #2.0 #0.09*16



# alpha=1000
# learning_rates = 0.001 * np.ones(10000)
learning_rates = 0.0001 * np.ones(10000)
learning_rates[:500] = np.linspace(0.001, 0.0001, 500)



# Initialize particles and gradient of F1 and F2
n_particles= 5*5 #4*4


cov=0.05*np.array([[1, 0], [0, 1]])
cov_shrink = np.ones(n_interactions * 1000) * 0.1
cov_shrink[:n_interactions] = np.linspace(1.0, 0.1, n_interactions)




def generate_reward_magnitude_time_matrices(num_cues, num_states, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    # whole matrix should sum to one for each state
    reward_magnitude_time_matrices = np.zeros((num_cues, num_states, max_reward_delay, max_reward_magnitude))
    # number of possible reward time and mags
    num_poss = 2
    for i in range(num_cues):
        j = i % num_states  # ensure that the cue is associated with a state
        for n in range(num_poss):
            time = np.random.randint(1, max_reward_delay)
            magnitude = np.random.randint(1, max_reward_magnitude)
            reward_magnitude_time_matrices[i, j, time, magnitude] = 1
        # normalize the matrix
        reward_magnitude_time_matrices[i] /= np.sum(reward_magnitude_time_matrices[i])
    return reward_magnitude_time_matrices

def generate_reward_magnitude_time_matrices_difficult(num_cues, num_states, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    # whole matrix should sum to one for each state
    reward_magnitude_time_matrices = np.zeros((num_cues, num_states, max_reward_delay, max_reward_magnitude))
    # number of possible reward time and mags
    np.random.seed(42)
    num_poss = 2
    for i in range(num_cues):
        j = i % num_states  # ensure that the cue is associated with a state
        # for n in range(num_poss):
        time = np.random.randint(1, max_reward_delay)
        magnitude = 1 # np.random.randint(1, max_reward_magnitude)
        reward_magnitude_time_matrices[i, j, time, magnitude] = 1
        time2 = (time + 3) % max_reward_delay
        magnitude2 = max_reward_magnitude - 1 # (magnitude + 3) % max_reward_magnitude
        # if magnitude2 == 0:
        #     magnitude2 = 1
        reward_magnitude_time_matrices[i, j, time2, magnitude2] = 1
        # normalize the matrix
        reward_magnitude_time_matrices[i] /= np.sum(reward_magnitude_time_matrices[i])
    return reward_magnitude_time_matrices


def generate_reward_magnitude_time_matrices_same_expected_value(num_cues, num_states, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    # whole matrix should sum to one for each state
    reward_magnitude_time_matrices = np.zeros((num_cues, num_states, max_reward_delay, max_reward_magnitude))
    # number of possible reward time and mags
    num_poss = 2
    for i in range(num_cues):
        time = np.random.randint(1, max_reward_delay)
        magnitude = np.random.randint(1, max_reward_magnitude)
        reward_magnitude_time_matrices[i, i, time, magnitude] = 1
        time2 = np.random.randint(1, max_reward_delay)
        magnitude2 = max_reward_magnitude - magnitude
        reward_magnitude_time_matrices[i, i, time2, magnitude2] = 1
        # normalize the matrix
        reward_magnitude_time_matrices[i] /= np.sum(reward_magnitude_time_matrices[i])
    return reward_magnitude_time_matrices


def plot_reward_magnitude_time_matrices(saveFolder, RT_matrices, title):
    max_reward_delay = RT_matrices.shape[-2]
    max_reward_magnitude = RT_matrices.shape[-1]
    num_states = RT_matrices.shape[1]
    num_cues = RT_matrices.shape[0]
    # plot reward time and magnitude matrices
    fig, axs = plt.subplots(num_states, 1, figsize=(5, 10))
    for i in range(num_states):
        j = i % num_cues  # ensure that the cue is associated with a state
        print('RT_matrices', RT_matrices.shape, i, j)
        axs[i].imshow(RT_matrices[j, i], cmap='hot', interpolation='nearest', vmin=0, vmax=1)
        axs[i].set_title('State ' + str(i))
        axs[i].set_xlabel('Reward Magnitude')
        axs[i].set_ylabel('Reward Time')
        axs[i].set_xticks(np.arange(max_reward_magnitude))
        axs[i].set_yticks(np.arange(max_reward_delay))
        axs[i].set_xticklabels(np.arange(1, max_reward_magnitude + 1))
        axs[i].set_yticklabels(np.arange(1, max_reward_delay + 1))
        plt.colorbar(axs[i].imshow(RT_matrices[j,i], cmap='hot', interpolation='nearest', vmin=0, vmax=1), ax=axs[i])
    fig.savefig(saveFolder + title + '.png')


def plot_particle_reward_magnitude_time_matrices(saveFolder, RT_matrices, particles, title):
    max_reward_delay = RT_matrices.shape[-2]
    max_reward_magnitude = RT_matrices.shape[-1]
    num_states = RT_matrices.shape[1]
    # plot reward time and magnitude matrices
    fig, axs = plt.subplots(num_states, 1, figsize=(5, 10))
    for i in range(num_states):
        axs[i].imshow(RT_matrices[i,i], cmap='hot', interpolation='nearest', vmin=0, vmax=1)
        axs[i].scatter(particles[i, :,1]-0.5,particles[i, :,0]-0.5,color="limegreen",s=1,zorder=n_particles+1)
        print('particles', i, particles[i, :,0], particles[i, :,1])
        axs[i].set_title('State ' + str(i))
        axs[i].set_xlabel('Reward Magnitude')
        axs[i].set_ylabel('Reward Time')
        axs[i].set_xticks(np.arange(max_reward_magnitude))
        axs[i].set_yticks(np.arange(max_reward_delay))
        axs[i].set_xticklabels(np.arange(1, max_reward_magnitude + 1))
        axs[i].set_yticklabels(np.arange(1, max_reward_delay + 1))
        plt.colorbar(axs[i].imshow(RT_matrices[i,i], cmap='hot', interpolation='nearest', vmin=0, vmax=1), ax=axs[i])
    fig.savefig(saveFolder + title + '.png')


def plot_1D_matrix_and_expected_values(saveFolder, true_expected_values, RT_vectors, expected_values, title, folder=None):
    ''' Plot a 1D vectors with written expected values'''
    fig, ax = plt.subplots(figsize=(5, 5))
    RT_show = np.array([RT_vectors[i, i, :] for i in range(RT_vectors.shape[0])]).T
    ax.imshow(RT_show, cmap='hot', interpolation='nearest', vmin=0, vmax=1)
    ax.axis('off')  # Turns off both axes (x and y), including ticks and labels
    # ax.set_title('Expected Reward Magnitude over Time')
    # write each expected value below the corresponding time step
    for i in range(len(expected_values)):
        ax.text(i, len(RT_show), str(round(expected_values[i], 2)), ha='center', va='bottom', fontsize=8, color='black')
        ax.text(i, len(RT_show) - 0.2, str(round(true_expected_values[i, i], 2)), ha='center', va='bottom', fontsize=8, color='red')
    if folder is not None:
        fig.savefig(folder + title + '.png')
    else:
        fig.savefig(saveFolder + title + '.png')



# construct no travel time environment
class noTravelTimeEnv(object):
    def __init__(self, num_states, reward_magnitude_time_matrices, cue_probs):
        self.num_states = num_states
        # hold reward time and magnitude matrices for each cue
        self.reward_magnitude_time_matrices = reward_magnitude_time_matrices
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # probability of cue at each time
        self.cue_probs = cue_probs
        # which state the agent is currently in
        self.current_state = 0
        # current environment reward over time map
        self.cur_reward_time_mag = np.zeros((self.num_states, self.max_reward_delay))

    def reset(self):
        # which state the agent is currently in
        self.current_state = 0
        # initial cue
        init_cue = np.zeros(self.num_states)
        # current environment reward over time map
        self.cur_reward_time_mag = np.zeros((self.num_states, self.max_reward_delay))
        return self.current_state, init_cue
        
    def step(self, action):
        # check for valid action
        assert action < self.num_states
        # take action, update state
        self.current_state = int(action)
        
        # get reward magnitude distribution for current state at current time since cue
        reward = self.cur_reward_time_mag[self.current_state, 0]
        
        # generate a new cue
        next_cue = np.zeros(self.num_states)
        for indc in range(self.num_states):
            # if np.random.choice(np.arange(2), p=self.cue_probs[indc]):
            if np.random.binomial(1, p=self.cue_probs[indc]):
                # give cue
                next_cue[indc] = 1
                
        # update current reward time and magnitude (drawn from distribution)
        # for inds in range(self.num_states):
        # move time forward
        self.cur_reward_time_mag[:, :-1] = self.cur_reward_time_mag[:, 1:]
        self.cur_reward_time_mag[:, -1] = 0
        for indc in range(len(self.cue_probs)):
            # check if cue is present
            if next_cue[indc] == 1:
                for inds in range(self.num_states):
                    # This draws independent rewards at each time for each cue
                    reward_state = inds
                    for time in range(self.max_reward_delay):
                        for mag in range(self.max_reward_magnitude):
                            if np.random.binomial(1, p=self.reward_magnitude_time_matrices[indc, reward_state, time, mag]):
                                if self.cur_reward_time_mag[reward_state, time] == 0:
                                    # if no reward has been given yet, give reward
                                    self.cur_reward_time_mag[reward_state, time] += mag
                

        # return reward and next cue
        return reward, next_cue, self.current_state



class noTravelAgent(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1, risk_weight=None):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        if risk_weight is None:
            self.magnitude_values = np.arange(self.max_reward_magnitude)
        else:
            # weights on TMD for risk sensitivity
            self.magnitude_values = risk_weight
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update current MT map
                self.current_MT_map += self.reward_MT[indc]
        # next possible reward magnitudes
        next_state_rew_mag = self.current_MT_map[:, 0]
        # expected reward per state
        expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map[:, :-1] = self.current_MT_map[:, 1:]
        self.current_MT_map[:, -1] = 0

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update MT map based on cues
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update current MT map
                self.test_MT_map += self.reward_MT[indc]
        # next possible reward magnitudes
        next_state_rew_mag = self.test_MT_map[:, 0]
        # expected reward per state
        expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        old_expect_rew = np.sum(next_state_rew_mag * (np.arange(self.max_reward_magnitude) + 1)[None, :], axis=-1)
        # print('expected reward:',expected_reward, action, old_expect_rew, np.argmax(old_expect_rew))
        # update current MT map in time
        self.test_MT_map[:, :-1] = self.test_MT_map[:, 1:]
        self.test_MT_map[:, -1] = 0

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
    

    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        
        
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    reward_map = np.zeros(self.reward_MT.shape[2:])
                    # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
                    reward_map[self.cur_cue_delay[indc], int(reward)] = 1
                    # update current MT map
                    self.reward_MT[indc, next_state] = (1 - self.alpha) * self.reward_MT[indc, next_state] + self.alpha * reward_map
            #         reward_map = np.zeros(self.reward_MT.shape)
            #         # print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
            #         reward_map[indc, next_state, self.cur_cue_delay[indc], int(reward)] = 1
            # # update current MT map
            # self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map

        # update time state
        self.cur_cue_delay += 1
        self.cur_cue_delay[self.cur_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        for indc in range(self.num_cues):
            if next_cue[indc] == 1:
                # update time state
                self.cur_cue_delay[indc] = 0

        # get next action
        # action = self.act(next_state, next_cue)
        action = np.random.choice(self.num_states)

        return action
    
class noTravelAgent_1D_time_dist(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices[:, :, :, 0]  # only use the first magnitude
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-1]
        # # max reward magnitude
        # self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # # reward magnitude values
        # self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # expected reward per state
        self.expected_return = np.zeros(self.reward_MT.shape[1])
        

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update current MT map
                self.current_MT_map += self.reward_MT[indc]
        # next expected reward per state
        expected_reward = self.current_MT_map[:, 0] * self.expected_return
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map[:, :-1] = self.current_MT_map[:, 1:]
        self.current_MT_map[:, -1] = 0

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update MT map based on cues
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update current MT map
                self.test_MT_map += self.reward_MT[indc]
        # next expected reward per state
        expected_reward = self.test_MT_map[:, 0] * self.expected_return
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.test_MT_map[:, :-1] = self.test_MT_map[:, 1:]
        self.test_MT_map[:, -1] = 0

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
    

    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        reward_map = np.zeros(self.reward_MT.shape)
        
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    reward_map[indc, next_state, self.cur_cue_delay[indc]] = 1
                    learn_rate = self.alpha * (self.reward_MT[indc, next_state, self.cur_cue_delay[indc]]/ np.sum(self.reward_MT[indc, next_state, :] + 1E-5))
                    self.expected_return[indc] = (1 - learn_rate) * self.expected_return[indc] + learn_rate * reward
            # update current MT map
            self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map

        # update time state
        self.cur_cue_delay += 1
        self.cur_cue_delay[self.cur_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        for indc in range(self.num_cues):
            if next_cue[indc] == 1:
                # update time state
                self.cur_cue_delay[indc] = 0

        # get next action
        _ = self.act(next_state, next_cue) # just to make self.current_state current
        action = np.random.choice(self.num_states)

        return action
    

class noTravelAgent_1D_magnitude_dist(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices[:, :, 0, :]  # only use the first delay
        # # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        self.test_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # expected reward per state
        self.expected_delay = np.zeros(self.reward_MT.shape[0])
        

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        for indc in range(self.num_cues):
            if cue[indc] == 1:
                # update current MT map
                self.current_MT_map += self.reward_MT[indc]
        # next expected reward per state
        print('self.current_MT_map', self.current_MT_map.shape, self.magnitude_values.shape, self.expected_delay.shape, self.cur_cue_delay.shape)
        expected_reward = np.sum(self.current_MT_map * self.magnitude_values[None, :], axis=-1) * (np.round(self.expected_delay) == self.cur_cue_delay[:, None]).astype(float)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map[:, :-1] = self.current_MT_map[:, 1:]
        self.current_MT_map[:, -1] = 0

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update MT map based on cues
        # for indc in range(self.num_cues):
        #     if cue[indc] == 1:
        #         # update current MT map
        #         self.test_MT_map += self.reward_MT[indc]
        # next expected reward per state
        # print('self.test_MT_map', self.test_MT_map.shape, self.magnitude_values.shape, self.test_cue_delay.shape, self.expected_delay.shape)
        # print('test act', np.round(self.expected_delay),self.test_cue_delay,  np.equal(np.round(self.expected_delay), self.test_cue_delay[:]).astype(float), np.diag(np.sum(self.reward_MT * self.magnitude_values[None,None, :], axis=-1)))
        expected_reward = np.diag(np.sum(self.reward_MT * self.magnitude_values[None, None, :], axis=-1)) * np.equal(np.round(self.expected_delay), self.test_cue_delay).astype(float)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # print('action', action, expected_reward)
        # update current MT map in time
        # self.test_MT_map[:, :-1] = self.test_MT_map[:, 1:]
        # self.test_MT_map[:, -1] = 0

        # update cue delays
        self.test_cue_delay = update_time_delays(self.test_cue_delay, self.num_cues, self.max_reward_delay, cue)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # reset cue delays
        self.test_cue_delay = np.ones(self.num_cues, dtype=int) * self.max_reward_magnitude

    

    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        reward_map = np.zeros(self.reward_MT.shape)
        
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            for indc in range(self.num_cues):
                if self.cur_cue_delay[indc] != self.max_reward_delay:
                    print('reward_map', indc, next_state, self.cur_cue_delay[indc], int(reward))
                    reward_map[indc, next_state, int(reward)] = 1
                    self.expected_delay[indc] = (1 - self.alpha) * self.expected_delay[indc] + self.alpha * self.cur_cue_delay[indc]
                    print('expected_delay', indc, self.expected_delay[indc], self.cur_cue_delay[indc])
            # update current MT map
            self.reward_MT = (1 - self.alpha) * self.reward_MT + self.alpha * reward_map

        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)

        # get next action
        # action = self.act(next_state, next_cue)
        action = np.random.choice(self.num_states)

        return action
    

def update_time_delays(cue_delay, num_cues, max_reward_delay, next_cue):
    # update time state
    cue_delay += 1
    cue_delay[cue_delay >= max_reward_delay] = max_reward_delay
    for indc in range(num_cues):
        if next_cue[indc] == 1:
            # update time state
            cue_delay[indc] = 0
    return cue_delay

def step_time_MT_map(MT_map):
    # move time forward in MT map
    MT_map[:, :-1] = MT_map[:, 1:]
    MT_map[:, -1] = 0
    return MT_map

def update_cur_MT_map(cue, reward_MT, test_MT_map):
    # update current MT map based on cues
    for indc in range(len(cue)):
        if cue[indc] == 1:
            # update current MT map
            test_MT_map += reward_MT[indc]
    return test_MT_map

def likelihood_limits(likelihood):
    np.log(likelihood/np.sum(likelihood))
    likelihood[likelihood < 1E-50] = 1E-50
    return likelihood

def particle_update(particles, reward, cue_delay, next_state, iter, max_r_delay, max_r_magnitude, particle_reward_MT, n_particles=n_particles, cov=cov, cov_shrink=cov_shrink, batch_size=batch_size, dx_der=dx_der, dy_der=dy_der, bw=bw):
    '''
    Update particles based on reward and next state.
    '''
    indc = int(next_state)
    print('reward', reward)
    mean = (cue_delay[indc] + 0.5, reward + 0.5)
    print('next_state', int(next_state))
    # print('iter', iter[int(next_state)])
    sample_rew=np.random.multivariate_normal(mean, cov * cov_shrink[int(iter[int(next_state)])], batch_size)

    RBF = gaussian_kde(sample_rew.T, bw_method=bw)
    likelihood = np.reshape(RBF(particles_der.T), (Nx_der, Ny_der))
    # likelihood= likelihood/np.sum(likelihood)
    likelihood = likelihood_limits(likelihood)
    likelihood= np.log(likelihood/np.sum(likelihood))
    # print('likelihood',likelihood)

    # Gradient of F1
    gradient_likelihood = np.gradient(likelihood, dx_der, dy_der)
    bins_x = np.digitize(particles[indc, :, 0], x_der[:, 0]) - 1
    bins_y = np.digitize(particles[indc, :, 1], y_der[0, :]) - 1
    gradient_f1 =np.zeros((n_particles,2))
    gradient_f2 =np.zeros((n_particles,2))
    gradient_f1[:, 0] = -np.ndarray.flatten(gradient_likelihood[0][bins_x, bins_y])
    gradient_f1[:, 1] = -np.ndarray.flatten(gradient_likelihood[1][bins_x, bins_y])

    # Gradient of F2
    particles_matrix = np.tile(particles[indc], (n_particles, 1, 1))
    dif_matrix = np.subtract(particles_matrix[:, :, :], particles[indc, :, np.newaxis, :])
    # distance_matrix = np.abs(dif_matrix)
    # distance_matrix = np.sum(np.abs(dif_matrix), axis=-1, keepdims=True)
    distance_matrix = np.sqrt(np.sum(dif_matrix**2,axis=-1,keepdims=True))
    gradient_f2 = np.sum(gamma * (distance_matrix / lamb - 1) * np.exp(-distance_matrix / lamb) * dif_matrix, axis=0)
    gradient_f2=gradient_f2/n_particles
    # print('distance matrix', (distance_matrix / lamb - 1), np.exp(-distance_matrix / lamb)* dif_matrix)
    

    # Sum gradient
    gradient = gradient_f1 + gradient_f2
    learn_r=learning_rates[int(iter[int(next_state)])] # learning rate
    # print('sample rew', mean, ' LR',learn_r,'gradient f1 and f2', -gradient_f1, -gradient_f2,'particles', np.round(particles[indc],4))


    # Update particles
    particles[indc] = particles[indc] - learn_r * gradient
    particles = particle_limits(particles, max_r_delay, max_r_magnitude)

    particle_reward_MT[indc, next_state] = matrix_particles(particles[indc], max_r_delay, max_r_magnitude)    
    iter[int(next_state)] += 1

    return particles, iter, particle_reward_MT

def particle_limits(particles, max_delay, max_mag):
    margin = 2.0
    particles[:,:,0][particles[:,:,0] > max_delay + margin] = max_delay + margin
    particles[:,:,0][particles[:,:,0] < - margin] = -margin
    particles[:,:,1][particles[:,:,1] > max_mag + margin] = max_mag + margin
    particles[:,:,1][particles[:,:,1] < -margin] = -margin
    return particles

def matrix_particles(particles, max_reward_delay, max_reward_magnitude):
    # create a matrix of particles
    counts = np.zeros((max_reward_delay, max_reward_magnitude), dtype=int)

    # Map each point to the nearest integer grid center
    grid_coords = np.floor(particles).astype(int)
    # print('grid_coords', grid_coords)

    # Filter and count
    for x, y in grid_coords:
        if 0 <= x < max_reward_delay and 0 <= y < max_reward_magnitude:
            counts[x, y] += 1

    return counts/np.sum(counts + 1E-5) 

def set_init_particles(max_reward_delay, max_reward_magnitude, num_states, num_cues, n_particles=n_particles):
    x, y = np.meshgrid(
        np.linspace(-0.5, max_reward_delay + 0.5, np.sqrt(n_particles).astype(int)),
        np.linspace(0.5, max_reward_magnitude + 0.5, np.sqrt(n_particles).astype(int))
    )
    x = x.ravel()
    y = y.ravel()
    particles_init = np.vstack((x, y)).T
    particles = np.zeros((num_states, n_particles, 2))
    for i in range(num_cues):
        particles[i] = np.copy(particles_init)
    return particles

    

class particle_noTravelAgent(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_states, self.num_cues)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        self.current_MT_map = update_cur_MT_map(cue, self.reward_MT, self.current_MT_map)
        # for indc in range(self.num_cues):
        #     if cue[indc] == 1:
        #         # update current MT map
        #         self.current_MT_map += self.reward_MT[indc]
        # next possible reward magnitudes
        next_state_rew_mag = self.current_MT_map[:, 0]
        # expected reward per state
        expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map = step_time_MT_map(self.current_MT_map)

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update MT map based on cues
        self.test_MT_map = update_cur_MT_map(cue, self.reward_MT, self.test_MT_map)
        # for indc in range(self.num_cues):
        #     if cue[indc] == 1:
        #         # update current MT map
        #         self.test_MT_map += self.reward_MT[indc]
        # next possible reward magnitudes
        next_state_rew_mag = self.test_MT_map[:, 0]
        # expected reward per state
        expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.test_MT_map = step_time_MT_map(self.test_MT_map)
        # self.test_MT_map[:, :-1] = self.test_MT_map[:, 1:]
        # self.test_MT_map[:, -1] = 0

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
    

    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        reward_map = np.zeros(self.reward_MT.shape)
        
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, self.cur_cue_delay, next_state, self.iter, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT)
            # self.particle_update(reward, next_state)
        self.reward_MT = self.particle_reward_MT


        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)


        # get next action
        # action = self.act(next_state, next_cue)
        action = np.random.choice(self.num_states)

        return action


class particle_time_only_agent(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices (cue x state x time x magnitude)
        self.reward_MT = reward_magnitude_time_matrices[:, :, :, 0]  # only use the first magnitude
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(self.max_reward_delay, self.max_reward_magnitude, self.num_states, self.num_cues)
        print('particles',self.particles)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices  # only use the first magnitude
        # expected reward per state
        self.expected_return = np.zeros(self.reward_MT.shape[1])

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        self.current_MT_map = update_cur_MT_map(cue, self.reward_MT, self.current_MT_map)
        # # next possible reward magnitudes
        # next_state_rew_mag = self.current_MT_map[:, 0]
        # expected reward per state
        # expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        expected_reward = self.current_MT_map[:, 0] * self.expected_return
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map = step_time_MT_map(self.current_MT_map)

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update MT map based on cues
        self.test_MT_map = update_cur_MT_map(cue, self.reward_MT, self.test_MT_map)
        # # next possible reward magnitudes
        # next_state_rew_mag = self.test_MT_map[:, 0]
        # # expected reward per state
        # expected_reward = np.sum(next_state_rew_mag * self.magnitude_values[None, :], axis=-1)
        # next expected reward per state
        expected_reward = self.test_MT_map[:, 0] * self.expected_return
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.test_MT_map = step_time_MT_map(self.test_MT_map)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
    

    def train(self, reward, next_state, next_cue):
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, self.cur_cue_delay, next_state, self.iter, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT)
            if np.isnan(self.particles).any() or np.isnan(self.particle_reward_MT).any():
                print('particles nan', self.particles)
                print('particle_reward_MT nan', self.particle_reward_MT)
                print('reward', reward, 'next_state', next_state, 'cur_cue_delay', self.cur_cue_delay, 'iter', self.iter)
                raise ValueError('Particles or particle_reward_MT contain NaN values.')

            # self.particle_update(reward, next_state)
            # print('self.particle_reward_MT', self.particle_reward_MT.shape, self.reward_MT.shape, self.particle_reward_MT)
            self.reward_MT = np.sum(self.particle_reward_MT, axis=-1) # sum over magnitudes
            # self.expected_return = np.array([np.sum(self.reward_MT[i,i] * self.magnitude_values[None, :], axis=(-1,-2)) for i in range(self.num_states)])  # expected return for each state
            self.expected_return = np.array([np.sum(self.particle_reward_MT[i,i] * self.magnitude_values[None, :], axis=(-1,-2)) for i in range(self.num_states)])  # expected return for each state
        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)
        # get next action
        action = np.random.choice(self.num_states)
        return action



class particle_mag_only_agent(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # reward magnitude and time matrices
        self.reward_MT = reward_magnitude_time_matrices[:, :, 0, :]  # only use the first delay
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # current reward map from observed cues       
        self.current_MT_map = np.zeros(self.reward_MT.shape[1:])
        # current time state
        self.cur_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        self.test_cue_delay = np.ones(self.num_cues, dtype=int) * self.num_cues
        # learning rate
        self.alpha = alpha
        # bandwidth of gaussian likelihood
        self.bw = 2
        # iterations of particle learning
        self.iter = np.zeros(self.num_cues,dtype=int)
        # particles for each state
        self.particles = set_init_particles(0, self.max_reward_magnitude, self.num_states, self.num_cues)
        # particle reward magnitude and time matrices
        self.particle_reward_MT = reward_magnitude_time_matrices
        # expected reward per state
        self.expected_delay = np.zeros(self.reward_MT.shape[0])

    def act(self, state, cue):
        # set current state
        self.current_state = state
        # update MT map based on cues
        self.current_MT_map = update_cur_MT_map(cue, self.reward_MT, self.current_MT_map)
        # expected reward per state
        expected_reward = np.sum(self.current_MT_map * self.magnitude_values[None, :], axis=-1) * (np.round(self.expected_delay) == self.cur_cue_delay[:, None]).astype(float)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update current MT map in time
        self.current_MT_map = step_time_MT_map(self.current_MT_map)

        return action
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # expected reward per state
        expected_reward = np.diag(np.sum(self.reward_MT * self.magnitude_values[None, None, :], axis=-1)) * np.equal(np.round(self.expected_delay), self.test_cue_delay).astype(float)
        # choose action based on expected reward
        action = np.argmax(expected_reward)
        # update cue delays
        self.test_cue_delay = update_time_delays(self.test_cue_delay, self.num_cues, self.max_reward_delay, cue)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_MT_map = np.zeros(self.reward_MT.shape[1:])
        # reset cue delays
        self.test_cue_delay = np.ones(self.num_cues, dtype=int) * self.max_reward_magnitude
    
    def train(self, reward, next_state, next_cue):   
        # make reward 1 for all cues at current delays and reward magnitudes
        if reward > 0:
            self.particles, self.iter, self.particle_reward_MT = particle_update(self.particles, reward, self.cur_cue_delay, next_state, self.iter, self.max_reward_delay, self.max_reward_magnitude, self.particle_reward_MT)
            # self.particle_update(reward, next_state)
            self.reward_MT = np.sum(self.particle_reward_MT, axis=-2)
            self.expected_delay = np.array([np.sum(self.reward_MT[i,i] * np.arange(self.max_reward_delay)[:, None], axis=(-1,-2)) for i in range(self.num_states)])  # expected return for each state

        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)
        # get next action
        action = np.random.choice(self.num_states)

        return action


class normalAgent(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # internal state
        self.cur_cue_delay = np.ones(self.num_states) * self.max_reward_delay
        # value
        self.value = np.zeros((self.num_states, (self.max_reward_delay + 1)**self.num_cues))
        # learning rate
        self.alpha = alpha

    def get_value(self, cue_delay):
        '''
        get value for the cue delay
        '''
        index = self.get_cue_index(cue_delay)
        return self.value[:, index]
    
    def get_cue_index(self, cue_delay):
        '''
        get index for the cue delay
        '''
        index = 0
        for ind in range(self.num_cues):
            index += cue_delay[ind] * ((self.max_reward_delay + 1) ** ind)
        return int(index)
    
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update time state
        self.test_cue_delay = update_time_delays(self.test_cue_delay, self.num_cues, self.max_reward_delay, cue)
        # self.test_cue_delay += 1
        # self.test_cue_delay[self.test_cue_delay >= self.max_reward_delay] = self.max_reward_delay
        # # set current cue
        # for indc in range(self.num_cues):
        #     if cue[indc] == 1:
        #         # update time state
        #         self.test_cue_delay[indc] = 0
        # expected reward per state
        expected_reward = self.get_value(self.test_cue_delay)
        # choose action based on expected reward
        action = np.argmax(expected_reward)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_cue_delay = np.ones(self.num_states) * self.max_reward_delay
    
    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        # update value function
        self.value[next_state, self.get_cue_index(self.cur_cue_delay)] = self.alpha * reward + (1 - self.alpha) * self.value[next_state, self.get_cue_index(self.cur_cue_delay)]
        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)

        # get next action
        # action = self.act(next_state, next_cue)
        action = np.random.choice(self.num_states)

        return action

class normalAgent_bootstrapped(object):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1):
        # number of states
        self.num_states = num_states
        # number of cues
        self.num_cues = reward_magnitude_time_matrices.shape[0]
        # initial state
        self.current_state = init_state
        # max reward delay
        self.max_reward_delay = reward_magnitude_time_matrices.shape[-2]
        # max reward magnitude
        self.max_reward_magnitude = reward_magnitude_time_matrices.shape[-1]
        # reward magnitude values
        self.magnitude_values = np.arange(self.max_reward_magnitude) + 1
        # internal state
        self.cur_cue_delay = np.ones(self.num_states) * self.max_reward_delay
        # value (action x state x cue delay)
        self.Q_value = np.zeros((self.num_states, self.num_states, (self.max_reward_delay + 1)**self.num_cues))
        # learning rate
        self.alpha = alpha

    def get_value(self, state, cue_delay):
        '''
        get value for the cue delay
        '''
        index = self.get_cue_index(cue_delay)
        return self.Q_value[:, state, index]
    
    def get_cue_index(self, cue_delay):
        '''
        get index for the cue delay
        '''
        index = 0
        for ind in range(self.num_cues):
            index += cue_delay[ind] * ((self.max_reward_delay + 1) ** ind)
        return int(index)
    
    def test_act(self, state, cue):
        # set current state
        self.test_state = state
        # update time state
        self.test_cue_delay = update_time_delays(self.test_cue_delay, self.num_cues, self.max_reward_delay, cue)
        # expected reward per state
        expected_reward = self.get_value(self.test_state, self.test_cue_delay)
        # choose action based on expected reward
        action = np.argmax(expected_reward)

        return action
    
    def reset_test(self, state):
        # reset test state
        self.test_state = state
        # reset test MT map
        self.test_cue_delay = np.ones(self.num_states) * self.max_reward_delay
    
    def train(self, reward, next_state, next_cue):
        # current state is still before action was taken
        prev_state = self.current_state
        # update value function
        cue_ind = self.get_cue_index(self.cur_cue_delay)
        action = next_state # action in the patch environment is equivalent to the next state/location
        self.Q_value[action, prev_state, cue_ind] = self.alpha * (reward + self.Q_value[action, prev_state, cue_ind]) + (1 - self.alpha) * self.Q_value[action, prev_state, cue_ind]
        # update time state
        self.cur_cue_delay = update_time_delays(self.cur_cue_delay, self.num_cues, self.max_reward_delay, next_cue)
        # get next action
        action = np.random.choice(self.num_states)
        #update current state
        self.current_state = next_state

        return action
    
def plot_rewards(saveFolder, rewards, title):
    plt.figure()
    plt.plot(rewards)
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Reward')
    plt.savefig(saveFolder + title + '.png')
    plt.close()
        
def plot_time_avg_rewards(saveFolder, rewards, title):
    '''
    plot rewards averaged in a 100 time window
    '''
    plt.figure()
    plt.plot(np.convolve(rewards, np.ones(100)/100, mode='valid'))
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Reward')
    plt.savefig(saveFolder + title + '.png')
    plt.close()

def plot_value_function(saveFolder, value, title):
    plt.figure()
    plt.imshow(value, cmap='hot', interpolation='nearest')
    plt.title(title)
    plt.xlabel('Cue Delay')
    plt.ylabel('State')
    plt.colorbar()
    plt.savefig(saveFolder + title + '.png')
    plt.close()
        
# run agent in environment
def main_act():
    # create environment
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    # create agent
    agent = noTravelAgent(num_states, state, reward_magnitude_time_matrices)
    
    # track rewards
    rewards = []
    
    for time in range(num_timesteps):
              
        # get action from agent
        action = agent.act(state, cue)
        # take action in environment
        reward, cue, state = env.step(action)            
        rewards.append(reward)
        
    print('Average reward over time noTravelAgent: ', np.mean(rewards))

def test_agent(agent, env, num_timesteps):
    # reset environment
    state, cue = env.reset()
    reward = 0
    # track rewards
    rewards = np.zeros(num_timesteps)
    # reset agent for test
    agent.reset_test(state)
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.test_act(state, cue)
        # take action in environment
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        
    return np.mean(rewards)

def train_agent(agent, env, test_env, num_timesteps, test_timesteps, test_every, init_state, init_cue, init_reward):
    state = init_state
    cue = init_cue
    reward = init_reward
    # track rewards
    rewards = np.zeros(num_timesteps)
    test_rewards = np.zeros(num_timesteps // test_every)
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.train(reward, state, cue)
        # take action in environment
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        print('time: ',time,'cue: ', cue, 'state: ', state, 'action: ', action, 'reward: ', reward, 'cur cue delay: ', agent.cur_cue_delay)
        # test agent every test_every time steps
        if time % test_every == 0:
            # test agent
            test_rewards[time // test_every] = test_agent(agent, test_env, test_timesteps)

        # if time > 1000:
        #     a
    
    return rewards, test_rewards

# plot average rewards from multiple runs
def plot_average_rewards(saveFolder, rewards_noTravel, rewards_norm, rewards_time=None, rewards_mag=None, labels=None):
    # plot average rewards
    plt.figure()
    if labels is None:
        labels = ['TMRL', 'standard RL', 'mag-ablated', 'time-ablated']
    plt.plot(np.mean(rewards_noTravel, axis=0), label=labels[0])
    plt.plot(np.mean(rewards_norm, axis=0), label=labels[1])
    plt.fill_between(range(len(np.mean(rewards_noTravel, axis=0))), np.mean(rewards_noTravel, axis=0) - np.std(rewards_noTravel, axis=0), np.mean(rewards_noTravel, axis=0) + np.std(rewards_noTravel, axis=0), alpha=0.2)
    plt.fill_between(range(len(np.mean(rewards_norm, axis=0))), np.mean(rewards_norm, axis=0) - np.std(rewards_norm, axis=0), np.mean(rewards_norm, axis=0) + np.std(rewards_norm, axis=0), alpha=0.2)
    if rewards_time is not None:
        plt.plot(np.mean(rewards_time, axis=0), label=labels[2])
        plt.plot(np.mean(rewards_mag, axis=0), label=labels[3])
        plt.fill_between(range(len(np.mean(rewards_time, axis=0))), np.mean(rewards_time, axis=0) - np.std(rewards_time, axis=0), np.mean(rewards_time, axis=0) + np.std(rewards_time, axis=0), alpha=0.2)
        plt.fill_between(range(len(np.mean(rewards_mag, axis=0))), np.mean(rewards_mag, axis=0) - np.std(rewards_mag, axis=0), np.mean(rewards_mag, axis=0) + np.std(rewards_mag, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + 'average_rewards.png')

def get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs):
    # create environment
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    test_env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    
    return env, test_env, state, cue, reward

def main_train_only_stardard_agent():
    num_runs = 10
    
    # randomly generate reward time and magnitude matrices
    reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(saveFolder, reward_magnitude_time_matrices, 'reward_time_magnitude_matrices_per_state')

    # save test rewards
    test_rewards_norm = []
    test_rewards_bootstrap = []
    # run multiple agents
    for run_num in range(num_runs):
        # create environment
        env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
        # create agent
        agent = normalAgent(num_states, state, init_reward_MT, alpha)
        rewards, test_r_norm = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)

        # create environment
        env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
        # create agent
        agent_bootstrap = normalAgent_bootstrapped(num_states, state, init_reward_MT, alpha)
        rewards_bootstrap, test_r_bootstrap = train_agent(agent_bootstrap, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    
    test_rewards_norm.append(test_r_norm)
    test_rewards_bootstrap.append(test_r_bootstrap)
    plot_average_rewards(saveFolder, np.array(test_rewards_norm), np.array(test_rewards_bootstrap), None, None, labels=['non-bootstrapped', 'bootstrapped'])



# run agent in environment
def main_train():

    # randomly generate reward time and magnitude matrices
    reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(saveFolder, reward_magnitude_time_matrices, 'reward_time_magnitude_matrices_per_state')

    # train no travel time agent
    # create environment
    env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
    # create agent
    # agent = noTravelAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    agent = noTravelAgent(num_states, state, init_reward_MT, alpha)
    
    rewards_no_travel, test_rew_noTravel = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        
    print('Average reward over time noTravelAgent: ', np.mean(rewards_no_travel))

    # train normal agent
    # create environment
    env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
    # create agent
    # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    normAgent = normalAgent(num_states, state, init_reward_MT, 0.01) #alpha)
            
    rewards_normAgent, test_rew_norm = train_agent(normAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time normalAgent: ', np.mean(rewards_normAgent))



    # train only time distribution agent
    # create environment
    env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
    # create agent
    # init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape[0:-1])
    timeAgent = noTravelAgent_1D_time_dist(num_states, state, init_reward_MT, alpha)
            
    rewards_timeAgent, test_rew_time = train_agent(timeAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time timeAgent: ', np.mean(rewards_timeAgent))
    


    # train only magnitude distribution agent
    # create environment
    env, test_env, state, cue, reward = get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs)
    # create agent
    # init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape[0:-2] + (max_reward_magnitude,))
    magAgent = noTravelAgent_1D_magnitude_dist(num_states, state, init_reward_MT, alpha)
            
    rewards_magAgent, test_rew_mag = train_agent(magAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time magAgent: ', np.mean(rewards_magAgent))



    # plot rewards
    plot_rewards(saveFolder, rewards_no_travel, 'rewards_noTravelAgent')
    plot_rewards(saveFolder, rewards_normAgent, 'rewards_normalAgent')
    plot_rewards(saveFolder, rewards_timeAgent, 'rewards_timeAgent')
    plot_rewards(saveFolder, rewards_timeAgent, 'rewards_magAgent')

    plot_rewards(saveFolder, test_rew_noTravel, 'test_rewards_noTravelAgent')
    plot_rewards(saveFolder, test_rew_norm, 'test_rewards_normalAgent')
    plot_rewards(saveFolder, test_rew_time, 'test_rewards_timeAgent')
    plot_rewards(saveFolder, test_rew_mag, 'test_rewards_magAgent')
    # plot_time_avg_rewards(rewards_no_travel, 'timge_avg_rewards_noTravelAgent')
    # plot_time_avg_rewards(rewards_normAgent, 'timge_avg_rewards_normalAgent')
    # plot learned reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(saveFolder, agent.reward_MT, 'learned_reward_time_magnitude_matrices_per_state')
    plot_value_function(saveFolder, normAgent.value, 'learned_value_function')

    true_magnitude_expected_values = np.sum(reward_magnitude_time_matrices * (np.arange(max_reward_magnitude)), axis=(-1, -2))
    plot_1D_matrix_and_expected_values(saveFolder, true_magnitude_expected_values, timeAgent.reward_MT, timeAgent.expected_return, 'learned_time_delay_vector')

    true_delay_expected_values = np.sum(reward_magnitude_time_matrices * (np.arange(max_reward_delay))[:,None], axis=(-1, -2))
    plot_1D_matrix_and_expected_values(saveFolder, true_delay_expected_values, magAgent.reward_MT, magAgent.expected_delay, 'learned_magnitude_vector')


# run agent in environment
def main_train_particles():

    # randomly generate reward time and magnitude matrices
    # reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices_difficult(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(saveFolder, reward_magnitude_time_matrices, 'reward_time_magnitude_matrices_per_state')

    # train no travel time agent
    # create environment
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    test_env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0

    # particles = np.zeros((num_cues, n_particles, 2)) #################################!!!!!!!!!!!!!!!!!!!!!!!
    # particles[0,-1,:] = 10
    # particle_update(particles, 2, np.array([2,2,2]), 0, np.array([2,2,2]), max_reward_delay, max_reward_magnitude, init_reward_MT)
    # a
    # create agent
    # agent = noTravelAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    agent = particle_noTravelAgent(num_states, state, init_reward_MT, alpha)
    
    rewards_no_travel, test_rew_noTravel = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        
    print('Average reward over time noTravelAgent: ', np.mean(rewards_no_travel))

    # train normal agent
    # create environment
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    test_env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    # create agent
    # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    normAgent = normalAgent(num_states, state, init_reward_MT, 0.01) #alpha)
            
    rewards_normAgent, test_rew_norm = train_agent(normAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time normalAgent: ', np.mean(rewards_normAgent))

    # plot rewards
    plot_rewards(saveFolder, rewards_no_travel, 'rewards_noTravelAgent')
    plot_rewards(saveFolder, rewards_normAgent, 'rewards_normalAgent')
    plot_rewards(saveFolder, test_rew_noTravel, 'test_rewards_noTravelAgent')
    plot_rewards(saveFolder, test_rew_norm, 'test_rewards_normalAgent')
    # plot_time_avg_rewards(rewards_no_travel, 'timge_avg_rewards_noTravelAgent')
    # plot_time_avg_rewards(rewards_normAgent, 'timge_avg_rewards_normalAgent')
    # plot learned reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(saveFolder, agent.reward_MT, 'learned_reward_time_magnitude_matrices_per_state')
    # plot_reward_magnitude_time_matrices(agent.particle_reward_MT, 'particle_reward_time_magnitude_matrices_per_state')
    plot_value_function(saveFolder, normAgent.value, 'learned_value_function')
    plot_particle_reward_magnitude_time_matrices(saveFolder, agent.particle_reward_MT, agent.particles, 'particle_reward_time_magnitude_matrices_per_state')


if __name__ == '__main__':
    # set random seed
    # np.random.seed(42)
    # main_train()
    main_train_particles()
    # main_train_only_stardard_agent()