'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the reward that will be given
in each state.
We assume here that you have reach any state from any other state with no travel time.
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
import os

from gridwork_reward_time_and_mag import gridWorldEnv
from gridwork_reward_time_and_mag import rewardMTAgent, rewardMTAgent_particle
from gridwork_reward_time_and_mag import normalAgent
from gridwork_reward_time_and_mag import train_agent
from gridwork_reward_time_and_mag import generate_grid_reward_magnitude_time_matrices, plot_reward_magnitude_time_matrices, plot_rewards, plot_value_function, plot_reward_MT_positions

# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.5
vertical_size = 1.5
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2



plot_folder = ''
experiment_name = 'gridworld_group_cues_2/'
saveFolder = plot_folder + experiment_name


# Environment parameters

np.random.seed(42)
# number of runs
num_runs = 10
# map size
map_size = 5
# max reward delay
max_reward_delay = 5
# max reward magnitude
max_reward_magnitude = 4
# learning rate
alpha = 0.005
# cue probabilities at each time
num_cues = 2
cue_probs = 0.1 * np.ones(num_cues)
# number of time steps for training or acting
train_timesteps = 300000
test_timesteps = 10000
# test reward rate every time steps
test_every_n = 1000 #100
# discount factor
discount_gamma = 0.98
# epsilon greedy exploration
epsilon = 0.1


def train_single(agent_type, reward_magnitude_time_matrices, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n):
    # randomly generate reward time and magnitude matrices
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)

    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0

    # create agent
    if agent_type == 'RTM':
        agent = rewardMTAgent(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon)
    elif agent_type == 'norm':
        agent = normalAgent(env, map_size, state, init_reward_MT, 0.01, discount_gamma, epsilon)
    elif agent_type == 'TMRL':
        agent = rewardMTAgent_particle(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon)
    
    rewards, test_rew, _ = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    return agent, rewards, test_rew

# run agent in environment
def train_both(run_num, reward_magnitude_time_matrices, rewarded_positions, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n):

    # agent, rewards_no_travel, test_rew_noTravel = train_single('RTM', reward_magnitude_time_matrices, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n)
    rewards_no_travel = []
    test_rew_noTravel = []

    #######################################

    normalAgent, rewards_normAgent, test_rew_norm = train_single('norm', reward_magnitude_time_matrices, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n)

    particle_agent, rewards_particle, test_rew_particle = train_single('TMRL', reward_magnitude_time_matrices, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n)

    # plot rewards
    # plot_rewards(rewards_no_travel, 'rewards_noTravelAgent'+'_'+str(run_num))
    # plot_rewards(rewards_normAgent, 'rewards_normalAgent'+'_'+str(run_num))
    # plot_rewards(rewards_particle, 'rewards_normalAgent'+'_'+str(run_num))
    # plot_rewards(test_rew_noTravel, 'test_rewards_noTravelAgent'+'_'+str(run_num), saveFolder=saveFolder, opt_rewards=None)
    plot_rewards(test_rew_norm, 'test_rewards_normalAgent'+'_'+str(run_num), saveFolder=saveFolder, opt_rewards=None)
    plot_rewards(test_rew_particle, 'test_rewards_particleAgent'+'_'+str(run_num), saveFolder=saveFolder, opt_rewards=None)
    # plot_time_avg_rewards(rewards_no_travel, 'timge_avg_rewards_noTravelAgent')
    # plot_time_avg_rewards(rewards_normAgent, 'timge_avg_rewards_normalAgent')
    # plot learned reward time and magnitude matrices
    # plot_reward_magnitude_time_matrices(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state'+'_'+str(run_num), saveFolder=saveFolder)
    # plot_reward_MT_positions(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state'+'_'+str(run_num), saveFolder=saveFolder)
    
    ################################
    # plot_reward_magnitude_time_matrices(particle_agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state'+'_'+str(run_num), saveFolder=saveFolder)
    # plot_reward_MT_positions(particle_agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_over_state'+'_'+str(run_num), saveFolder=saveFolder)
    ##################################
    # plot_value_function(np.max(normAgent.Q_value, axis=-1), 'learned_value_function'+'_'+str(run_num))

    # return rewards_no_travel, test_rew_noTravel, rewards_normAgent, test_rew_norm
    return rewards_no_travel, test_rew_noTravel, rewards_normAgent, test_rew_norm, rewards_particle, test_rew_particle


def train_multiple_agents(num_runs, num_cues, map_size, max_reward_delay, max_reward_magnitude, alpha, cue_probs, train_timesteps, test_timesteps, test_every_n):

    # save parameters
    np.savez(saveFolder + 'parameters.npz', map_size=map_size, max_reward_delay=max_reward_delay, max_reward_magnitude=max_reward_magnitude, alpha=alpha, cue_probs=cue_probs, train_timesteps=train_timesteps, test_timesteps=test_timesteps, test_every_n=test_every_n, discount_gamma=discount_gamma, epsilon=epsilon)

    # generate random reward time and magnitude matrix
    reward_magnitude_time_matrices, rewarded_positions = generate_grid_reward_magnitude_time_matrices(num_cues, map_size, max_reward_delay, max_reward_magnitude)
    # plot reward time and magnitude matrices
    plot_reward_magnitude_time_matrices(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_per_state', saveFolder=saveFolder)

    rewards_noTravel = []
    rewards_norm = []
    test_rewards_noTravel = []
    test_rewards_norm = []
    rewards_particle= []
    test_rewards_particle = []
    # run multiple agents
    for run_num in range(num_runs):
        print('run number: ', run_num)
        # train agent
        # r_noTravel, r_test_noTravel, r_norm, r_test_norm = train_both(run_num, reward_magnitude_time_matrices, rewarded_positions, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n)
        r_noTravel, r_test_noTravel, r_norm, r_test_norm, r_particle, r_test_particle = train_both(run_num, reward_magnitude_time_matrices, rewarded_positions, cue_probs, map_size, alpha, train_timesteps, test_timesteps, test_every_n)
        rewards_noTravel.append(r_noTravel)
        rewards_norm.append(r_norm)
        test_rewards_noTravel.append(r_test_noTravel)
        test_rewards_norm.append(r_test_norm)
        rewards_particle.append(r_particle)
        test_rewards_particle.append(r_test_particle)
        # save rewards
        np.savez(saveFolder + 'rewards_' + str(run_num) + '.npz', rewards_noTravel=np.array(r_particle), test_rewards_noTravel=np.array(r_test_particle), rewards_norm=np.array(r_norm), test_rewards_norm=np.array(r_test_norm))
        # np.savez(saveFolder + 'rewards_' + str(run_num) + '.npz', rewards_noTravel=np.array(r_noTravel), test_rewards_noTravel=np.array(r_test_noTravel), rewards_norm=np.array(r_norm), test_rewards_norm=np.array(r_test_norm), rewards_particle=np.array(r_particle), test_rewards_particle=np.array(r_test_particle))

    # optAgent, opt_rewards, test_opt_rewards = train_single('norm', reward_magnitude_time_matrices, cue_probs, map_size, alpha, train_timesteps * 1000, test_timesteps, test_every_n * 1000)
    # # save rewards
    # np.savez(saveFolder + 'rewards_optimal.npz', opt_rewards=np.array(opt_rewards), test_opt_rewards=np.array(test_opt_rewards))

    # make plots of average rewards
    # plot_average_rewards(np.array(test_rewards_noTravel), np.array(test_rewards_norm), saveFolder)
    plot_average_rewards(np.array(test_rewards_noTravel), np.array(test_rewards_norm), np.array(test_rewards_particle), saveFolder)



# plot average rewards from multiple runs
def plot_average_rewards(rewards_noTravel, rewards_norm, rewards_particle=None, saveFolder=saveFolder):
    # plot average rewards
    plt.figure()
    # plt.plot(np.mean(rewards_noTravel, axis=0), label='noTravelAgent')
    # plt.fill_between(range(len(np.mean(rewards_noTravel, axis=0))), np.mean(rewards_noTravel, axis=0) - np.std(rewards_noTravel, axis=0), np.mean(rewards_noTravel, axis=0) + np.std(rewards_noTravel, axis=0), alpha=0.2)
    if rewards_particle is not None:
        plt.plot(np.mean(rewards_particle, axis=0), label='particleAgent')
        plt.fill_between(range(len(np.mean(rewards_particle, axis=0))), np.mean(rewards_particle, axis=0) - np.std(rewards_particle, axis=0), np.mean(rewards_particle, axis=0) + np.std(rewards_particle, axis=0), alpha=0.2)
    else:
        plt.plot(np.mean(rewards_noTravel, axis=0), label='noTravelAgent')
        plt.fill_between(range(len(np.mean(rewards_noTravel, axis=0))), np.mean(rewards_noTravel, axis=0) - np.std(rewards_noTravel, axis=0), np.mean(rewards_noTravel, axis=0) + np.std(rewards_noTravel, axis=0), alpha=0.2)
    
    plt.plot(np.mean(rewards_norm, axis=0), label='normalAgent')
    plt.fill_between(range(len(np.mean(rewards_norm, axis=0))), np.mean(rewards_norm, axis=0) - np.std(rewards_norm, axis=0), np.mean(rewards_norm, axis=0) + np.std(rewards_norm, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + 'average_rewards.png')



def plot_saved_rewards():
    rewards_noTravel = []
    rewards_norm = []
    test_rewards_noTravel = []
    test_rewards_norm = []
    for run_num in range(num_runs):
        # load rewards
        data = np.load(saveFolder + 'rewards_' + str(run_num) + '.npz')
        rewards_noTravel.append(data['rewards_noTravel'])
        test_rewards_noTravel.append(data['test_rewards_noTravel'])
        rewards_norm.append(data['rewards_norm'])
        test_rewards_norm.append(data['test_rewards_norm'])
    plot_average_rewards(np.array(test_rewards_noTravel), np.array(test_rewards_norm))


# plot average rewards from multiple runs
def plot_average_rewards_with_opt(folder, mean_rew_TMRL, std_rew_TMRL, mean_rew_norm, std_rew_norm, mean_opt):
    # plot average rewards
    plt.figure()
    plt.plot(mean_rew_TMRL, label='TMRL')
    plt.plot(mean_rew_norm, label='standard RL')
    # plt.plot(mean_opt, label='optimal')
    plt.fill_between(range(len(mean_rew_TMRL)), mean_rew_TMRL - std_rew_TMRL, mean_rew_TMRL + std_rew_TMRL, alpha=0.2)
    plt.fill_between(range(len(mean_rew_norm)), mean_rew_norm - std_rew_norm, mean_rew_norm + std_rew_norm, alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(folder + 'average_rewards_opt.png')


# plot average rewards from multiple runs
def plot_rewards_all(folder, num_runs, rewards):
    # plot average rewards
    plt.figure()
    for i in range(num_runs):
        plt.plot(rewards[i], label=str(i))
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.savefig(folder + 'rewards_all.png')

#####################################################################################
def fill_zeros_with_previous(arr):
    arr = arr.copy()  # Avoid modifying the original array
    last_val = None
    for i in range(len(arr)):
        if arr[i] != 0:
            last_val = arr[i]
        elif last_val is not None:
            arr[i] = last_val
    return arr
######################################################################################


def get_mean_std_all_runs(group_folder, experiments, num_runs, agent_names, opt=False, ablation=False):
    mean_agents = {name: [] for name in agent_names}
    std_agents = {name: [] for name in agent_names}

    for exp in experiments:
        folder = plot_folder + group_folder + '_cues_' + str(exp) +'/'

        # save rewards
        rewards_agents = {name: [] for name in agent_names}
        test_rewards_agents = {name: [] for name in agent_names}

        for name in agent_names:
            for run_num in range(num_runs):
                # load rewards
                data = np.load(folder + 'rewards_' + name + '_' + str(run_num) + '.npz')
                # print('data',folder, data)
                rewards_agents[name].append(data['rewards'])
                test_rewards_agents[name].append(data['test_rewards'])


            mean_agents[name].append(np.mean(np.array(test_rewards_agents[name]), axis=0))
            std_agents[name].append(np.std(np.array(test_rewards_agents[name]), axis=0))
    for name in agent_names:
        mean_agents[name] = np.array(mean_agents[name])
        std_agents[name] = np.array(std_agents[name])

    test_every_n = np.array(rewards_agents[agent_names[0]]).shape[-1] // mean_agents[agent_names[0]].shape[-1]
    return mean_agents, std_agents, test_every_n


def comparison_plots(group_folder, experiments, num_runs, agent_names, format='png', opt=False, ablation=False):
    
    mean_agents, std_agents, test_every_n = get_mean_std_all_runs(group_folder, experiments, num_runs, agent_names)

    # set colormap for matplotlib
    cmap_agents = [plt.cm.Blues, plt.cm.Reds, plt.cm.Greens]
    colors_agents = {}
    for ind, name in enumerate(agent_names):
        colors_agents[name] = [cmap_agents[ind]((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    # line styles
    linestyles = ['-', '-']

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    # Store line handles for custom legend
    legend_handles = []
    # plt.figure()
    for ind, exp in enumerate(experiments):
        if 'rewardMTAgent' in agent_names:
            max_mean = np.max(mean_agents['rewardMTAgent'][ind])
        elif 'rewardMTAgent_particle' in agent_names:
            max_mean = np.max(mean_agents['rewardMTAgent_particle'][ind])
        print('max mean',max_mean)

        legend_agents = {}
        for name in agent_names:
            line_agent, = ax.plot(np.arange(len(mean_agents[name][ind])) * test_every_n, mean_agents[name][ind]/max_mean, color=colors_agents[name][ind], lw=0.7)
            ax.fill_between(np.arange(len(mean_agents[name][ind])) * test_every_n, (mean_agents[name][ind] - std_agents[name][ind])/max_mean, (mean_agents[name][ind] + std_agents[name][ind])/max_mean, color=colors_agents[name][ind], alpha=0.2)
            legend_agents[name] = line_agent
        
        # save legend handles
        legend_handles.append(legend_agents)

    plt.xlabel('time steps')
    plt.ylabel('normalized reward rate')
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_normalized_rewards.'+format)
    
    # plt.legend()
    # Create a legend with two columns: RTM and norm
    #     legend_fig, legend_ax = plt.subplots(figsize=(4.8, 4))
    # legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_fig, legend_ax = plt.subplots(figsize=(1.2*len(agent_names), 2)) #len(agent_names)))
    legend_ax.axis('off')

    table_data = []
    if ablation:
        exp_labels = ['0.05','0.1','0.2']
        for i, (line_rtm, line_time_only, line_mag_only) in enumerate(legend_handles):
            table_data.append([line_rtm, line_time_only, line_mag_only])

        # Add legend lines to table
        for row, (line_rtm, line_time_only, line_mag_only) in enumerate(table_data):
            y = 1 - row * 0.2
            legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=line_rtm.get_color(), linestyle=linestyles[0], lw=2))
            legend_ax.add_line(Line2D([0.5, 0.7], [y, y], color=line_time_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.add_line(Line2D([0.8, 1.0], [y, y], color=line_mag_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.text(0.0, y, f"{exp_labels[row]}", va='center')

        # Add column headers
        legend_ax.text(0.3, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.6, 1.05, "time only", ha='center', fontweight='bold')
        legend_ax.text(0.9, 1.05, "magnitude only", ha='center', fontweight='bold')
        legend_ax.text(0.04, 1.07, "probility of\nstimuli", ha='center', fontweight='bold')
    else:
        for i, line_agent in enumerate(legend_handles):
            table_data.append(line_agent)

        # Add legend lines to table
        spacing = 0.9/len(agent_names)
        for row, line_agent in enumerate(table_data):
            y = 1 - row * 0.2
            for col, name in enumerate(agent_names):
                legend_ax.add_line(Line2D([0.15 + col * spacing, spacing + col * spacing], [y, y], color=line_agent[name].get_color(), linestyle=linestyles[0], lw=2.5))
            legend_ax.text(0.0, y, f"{experiments[row]}", va='center')

        # Add column headers
        legend_ax.text(0.08 + spacing/2, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.08 + spacing + spacing/2, 1.05, "standard RL", ha='center', fontweight='bold')
        legend_ax.text(0.08 + 2 * spacing + spacing/2, 1.05, "QR-RL", ha='center', fontweight='bold')
        legend_ax.text(0.0, 1.1, "number of\nstimuli", ha='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    plt.xlabel('time steps')
    plt.ylabel('normalized reward per time step')
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'group_normalized_rewards_legend.'+format)




def across_stimuli_plots(group_folder, experiments, num_runs, agent_names, format='png', ablation=False):

    mean_agents, std_agents, test_every_n = get_mean_std_all_runs(group_folder, experiments, num_runs, agent_names)

    # set colormap for matplotlib
    cmap_agents = [plt.cm.Blues, plt.cm.Reds, plt.cm.Greens]
    colors_agents = {}
    for ind, name in enumerate(agent_names):
        colors_agents[name] = cmap_agents[ind]((len(experiments) - 0.5 - 1) / (len(experiments) - 1))

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    if 'rewardMTAgent' in agent_names:
        max_mean = np.max(np.array(mean_agents['rewardMTAgent'])[:, -100:], axis=-1)
    elif 'rewardMTAgent_particle' in agent_names:
        max_mean = np.max(np.array(mean_agents['rewardMTAgent_particle'])[:, -100:], axis=-1)
    # max_mean = np.max(mean_RTM[-100:], axis=-1)
    legend_handles = {}
    for name in agent_names:
        line_agent = ax.errorbar(experiments, mean_agents[name][:,-1]/max_mean, yerr=std_agents[name][:,-1]/max_mean, capsize=2, color=colors_agents[name])
        legend_handles[name] = line_agent
    ax.set_xticks(experiments)
    ax.spines['left'].set_linewidth(linewidth)
    ax.spines['bottom'].set_linewidth(linewidth)
    ax.tick_params(width=linewidth, length=length_ticks)

    # save legend handles
    # if ablation:
    #     legend_handles= [line_rtm[0], line_time_only[0], line_mag_only[0]]
    
    if ablation:
        plt.xlabel('number of stimuli')
    else:
        plt.xlabel('number of stimuli')
    plt.ylabel('normalized reward rate')
    # plt.legend()
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_final_reward_across_stimuli.'+format)

    legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_ax.axis('off')

    table_data = []
    for i, line in enumerate(legend_handles):
        table_data.append([line])

    if ablation:
        # names = ['TMRL', 'time distribution only', 'magnitude distribution only']
        names = ['TMRL', 'time only', 'magnitude only']
        colors = []
        for name in agent_names:
            colors.append(colors_agents[name])
    else:
        names = ['TMRL', 'standard RL', 'QR-RL']
        colors = []
        for name in agent_names:
            colors.append(colors_agents[name])
    # Add legend lines to table
    for row, line in enumerate(table_data):
        y = 1 - row * 0.3
        ylab = 1 - row * 0.3 + 0.1
        legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=colors[row],  lw=2))
        # legend_ax.add_line(Line2D([0.6, 0.8], [y, y], color=colors[1],  lw=2))
        legend_ax.text(0.2, ylab, f"{names[row]}", va='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'across_stimuli_legend.'+format)
    



def plot_2D_distribution(format='svg'):
    from scipy.stats import multivariate_normal

    # Define grid
    x, y = np.mgrid[-5:5:.01, -5:5:.01]
    pos = np.dstack((x, y))

    # Define two Gaussian distributions (means and covariances)
    rv1 = multivariate_normal(mean=[-2, -2], cov=[[1, 0], [0, 0.3]])
    rv2 = multivariate_normal(mean=[2, 2], cov=[[1, 0], [0, 0.3]])

    # Evaluate both distributions and sum them
    z = rv1.pdf(pos) + rv2.pdf(pos)

    print('max',np.max(z))

    def sigmoid(x):
        max_x = np.max(x)
        return 1.0/ (1 + np.exp(-50*(x - 0.5 * max_x)))

    cmap = 'coolwarm'
    # Plot
    plt.figure(figsize=(3,3))
    plt.contourf(x, y, sigmoid(z), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_0.'+format)

    plt.figure(figsize=(3,3))
    rv3 = multivariate_normal(mean=[-2, 2], cov=[[1, 0], [0, 0.3]])
    a = rv3.pdf(pos)
    plt.contourf(x, y, sigmoid(a), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_1.'+format)

    plt.figure(figsize=(3,3))
    rv4 = multivariate_normal(mean=[2, -2], cov=[[1, 0], [0, 0.3]])
    b = rv4.pdf(pos)
    plt.contourf(x, y, sigmoid(b), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_2.'+format)





# if __name__ == '__main__':

#     # create folder if it doesn't already exist
#     if os.path.exists(saveFolder):
#         print('overwrite')
#     else:
#         os.makedirs(saveFolder)

#     # train_just_norm()
#     train_multiple_agents(num_runs, num_cues=num_cues, map_size=map_size, max_reward_delay=max_reward_delay, max_reward_magnitude=max_reward_magnitude, alpha=alpha, cue_probs=cue_probs, train_timesteps=train_timesteps, test_timesteps=test_timesteps, test_every_n=test_every_n)
#     # plot_saved_rewards()
#     # cue_nums_to_plot = np.arange(2,4)    
#     # plot_different_cue_numbers(cue_nums_to_plot)

# if __name__ == '__main__':

#     group_folder = 'gridworld_group_14'
#     experiments = [3,4, 5]
#     comparison_plots(group_folder, experiments, num_runs, format='svg', opt=False)
#     across_stimuli_plots(group_folder, experiments, num_runs, format='svg')

    # plot_2D_distribution()



if __name__ == '__main__':

    group_folder = 'gridworld_particle_group'
    experiments = [3,4, 5]
    comparison_plots(group_folder, experiments, 10, ['rewardMTAgent_particle', 'normalAgent', 'QuantileRLAgent'], format='svg', opt=False)
    across_stimuli_plots(group_folder, experiments, 10, ['rewardMTAgent_particle', 'normalAgent', 'QuantileRLAgent'], format='svg')