'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the reward that will be given
in each state.
We assume here that you have reach any state from any other state with no travel time.
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
import os, json
from dataclasses import dataclass, asdict

from no_travel_time_cue_reward_time_and_mag import noTravelTimeEnv
from no_travel_time_cue_reward_time_and_mag import noTravelAgent
from no_travel_time_cue_reward_time_and_mag import particle_noTravelAgent
from no_travel_time_cue_reward_time_and_mag import normalAgent
from no_travel_time_cue_reward_time_and_mag import noTravelAgent_1D_time_dist
from no_travel_time_cue_reward_time_and_mag import noTravelAgent_1D_magnitude_dist
from no_travel_time_cue_reward_time_and_mag import particle_time_only_agent, particle_mag_only_agent
from no_travel_time_cue_reward_time_and_mag import train_agent
from no_travel_time_cue_reward_time_and_mag import generate_reward_magnitude_time_matrices, generate_reward_magnitude_time_matrices_difficult, plot_reward_magnitude_time_matrices, plot_rewards, plot_value_function, plot_1D_matrix_and_expected_values
# from gridworld_group_plots import comparison_plots, across_stimuli_plots


plot_folder = ''



# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.5
vertical_size = 1.5
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2

def set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states):
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    test_env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    return env, test_env, state, cue, reward


# run agent in environment
# def train_both(run_num, reward_magnitude_time_matrices, cue_probs, num_states, alpha, train_timesteps, test_timesteps, test_every_n, include_1D= False):
def train_both(run_num, reward_magnitude_time_matrices, config):
    cue_probs = config.cue_probs
    num_states = config.num_states
    alpha = config.alpha
    train_timesteps = config.train_timesteps
    test_timesteps = config.test_timesteps
    test_every_n = config.test_every_n
    include_1D = config.include_1D
    particle_agent = config.particle_agent
    saveFolder = config.saveFolder
    max_reward_delay = config.max_reward_delay
    max_reward_magnitude = config.max_reward_magnitude

    num_states = reward_magnitude_time_matrices.shape[1]
    # randomly generate reward time and magnitude matrices
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    

    # train no travel time agent
    # create environment
    env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states)
    # create agent
    # agent = noTravelAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    # agent = noTravelAgent(num_states, state, init_reward_MT, alpha)
    if particle_agent:
        agent = particle_noTravelAgent(num_states, state, init_reward_MT, alpha)
    else:
        agent = noTravelAgent(num_states, state, init_reward_MT, alpha, risk_weight = config.risk_weight)
    
    rewards_no_travel, test_rew_noTravel = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    plot_rewards(saveFolder, test_rew_noTravel, 'test_rewards_no_travel_'+str(run_num))

    print('Average reward over time noTravelAgent: ', np.mean(rewards_no_travel))

    # train normal agent
    # create environment
    env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states)
    # create agent
    # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
    normAgent = normalAgent(num_states, state, init_reward_MT, alpha)
            
    rewards_normAgent, test_rew_norm = train_agent(normAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Average reward over time normalAgent: ', np.mean(rewards_normAgent))

    if include_1D:
        # train time distrubtuion only agent
        # create environment
        env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states)
        # create agent
        # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
        # timeAgent = noTravelAgent_1D_time_dist(num_states, state, init_reward_MT, alpha)
        timeAgent = particle_time_only_agent(num_states, state, init_reward_MT, alpha)

        rewards_timeAgent, test_rew_time = train_agent(timeAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        print('Average reward over time timeAgent: ', np.mean(rewards_timeAgent))

        # plot rewards
        # plot_rewards(saveFolder, rewards_timeAgent, 'rewards_timeAgent'+'_'+str(run_num))
        plot_rewards(saveFolder, test_rew_time, 'test_rewards_timeAgent'+'_'+str(run_num))

        # train magnitude distrubution only agent
        # create environment
        env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states)
        # create agent
        # normAgent = normalAgent(num_states, state, reward_magnitude_time_matrices, alpha)
        # magAgent = noTravelAgent_1D_magnitude_dist(num_states, state, init_reward_MT, alpha)             
        magAgent = particle_mag_only_agent(num_states, state, init_reward_MT, alpha)             
        
        rewards_magAgent, test_rew_mag = train_agent(magAgent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
        print('Average reward over time magAgent: ', np.mean(rewards_magAgent))

        # plot rewards
        # plot_rewards(saveFolder, rewards_magAgent, 'rewards_magAgent'+'_'+str(run_num))
        # plot_rewards(saveFolder, test_rew_mag, 'test_rewards_magAgent'+'_'+str(run_num))

        true_magnitude_expected_values = np.sum(reward_magnitude_time_matrices * (np.arange(max_reward_magnitude) + 1), axis=(-1, -2))
        plot_1D_matrix_and_expected_values(saveFolder, true_magnitude_expected_values, timeAgent.reward_MT, timeAgent.expected_return, 'learned_time_delay_vector_'+str(run_num))

        true_delay_expected_values = np.sum(reward_magnitude_time_matrices * (np.arange(max_reward_delay) + 1)[:,None], axis=(-1, -2))
        plot_1D_matrix_and_expected_values(saveFolder, true_delay_expected_values, magAgent.reward_MT, magAgent.expected_delay, 'learned_magnitude_vector_'+str(run_num))


    
    
    
    plot_reward_magnitude_time_matrices(saveFolder, agent.reward_MT, 'learned_reward_time_magnitude_matrices_per_state'+'_'+str(run_num))
    plot_value_function(saveFolder, normAgent.value, 'learned_value_function'+'_'+str(run_num))

    if include_1D:
        return rewards_no_travel, test_rew_noTravel, rewards_normAgent, test_rew_norm, rewards_timeAgent, test_rew_time, rewards_magAgent, test_rew_mag
    else:
        return rewards_no_travel, test_rew_noTravel, rewards_normAgent, test_rew_norm, 0, 0, 0, 0


# def train_multiple_agents(num_runs, num_cues, num_states, max_reward_delay, max_reward_magnitude, alpha, cue_probs, train_timesteps, test_timesteps, test_every_n, include_1D=False):
def train_multiple_agents(generate_rew_mag_time_mat, config):
    num_states = config.num_states
    num_cues = config.num_cues
    max_reward_delay = config.max_reward_delay
    max_reward_magnitude = config.max_reward_magnitude
    saveFolder = config.saveFolder



    # save parameters
    # Save to JSON
    with open(saveFolder + "parameters.json", "w") as f:
        json.dump(asdict(config), f, indent=2)
    # np.savez(saveFolder + 'parameters.npz', num_states=num_states, max_reward_delay=max_reward_delay, max_reward_magnitude=max_reward_magnitude, alpha=alpha, cue_probs=cue_probs, train_timesteps=train_timesteps, test_timesteps=test_timesteps, test_every_n=test_every_n)


    # generate random reward time and magnitude matrix
    # reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    # reward_magnitude_time_matrices = generate_reward_magnitude_time_matrices_difficult(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    reward_magnitude_time_matrices = generate_rew_mag_time_mat(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    # plot reward time and magnitude matrices
    print('reward_magnitude_time_matrices', reward_magnitude_time_matrices.shape)
    plot_reward_magnitude_time_matrices(saveFolder, reward_magnitude_time_matrices, 'reward_time_magnitude_matrices_per_state')

    rewards_noTravel = []
    rewards_norm = []
    rewards_time_only = []
    rewards_mag_only = []

    test_rewards_noTravel = []
    test_rewards_norm = []
    test_rewards_time_only = []
    test_rewards_mag_only = []
    # run multiple agents
    for run_num in range(config.num_runs):
        print('run number: ', run_num)
        # train agent
        # r_noTravel, r_test_noTravel, r_norm, r_test_norm, r_time, r_test_time, r_mag, r_test_mag = train_both(run_num, reward_magnitude_time_matrices, cue_probs, num_states, alpha, train_timesteps, test_timesteps, test_every_n, include_1D=include_1D)
        r_noTravel, r_test_noTravel, r_norm, r_test_norm, r_time, r_test_time, r_mag, r_test_mag = train_both(run_num, reward_magnitude_time_matrices, config)
        rewards_noTravel.append(r_noTravel)
        rewards_norm.append(r_norm)
        test_rewards_noTravel.append(r_test_noTravel)
        test_rewards_norm.append(r_test_norm)
        if config.include_1D:
            rewards_time_only.append(r_time)
            rewards_mag_only.append(r_mag)
            test_rewards_time_only.append(r_test_time)
            test_rewards_mag_only.append(r_test_mag)
        # save rewards
        np.savez(saveFolder + 'rewards_' + str(run_num) + '.npz', rewards_noTravel=r_noTravel, test_rewards_noTravel=r_test_noTravel, rewards_norm=r_norm, test_rewards_norm=r_test_norm, rewards_time_only=r_time, test_rewards_time_only=r_test_time, rewards_mag_only=r_mag, test_rewards_mag_only=r_test_mag)

    # make plots of average rewards
    if config.include_1D:
        plot_average_rewards(saveFolder, np.array(test_rewards_noTravel), np.array(test_rewards_norm), np.array(test_rewards_time_only), np.array(test_rewards_mag_only))
    else:
        plot_average_rewards(saveFolder, np.array(test_rewards_noTravel), np.array(test_rewards_norm), None, None)


# plot average rewards from multiple runs
def plot_average_rewards(saveFolder, rewards_noTravel, rewards_norm, rewards_time=None, rewards_mag=None, labels=None):
    # plot average rewards
    plt.figure()
    if labels is None:
        labels = ['TMRL', 'standard RL', 'mag-ablated', 'time-ablated']
    plt.plot(np.mean(rewards_noTravel, axis=0), label=labels[0])
    plt.plot(np.mean(rewards_norm, axis=0), label=labels[1])
    plt.fill_between(range(len(np.mean(rewards_noTravel, axis=0))), np.mean(rewards_noTravel, axis=0) - np.std(rewards_noTravel, axis=0), np.mean(rewards_noTravel, axis=0) + np.std(rewards_noTravel, axis=0), alpha=0.2)
    plt.fill_between(range(len(np.mean(rewards_norm, axis=0))), np.mean(rewards_norm, axis=0) - np.std(rewards_norm, axis=0), np.mean(rewards_norm, axis=0) + np.std(rewards_norm, axis=0), alpha=0.2)
    if rewards_time is not None:
        plt.plot(np.mean(rewards_time, axis=0), label=labels[2])
        plt.plot(np.mean(rewards_mag, axis=0), label=labels[3])
        plt.fill_between(range(len(np.mean(rewards_time, axis=0))), np.mean(rewards_time, axis=0) - np.std(rewards_time, axis=0), np.mean(rewards_time, axis=0) + np.std(rewards_time, axis=0), alpha=0.2)
        plt.fill_between(range(len(np.mean(rewards_mag, axis=0))), np.mean(rewards_mag, axis=0) - np.std(rewards_mag, axis=0), np.mean(rewards_mag, axis=0) + np.std(rewards_mag, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + 'average_rewards.png')


def plot_saved_rewards():
    rewards_noTravel = []
    rewards_norm = []
    test_rewards_noTravel = []
    test_rewards_norm = []
    for run_num in range(num_runs):
        # load rewards
        data = np.load(saveFolder + 'rewards_' + str(run_num) + '.npz')
        rewards_noTravel.append(data['rewards_noTravel'])
        test_rewards_noTravel.append(data['test_rewards_noTravel'])
        rewards_norm.append(data['rewards_norm'])
        test_rewards_norm.append(data['test_rewards_norm'])
    plot_average_rewards(saveFolder, np.array(test_rewards_noTravel), np.array(test_rewards_norm))



def get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=False):
    mean_RTM = []
    std_RTM = []
    mean_norm = []
    std_norm = []
    mean_time_only = []
    std_time_only = []
    mean_mag_only = []
    std_mag_only = []
    if opt:
        mean_opt = []

    for exp in experiments:
        folder = plot_folder + group_folder + '_cues_' + str(exp) +'/'

        # parameters = np.load(folder +'parameters.npz')

        # save rewards
        rewards_noTravel = []
        rewards_norm = []
        test_rewards_noTravel = []
        test_rewards_norm = []
        if ablation:
            rewards_time_only = []
            rewards_mag_only = []
            test_rewards_time_only = []
            test_rewards_mag_only = []

        for run_num in range(num_runs):
            # load rewards
            data = np.load(folder + 'rewards_' + str(run_num) + '.npz')
            # print('data',folder, data)
            rewards_noTravel.append(data['rewards_noTravel'])
            test_rewards_noTravel.append(data['test_rewards_noTravel'])
            rewards_norm.append(data['rewards_norm'])
            test_rewards_norm.append(data['test_rewards_norm'])
            if ablation:
                rewards_time_only.append(data['rewards_time_only'])
                test_rewards_time_only.append(data['test_rewards_time_only'])
                rewards_mag_only.append(data['rewards_mag_only'])
                test_rewards_mag_only.append(data['test_rewards_mag_only'])

        if opt:
            # load optimal agent rewards
            opt_data = np.load(folder + 'rewards_optimal.npz')
            opt_rewards = opt_data['opt_rewards']
            test_opt_rewards = opt_data['test_opt_rewards']
            mean_opt.append(np.array(test_opt_rewards))

        mean_rew_TMRL = np.mean(np.array(test_rewards_noTravel), axis=0)
        std_rew_TMRL = np.std(np.array(test_rewards_noTravel), axis=0)
        mean_rew_norm = np.mean(np.array(test_rewards_norm), axis=0)
        std_rew_norm = np.std(np.array(test_rewards_norm), axis=0)
        mean_RTM.append(mean_rew_TMRL)
        std_RTM.append(std_rew_TMRL)
        mean_norm.append(mean_rew_norm)
        std_norm.append(std_rew_norm)
        if ablation:
            mean_rew_time_only = np.mean(np.array(test_rewards_time_only), axis=0)
            std_rew_time_only = np.std(np.array(test_rewards_time_only), axis=0)
            mean_rew_mag_only = np.mean(np.array(test_rewards_mag_only), axis=0)
            std_rew_mag_only = np.std(np.array(test_rewards_mag_only), axis=0)
            mean_time_only.append(mean_rew_time_only)
            std_time_only.append(std_rew_time_only)
            mean_mag_only.append(mean_rew_mag_only)
            std_mag_only.append(std_rew_mag_only)

        # plot_rewards_all(folder, num_runs, test_rewards_noTravel)

        # if opt:
        #     plot_average_rewards_with_opt(folder, mean_rew_TMRL, std_rew_TMRL, mean_rew_norm, std_rew_norm, np.array(test_opt_rewards))

    if opt:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm), np.array(mean_opt)
    elif ablation:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm), np.array(mean_time_only), np.array(std_time_only), np.array(mean_mag_only), np.array(std_mag_only)
    else:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm)


def comparison_plots(group_folder, experiments, num_runs, format='png', opt=False, ablation=False):
    
    if opt:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_opt = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=True)
    elif ablation:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_time_only, std_time_only, mean_mag_only, std_mag_only = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=True)
    else:
        mean_RTM, std_RTM, mean_norm, std_norm = get_mean_std_all_runs(group_folder, experiments, num_runs)

    print('mean RTM shape',len(mean_RTM))


    # set colormap for matplotlib
    cmap_RTM = plt.cm.Blues
    cmap_norm = plt.cm.Reds
    colors_RTM = [cmap_RTM((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    colors_norm = [cmap_norm((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    if ablation:
        cmap_time_only = plt.cm.Greens
        cmap_mag_only = plt.cm.Oranges
        colors_time_only = [cmap_time_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
        colors_mag_only = [cmap_mag_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    print('colors rtm',colors_RTM,colors_norm)
    # line styles
    linestyles = ['-', '-']

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    # Store line handles for custom legend
    legend_handles = []
    # plt.figure()
    for ind, exp in enumerate(experiments):
        max_mean = np.max(mean_RTM[ind])
        print('max mean',max_mean)

        line_rtm, = ax.plot(np.arange(len(mean_RTM[ind])) * test_every_n, mean_RTM[ind]/max_mean, color=colors_RTM[ind])
        ax.fill_between(np.arange(len(mean_RTM[ind])) * test_every_n, (mean_RTM[ind] - std_RTM[ind])/max_mean, (mean_RTM[ind] + std_RTM[ind])/max_mean, color=colors_RTM[ind], alpha=0.2)
        if ablation:
            line_time_only, = ax.plot(np.arange(len(mean_time_only[ind])) * test_every_n, mean_time_only[ind]/max_mean, color=colors_time_only[ind])
            ax.fill_between(np.arange(len(mean_time_only[ind])) * test_every_n, (mean_time_only[ind] - std_time_only[ind])/max_mean, (mean_time_only[ind] + std_time_only[ind])/max_mean, color=colors_time_only[ind], alpha=0.2)
            line_mag_only, = ax.plot(np.arange(len(mean_mag_only[ind])) * test_every_n, mean_mag_only[ind]/max_mean, color=colors_mag_only[ind])
            ax.fill_between(np.arange(len(mean_mag_only[ind])) * test_every_n, (mean_mag_only[ind] - std_norm[ind])/max_mean, (mean_mag_only[ind] + std_mag_only[ind])/max_mean, color=colors_mag_only[ind], alpha=0.2) 
        else:
            line_norm, = ax.plot(np.arange(len(mean_RTM[ind])) * test_every_n, mean_norm[ind]/max_mean, color=colors_norm[ind])
            ax.fill_between(np.arange(len(mean_RTM[ind])) * test_every_n, (mean_norm[ind] - std_norm[ind])/max_mean, (mean_norm[ind] + std_norm[ind])/max_mean, color=colors_norm[ind], alpha=0.2)
        
        # save legend handles
        if ablation:
            legend_handles.append((line_rtm, line_time_only, line_mag_only))
        else:
            legend_handles.append((line_rtm, line_norm))

    plt.xlabel('time steps')
    plt.ylabel('normalized reward rate')
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_normalized_rewards.'+format)
    
    # plt.legend()
    # Create a legend with two columns: RTM and norm
    if ablation:
        legend_fig, legend_ax = plt.subplots(figsize=(4.8, 4))
    else:
        legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_ax.axis('off')

    table_data = []
    if ablation:
        exp_labels = ['0.05','0.1','0.2']
        for i, (line_rtm, line_time_only, line_mag_only) in enumerate(legend_handles):
            table_data.append([line_rtm, line_time_only, line_mag_only])

        # Add legend lines to table
        for row, (line_rtm, line_time_only, line_mag_only) in enumerate(table_data):
            print('line_rtm',line_rtm)
            y = 1 - row * 0.2
            legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=line_rtm.get_color(), linestyle=linestyles[0], lw=2))
            legend_ax.add_line(Line2D([0.5, 0.7], [y, y], color=line_time_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.add_line(Line2D([0.8, 1.0], [y, y], color=line_mag_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.text(0.0, y, f"{exp_labels[row]}", va='center')

        # Add column headers
        legend_ax.text(0.3, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.6, 1.05, "time only", ha='center', fontweight='bold')
        legend_ax.text(0.9, 1.05, "magnitude only", ha='center', fontweight='bold')
        legend_ax.text(0.04, 1.07, "probility of\nstimuli", ha='center', fontweight='bold')
    else:
        for i, (line_rtm, line_norm) in enumerate(legend_handles):
            table_data.append([line_rtm, line_norm])

        # Add legend lines to table
        for row, (line_rtm, line_norm) in enumerate(table_data):
            print('line_rtm',line_rtm)
            y = 1 - row * 0.2
            legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=line_rtm.get_color(), linestyle=linestyles[0], lw=2))
            legend_ax.add_line(Line2D([0.6, 0.8], [y, y], color=line_norm.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.text(0.0, y, f"{experiments[row]}", va='center')

        # Add column headers
        legend_ax.text(0.3, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.7, 1.05, "standard RL", ha='center', fontweight='bold')
        legend_ax.text(0.0, 1.1, "number of\nstimuli", ha='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    plt.xlabel('time steps')
    plt.ylabel('normalized reward rate')
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'group_normalized_rewards_legend.'+format)


def across_stimuli_plots(group_folder, experiments, num_runs, format='png', ablation=False):

    if ablation:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_time_only, std_time_only, mean_mag_only, std_mag_only = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=True)
    else:
        mean_RTM, std_RTM, mean_norm, std_norm = get_mean_std_all_runs(group_folder, experiments, num_runs)

    # get colors
    cmap_RTM = plt.cm.Blues
    cmap_norm = plt.cm.Reds
    i = 1
    colors_RTM = cmap_RTM((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1))
    colors_norm = cmap_norm((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 
    if ablation:
        cmap_time_only = plt.cm.Greens
        cmap_mag_only = plt.cm.Oranges
        colors_time_only = cmap_time_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 
        colors_mag_only = cmap_mag_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    max_mean = np.max(mean_RTM[-100:], axis=-1)
    line_rtm = ax.errorbar(experiments, mean_RTM[:,-1]/max_mean, yerr=std_RTM[:,-1]/max_mean, capsize=2, color=colors_RTM)
    if ablation:
        line_time_only = ax.errorbar(experiments, mean_time_only[:,-1]/max_mean, yerr=std_time_only[:,-1]/max_mean, capsize=2, color=colors_time_only)
        line_mag_only = ax.errorbar(experiments, mean_mag_only[:,-1]/max_mean, yerr=std_mag_only[:,-1]/max_mean, capsize=2, color=colors_mag_only)
    else:
        line_norm = ax.errorbar(experiments, mean_norm[:,-1]/max_mean, yerr=std_norm[:,-1]/max_mean, capsize=2, color=colors_norm)
    ax.spines['left'].set_linewidth(linewidth)
    ax.spines['bottom'].set_linewidth(linewidth)
    ax.tick_params(width=linewidth, length=length_ticks)

    # save legend handles
    print('line_rtm', line_rtm[0])
    if ablation:
        legend_handles= [line_rtm[0], line_time_only[0], line_mag_only[0]]
    else:
        legend_handles= [line_rtm[0], line_norm[0]]


    plt.xlabel('number of stimuli')
    plt.ylabel('normalized reward rate')
    # plt.legend()
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_final_reward_across_stimuli.'+format)

    legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_ax.axis('off')

    table_data = []
    for i, line in enumerate(legend_handles):
        table_data.append([line])

    if ablation:
        # names = ['TMRL', 'time distribution only', 'magnitude distribution only']
        names = ['TMRL', 'time only', 'magnitude only']
        colors = [colors_RTM, colors_time_only, colors_mag_only]
    else:
        names = ['TMRL', 'standard RL']
        colors = [colors_RTM, colors_norm]
    # Add legend lines to table
    for row, line in enumerate(table_data):
        y = 1 - row * 0.3
        ylab = 1 - row * 0.3 + 0.1
        legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=colors[row],  lw=2))
        # legend_ax.add_line(Line2D([0.6, 0.8], [y, y], color=colors[1],  lw=2))
        legend_ax.text(0.2, ylab, f"{names[row]}", va='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'across_stimuli_legend.'+format)
    



# if __name__ == '__main__':

#     # create folder if it doesn't already exist
#     if os.path.exists(saveFolder):
#         print('overwrite')
#     else:
#         os.makedirs(saveFolder)

#     train_multiple_agents(num_runs, num_cues=num_cues, num_states=num_states, max_reward_delay=max_reward_delay, max_reward_magnitude=max_reward_magnitude, alpha=alpha, cue_probs=cue_probs, train_timesteps=train_timesteps, test_timesteps=test_timesteps, test_every_n=test_every_n, include_1D=include_1D)
#     # plot_saved_rewards()
#     # cue_nums_to_plot = np.arange(2,4)    
#     # plot_different_cue_numbers(cue_nums_to_plot)


if __name__ == '__main__':

    group_folder = 'particle_patch_group_3'
    experiments = [2,3,4,5]
    comparison_plots(group_folder, experiments, num_runs=10, format='svg')
    across_stimuli_plots(group_folder, experiments, num_runs=10, format='svg')



# if __name__ == '__main__':

#     group_folder = 'particle_patch_group_7'
#     experiments = ['3_p_005', '3_p_01', '3_p_02']#, '5_p_03']
#     comparison_plots(group_folder, experiments, num_runs, format='svg', opt=False, ablation=True)
#     across_stimuli_plots(group_folder, experiments, num_runs, format='svg', ablation=True)

