'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the probability of reward that will be given
in each state.
We use a square gridworld environment.
Here, we try to optimize for homeostatic losses.
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
import os,  json
from dataclasses import dataclass, asdict

from gridwork_reward_time_and_mag import gridWorldEnv, random_position#, train_agent
from gridwork_reward_time_and_mag import rewardMTAgent, rewardMTAgent_particle, homeostatic_agent, normalAgent, homeostatic_agent_particle
from gridwork_reward_time_and_mag import generate_grid_reward_magnitude_time_matrices, plot_reward_magnitude_time_matrices, plot_rewards, plot_value_function, plot_reward_MT_positions
from patch_risk_sensitivity import plot_average_rewards, plot_reward_histograms
from gridwork_reward_time_and_mag import plot_policy


# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.5
vertical_size = 1.5
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2


plot_folder = ''
experiment_name = 'homeostatic_gridworld_group_cues_2/'
saveFolder = plot_folder + experiment_name


# seed
np.random.seed(42)
# max reward delay
max_reward_delay = 10
# max reward magnitude
max_reward_magnitude = 4

# DNL learning parameters
# To compute gradient
dx_der = 0.1
dy_der = 0.1
x_der, y_der = np.mgrid[-1:(max_reward_delay + 1):dx_der,-1:(max_reward_magnitude + 1):dy_der]
Nx_der = x_der.shape[0]
Ny_der = y_der.shape[1]
x_der_flat = np.expand_dims(np.ndarray.flatten(x_der),axis=1)
y_der_flat = np.expand_dims(np.ndarray.flatten(y_der),axis=1)
particles_der = np.concatenate((x_der_flat,y_der_flat),axis=1)
n_interactions = 100
batch_size = 100
bw = 1.0 #1.0 #0.8
lamb=  0.8  * np.max([max_reward_delay, max_reward_magnitude]) # 1.0
particle_gamma = 1000.0 #100.0 #2.0 #0.09*16
learning_rates = 0.0001 * np.ones(10000)
learning_rates[:5000] = np.linspace(0.001, 0.0001, 5000)
n_particles = 10 * 10 #5*5 
cov = 0.5*np.array([[1, 0], [0, 1]])
cov_shrink = np.ones(n_interactions * 1000) * 0.01
cov_shrink[:n_interactions] = np.linspace(1.0, 0.01, n_interactions)


def internal_state_update(internal_state):
    gamma = 0.8
    next_internal_state = internal_state * gamma
    if next_internal_state is float:
        if next_internal_state > 2.0:
            next_internal_state = 2.0
    else:
        next_internal_state = np.array(next_internal_state)
        next_internal_state[next_internal_state > 2.0] = 2.0
    return next_internal_state

def homeostatic_loss(internal_state):
    return (1 - 0.1 / (internal_state + 0.1))**2

def raw_reward_to_homeostatic_loss(internal_state_init, rewards):
    print('FULL REWARD SHAPE ',rewards.shape)
    internal_states = np.ones(rewards.shape) * internal_state_init
    for t in range(rewards.shape[-1] - 1):
        internal_states[:, t+1] = internal_state_update(internal_states[:, t]) + rewards[:, t]
    return homeostatic_loss(internal_states)

# def risk_function(a, matrix):
#     return np.ones(matrix.shape)

@dataclass
class TrainConfig:
    num_runs: int = 10 # number of runs
    map_size: int = 5 # number of states is map_size x map_size
    num_cues: int = 2 # number of cues
    max_reward_delay: int = max_reward_delay # max reward delay
    max_reward_magnitude: int = max_reward_magnitude # max reward magnitude
    alpha: float = 0.01
    cue_probs = np.array([0.05, 0.1]) #np.array([0.2, 0.4, 0.05]) #0.2 * np.ones(num_cues) # cue probabilities at each time
    train_timesteps: int = 200000 # number of time steps for training or acting
    test_timesteps: int = 1000 # number of time steps for testing
    test_every_n: int = 1000 # test reward rate every time steps
    saveFolder: str = saveFolder
    # discount factor
    discount_gamma: float = 0.98
    # epsilon greedy exploration
    epsilon: float = 0.1
    # initialize internal state at this value
    internal_state_init: float = 1
    # risk weights for TMD
    risk_function = None #risk_function
    # DNL particle parameters
    n_particles: int = n_particles
    cov = cov
    cov_shrink = cov_shrink
    batch_size: int = batch_size
    dx_der: float = dx_der
    dy_der: float = dy_der
    bw: float = bw
    learning_rates = learning_rates
    x_der = x_der
    y_der = y_der
    particles_der = particles_der
    n_interactions: int = n_interactions
    lamb: float = lamb
    particle_gamma: float = particle_gamma
    Nx_der: int = Nx_der
    Ny_der: int = Ny_der
    # test cue simultaneous
    # test_cue_simultaneous: bool = True




def generate_grid_reward_magnitude_time_matrices_homeostatis(num_cues, map_size, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    reward_magnitude_time_matrices = np.zeros((num_cues, map_size, map_size, max_reward_delay, max_reward_magnitude))
    rewarded_positions = []
    # number of possible reward time and mags

    reward_magnitude_time_matrices[0, 4, 0, max_reward_delay - 1, max_reward_magnitude - 1] = 1
    reward_magnitude_time_matrices[1, 0, 4, max_reward_delay - 3, 1] = 1
    rewarded_positions = [(4,0), (0,4)]
    # rewarded_positions = np.insert(rewarded_positions, 0, (4, 0), axis=0)
    # rewarded_positions = np.insert(rewarded_positions, 1, (0, 4), axis=0)
    print('rewarded_positions',rewarded_positions)
    return reward_magnitude_time_matrices, rewarded_positions


def set_train_test_env(reward_magnitude_time_matrices, cue_probs, map_size, test_cue_simultaneous=False):
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs * 3, cue_simultaneous=test_cue_simultaneous)
    # reset environment
    state, cue = env.reset()
    reward = 0
    return env, test_env, state, cue, reward


def train_multiple_agents(generate_rew_mag_time_mat, config, agent_classes):
    """
    Train multiple agents based on the provided agent classes.

    Args:
        generate_rew_mag_time_mat (function): Function to generate reward magnitude time matrices.
        config (TrainConfig): Configuration object with training parameters.
        agent_classes (list): List of agent classes to initialize and train.

    Returns:
        None
    """
    map_size = config.map_size
    num_cues = config.num_cues
    max_reward_delay = config.max_reward_delay
    max_reward_magnitude = config.max_reward_magnitude
    saveFolder = config.saveFolder

    # Save parameters
    with open(saveFolder + "parameters.json", "w") as f:
        json.dump(asdict(config), f, indent=2)

    # Generate reward magnitude time matrices
    reward_magnitude_time_matrices, rewarded_positions = generate_rew_mag_time_mat(num_cues, map_size, max_reward_delay, max_reward_magnitude)
    print('reward_magnitude_time_matrices', reward_magnitude_time_matrices.shape)
    print('rewarded_positions', rewarded_positions)
    plot_reward_magnitude_time_matrices(reward_magnitude_time_matrices, rewarded_positions, 'reward_time_magnitude_matrices_per_state', saveFolder=saveFolder)

    # Initialize rewards storage
    rewards = {agent_class.__name__: [] for agent_class in agent_classes}
    test_rewards = {agent_class.__name__: [] for agent_class in agent_classes}
    test_homeo_loss = {agent_class.__name__: [] for agent_class in agent_classes}

    # Train agents
    for run_num in range(config.num_runs):
        print('Run number:', run_num)
        for agent_class in agent_classes:
            print(f'Training agent: {agent_class.__name__}')
            # Train the agent
            r, r_test, homeo_loss_test, agent = train_and_test(run_num, agent_class, reward_magnitude_time_matrices, config)
            # Store rewards
            rewards[agent_class.__name__].append(r)
            test_rewards[agent_class.__name__].append(r_test)
            test_homeo_loss[agent_class.__name__].append(homeo_loss_test)
            print('rewards list shapes',np.array(rewards[agent_class.__name__]).shape, np.array(test_rewards[agent_class.__name__]).shape, np.array(test_homeo_loss[agent_class.__name__]).shape)
            
            # Save rewards
            np.savez(saveFolder + f'rewards_{agent_class.__name__}_{run_num}.npz', rewards=r, test_rewards=r_test, homeostatic_loss=np.mean(homeo_loss_test,axis=-1))
            if (agent_class.__name__ == 'magnitude_risk_agent') or (agent_class.__name__ == 'rewardMTAgent') or (agent_class.__name__ == 'homeostatic_agent') or (agent_class.__name__ == 'normalAgent'):
                plot_reward_magnitude_time_matrices(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state_'+agent_class.__name__, saveFolder=saveFolder)
                plot_reward_magnitude_time_matrices(agent.reward_MT, np.roll(np.array(rewarded_positions), 1, axis=0), 'learned_reward_time_magnitude_matrices_per_state_rolled_1_'+agent_class.__name__, saveFolder=saveFolder)
                plot_reward_magnitude_time_matrices(agent.reward_MT, np.roll(np.array(rewarded_positions), 2, axis=0), 'learned_reward_time_magnitude_matrices_per_state_rolled_2_'+agent_class.__name__, saveFolder=saveFolder)
                
            if  (agent_class.__name__ == 'magnitude_risk_particle_agent') or (agent_class.__name__ == 'rewardMTAgent_particle') or (agent_class.__name__ == 'homeostatic_agent_particle'):
                plot_reward_magnitude_time_matrices(agent.reward_MT, rewarded_positions, 'learned_reward_time_magnitude_matrices_per_state_'+agent_class.__name__, particles=agent.particles, saveFolder=saveFolder)
                plot_reward_magnitude_time_matrices(agent.reward_MT, np.roll(np.array(rewarded_positions), 1, axis=0), 'learned_reward_time_magnitude_matrices_per_state_rolled_1_'+agent_class.__name__, particles=agent.particles, saveFolder=saveFolder)
                plot_reward_magnitude_time_matrices(agent.reward_MT, np.roll(np.array(rewarded_positions), 2, axis=0), 'learned_reward_time_magnitude_matrices_per_state_rolled_2_'+agent_class.__name__, particles=agent.particles, saveFolder=saveFolder)
            policies = agent.get_policies()
            # print('POLICIES ', policies[0], policies[1], policies[2])#, policies[3,rewarded_positions[3][0], rewarded_positions[3][1]])
            plot_policy(saveFolder, map_size, policies, rewarded_positions, agent_class.__name__)
    # Plot average rewards
    print('test_rewards shapes',{np.array(i).shape for k,i in test_rewards.items()},{np.array(i).shape for k,i in test_homeo_loss.items()})
    plot_average_rewards(saveFolder, test_rewards)
    plot_average_rewards(saveFolder, test_homeo_loss, 'average_homeostatic_loss')
    return test_rewards, homeo_loss_test


def train_and_test(run_num, agent_class, reward_magnitude_time_matrices, config):
    cue_probs = config.cue_probs
    map_size = config.map_size
    train_timesteps = config.train_timesteps
    test_timesteps = config.test_timesteps
    test_every_n = config.test_every_n
    alpha = config.alpha
    discount_gamma = config.discount_gamma
    epsilon = config.epsilon
    risk_function = config.risk_function
    if 'test_cue_simultaneous' in config.__annotations__:
        test_cue_simultaneous = config.test_cue_simultaneous
    else:
        test_cue_simultaneous = False


    # create environment
    env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, map_size, test_cue_simultaneous=test_cue_simultaneous)
    # create agent
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    if agent_class == homeostatic_agent or agent_class == homeostatic_agent_particle:
        agent = agent_class(env, map_size, state, init_reward_MT, internal_state_update, homeostatic_loss, alpha, discount_gamma, epsilon, internal_state_init=config.internal_state_init, config=config)  
    else:
        agent = agent_class(env, map_size, state, init_reward_MT, alpha, discount_gamma, epsilon, risk_function, config)   
    train_rew, test_rew = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    # test_learned_TMD(agent.reward_MT, reward_magnitude_time_matrices, agent.magnitude_values)
    plot_rewards(np.mean(test_rew, axis=-1), 'test_rewards_' + agent_class.__name__ + '_' + str(run_num), saveFolder=config.saveFolder)
    if 'internal_state_init' in config.__annotations__:
        homeostatic_rew = raw_reward_to_homeostatic_loss(config.internal_state_init, test_rew)
    else:
        homeostatic_rew = np.zeros(test_rew.shape)
    return train_rew, test_rew, homeostatic_rew, agent



def test_agent(agent, env, num_timesteps, weights=None, prnt=False, epoch=0):
    # reset environment
    state, cue = env.reset()
    reward = 0
    # track rewards
    rewards = np.zeros(num_timesteps)
    # reset agent for test
    agent.reset_test(state)
    
    prev_reward = 0
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.test_act(state, cue, prev_reward, weights)
        # take action in environment
        reward, cue, state = env.step(action)
        if prnt:
            print('agent cue delay', agent.test_cue_delay, ' next cue ', cue)
            print('final test rew ',reward, ' state',state, ' action',action)            
        rewards[time] = reward
        prev_reward = reward
        
    return rewards

def train_agent(agent, env, test_env, num_timesteps, test_timesteps, test_every, init_state, init_cue, init_reward):
    prev_state = None
    prev_action = None
    state = init_state
    cue = init_cue
    reward = init_reward

    print('train agent with ', num_timesteps, ' timesteps', test_timesteps, test_every, init_state, init_cue, init_reward)
    # track rewards
    rewards = np.zeros(num_timesteps)
    # test_rewards = np.zeros(num_timesteps // test_every)
    full_test_rewards = np.zeros((num_timesteps // test_every, test_timesteps))
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.train(prev_state, prev_action, reward, state, cue)
        prev_action = action
        # take action in environment
        prev_state = state
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        # print('cue: ', cue, 'state: ', state, 'action: ', action, 'reward: ', reward, 'cur cue delay: ', agent.cur_cue_delay)
        # test agent every test_every time steps
        if time % test_every == 0:
            epoch = time // test_every
            # test agent
            full_test_rewards[time // test_every] = test_agent(agent, test_env, test_timesteps, epoch=epoch)


    # test_agent(agent, test_env, test_timesteps, prnt=False, epoch=num_timesteps // test_every)
    
    return rewards, full_test_rewards


def plot_saved_rewards():
    rewards_noTravel = []
    rewards_norm = []
    test_rewards_noTravel = []
    test_rewards_norm = []
    for run_num in range(num_runs):
        # load rewards
        data = np.load(saveFolder + 'rewards_' + str(run_num) + '.npz')
        rewards_noTravel.append(data['rewards_noTravel'])
        test_rewards_noTravel.append(data['test_rewards_noTravel'])
        rewards_norm.append(data['rewards_norm'])
        test_rewards_norm.append(data['test_rewards_norm'])
    plot_average_rewards(np.array(test_rewards_noTravel), np.array(test_rewards_norm))



# plot average rewards from multiple runs
def plot_rewards_all(folder, num_runs, rewards):
    # plot average rewards
    plt.figure()
    for i in range(num_runs):
        plt.plot(rewards[i], label=str(i))
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.savefig(folder + 'rewards_all.png')

#####################################################################################
def fill_zeros_with_previous(arr):
    arr = arr.copy()  # Avoid modifying the original array
    last_val = None
    for i in range(len(arr)):
        if arr[i] != 0:
            last_val = arr[i]
        elif last_val is not None:
            arr[i] = last_val
    return arr
######################################################################################

def get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=False):
    mean_RTM = []
    std_RTM = []
    mean_norm = []
    std_norm = []
    mean_time_only = []
    std_time_only = []
    mean_mag_only = []
    std_mag_only = []
    if opt:
        mean_opt = []

    for exp in experiments:
        folder = plot_folder + group_folder + '_cues_' + str(exp) +'/'

        # parameters = np.load(folder +'parameters.npz')

        # save rewards
        rewards_noTravel = []
        rewards_norm = []
        test_rewards_noTravel = []
        test_rewards_norm = []
        if ablation:
            rewards_time_only = []
            rewards_mag_only = []
            test_rewards_time_only = []
            test_rewards_mag_only = []

        for run_num in range(num_runs):
            # load rewards
            data = np.load(folder + 'rewards_' + str(run_num) + '.npz')
            # print('data',folder, data)
            rewards_noTravel.append(data['rewards_noTravel'])
            test_rewards_noTravel.append(data['test_rewards_noTravel'])
            rewards_norm.append(data['rewards_norm'])
            test_rewards_norm.append(data['test_rewards_norm'])
            if ablation:
                rewards_time_only.append(data['rewards_time_only'])
                test_rewards_time_only.append(data['test_rewards_time_only'])
                rewards_mag_only.append(data['rewards_mag_only'])
                test_rewards_mag_only.append(data['test_rewards_mag_only'])

        if opt:
            # load optimal agent rewards
            opt_data = np.load(folder + 'rewards_optimal.npz')
            opt_rewards = opt_data['opt_rewards']
            test_opt_rewards = opt_data['test_opt_rewards']
            mean_opt.append(np.array(test_opt_rewards))

        mean_rew_TMRL = np.mean(np.array(test_rewards_noTravel), axis=0)
        std_rew_TMRL = np.std(np.array(test_rewards_noTravel), axis=0)
        mean_rew_norm = np.mean(np.array(test_rewards_norm), axis=0)
        std_rew_norm = np.std(np.array(test_rewards_norm), axis=0)
        mean_RTM.append(mean_rew_TMRL)
        std_RTM.append(std_rew_TMRL)
        mean_norm.append(mean_rew_norm)
        std_norm.append(std_rew_norm)
        if ablation:
            mean_rew_time_only = np.mean(np.array(test_rewards_time_only), axis=0)
            std_rew_time_only = np.std(np.array(test_rewards_time_only), axis=0)
            mean_rew_mag_only = np.mean(np.array(test_rewards_mag_only), axis=0)
            std_rew_mag_only = np.std(np.array(test_rewards_mag_only), axis=0)
            mean_time_only.append(mean_rew_time_only)
            std_time_only.append(std_rew_time_only)
            mean_mag_only.append(mean_rew_mag_only)
            std_mag_only.append(std_rew_mag_only)

        plot_rewards_all(folder, num_runs, test_rewards_noTravel)

        if opt:
            plot_average_rewards_with_opt(folder, mean_rew_TMRL, std_rew_TMRL, mean_rew_norm, std_rew_norm, np.array(test_opt_rewards))

    if opt:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm), np.array(mean_opt)
    elif ablation:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm), np.array(mean_time_only), np.array(std_time_only), np.array(mean_mag_only), np.array(std_mag_only)
    else:
        return np.array(mean_RTM), np.array(std_RTM), np.array(mean_norm), np.array(std_norm)


def comparison_plots(group_folder, experiments, num_runs, config, format='png', opt=False, ablation=False):

    test_every_n = config.test_every_n
    
    if opt:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_opt = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=True)
    elif ablation:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_time_only, std_time_only, mean_mag_only, std_mag_only = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=True)
    else:
        mean_RTM, std_RTM, mean_norm, std_norm = get_mean_std_all_runs(group_folder, experiments, num_runs)

    print('mean RTM shape',len(mean_RTM))


    # set colormap for matplotlib
    cmap_TMD = plt.cm.Blues
    cmap_norm = plt.cm.Reds

    colors_RTM = [cmap_RTM((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    colors_norm = [cmap_norm((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    if ablation:
        cmap_time_only = plt.cm.Greens
        cmap_mag_only = plt.cm.Oranges
        colors_time_only = [cmap_time_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
        colors_mag_only = [cmap_mag_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) for i in range(len(experiments))]
    print('colors rtm',colors_RTM,colors_norm)
    # line styles
    linestyles = ['-', '-']

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    # Store line handles for custom legend
    legend_handles = []
    # plt.figure()
    for ind, exp in enumerate(experiments):
        max_mean = np.max(mean_RTM[ind])
        print('max mean',max_mean)

        line_rtm, = ax.plot(np.arange(len(mean_RTM[ind])) * test_every_n, mean_RTM[ind]/max_mean, color=colors_RTM[ind])
        ax.fill_between(np.arange(len(mean_RTM[ind])) * test_every_n, (mean_RTM[ind] - std_RTM[ind])/max_mean, (mean_RTM[ind] + std_RTM[ind])/max_mean, color=colors_RTM[ind], alpha=0.2)
        if ablation:
            line_time_only, = ax.plot(np.arange(len(mean_time_only[ind])) * test_every_n, mean_time_only[ind]/max_mean, color=colors_time_only[ind])
            ax.fill_between(np.arange(len(mean_time_only[ind])) * test_every_n, (mean_time_only[ind] - std_time_only[ind])/max_mean, (mean_time_only[ind] + std_time_only[ind])/max_mean, color=colors_time_only[ind], alpha=0.2)
            line_mag_only, = ax.plot(np.arange(len(mean_mag_only[ind])) * test_every_n, mean_mag_only[ind]/max_mean, color=colors_mag_only[ind])
            ax.fill_between(np.arange(len(mean_mag_only[ind])) * test_every_n, (mean_mag_only[ind] - std_norm[ind])/max_mean, (mean_mag_only[ind] + std_mag_only[ind])/max_mean, color=colors_mag_only[ind], alpha=0.2) 
        else:
            line_norm, = ax.plot(np.arange(len(mean_RTM[ind])) * test_every_n, mean_norm[ind]/max_mean, color=colors_norm[ind])
            ax.fill_between(np.arange(len(mean_RTM[ind])) * test_every_n, (mean_norm[ind] - std_norm[ind])/max_mean, (mean_norm[ind] + std_norm[ind])/max_mean, color=colors_norm[ind], alpha=0.2)
        
        # save legend handles
        if ablation:
            legend_handles.append((line_rtm, line_time_only, line_mag_only))
        else:
            legend_handles.append((line_rtm, line_norm))

    plt.xlabel('time steps')
    plt.ylabel('normalized reward rate')
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_normalized_rewards.'+format)
    
    # plt.legend()
    # Create a legend with two columns: RTM and norm
    if ablation:
        legend_fig, legend_ax = plt.subplots(figsize=(4.8, 4))
    else:
        legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_ax.axis('off')

    table_data = []
    if ablation:
        exp_labels = ['0.05','0.1','0.2']
        for i, (line_rtm, line_time_only, line_mag_only) in enumerate(legend_handles):
            table_data.append([line_rtm, line_time_only, line_mag_only])

        # Add legend lines to table
        for row, (line_rtm, line_time_only, line_mag_only) in enumerate(table_data):
            print('line_rtm',line_rtm)
            y = 1 - row * 0.2
            legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=line_rtm.get_color(), linestyle=linestyles[0], lw=2))
            legend_ax.add_line(Line2D([0.5, 0.7], [y, y], color=line_time_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.add_line(Line2D([0.8, 1.0], [y, y], color=line_mag_only.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.text(0.0, y, f"{exp_labels[row]}", va='center')

        # Add column headers
        legend_ax.text(0.3, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.6, 1.05, "time only", ha='center', fontweight='bold')
        legend_ax.text(0.9, 1.05, "magnitude only", ha='center', fontweight='bold')
        legend_ax.text(0.04, 1.07, "probility of\nstimuli", ha='center', fontweight='bold')
    else:
        for i, (line_rtm, line_norm) in enumerate(legend_handles):
            table_data.append([line_rtm, line_norm])

        # Add legend lines to table
        for row, (line_rtm, line_norm) in enumerate(table_data):
            print('line_rtm',line_rtm)
            y = 1 - row * 0.2
            legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=line_rtm.get_color(), linestyle=linestyles[0], lw=2))
            legend_ax.add_line(Line2D([0.6, 0.8], [y, y], color=line_norm.get_color(), linestyle=linestyles[1], lw=2))
            legend_ax.text(0.0, y, f"{experiments[row]}", va='center')

        # Add column headers
        legend_ax.text(0.3, 1.05, "TMRL", ha='center', fontweight='bold')
        legend_ax.text(0.7, 1.05, "standard RL", ha='center', fontweight='bold')
        legend_ax.text(0.0, 1.1, "number of\nstimuli", ha='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    plt.xlabel('time steps')
    plt.ylabel('normalized reward per time step')
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'group_normalized_rewards_legend.'+format)




def across_stimuli_plots(group_folder, experiments, num_runs, format='png', ablation=False):

    if ablation:
        mean_RTM, std_RTM, mean_norm, std_norm, mean_time_only, std_time_only, mean_mag_only, std_mag_only = get_mean_std_all_runs(group_folder, experiments, num_runs, opt=False, ablation=True)
    else:
        mean_RTM, std_RTM, mean_norm, std_norm = get_mean_std_all_runs(group_folder, experiments, num_runs)

    # get colors
    cmap_RTM = plt.cm.Blues
    cmap_norm = plt.cm.Reds
    i = 1
    colors_RTM = cmap_RTM((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1))
    colors_norm = cmap_norm((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 
    if ablation:
        cmap_time_only = plt.cm.Greens
        cmap_mag_only = plt.cm.Oranges
        colors_time_only = cmap_time_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 
        colors_mag_only = cmap_mag_only((len(experiments) - i * 0.5 - 1) / (len(experiments) - 1)) 

    # plot average rewards
    # Create the plot
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    max_mean = np.max(mean_RTM[-100:], axis=-1)
    line_rtm = ax.errorbar(experiments, mean_RTM[:,-1]/max_mean, yerr=std_RTM[:,-1]/max_mean, capsize=2, color=colors_RTM)
    if ablation:
        line_time_only = ax.errorbar(experiments, mean_time_only[:,-1]/max_mean, yerr=std_time_only[:,-1]/max_mean, capsize=2, color=colors_time_only)
        line_mag_only = ax.errorbar(experiments, mean_mag_only[:,-1]/max_mean, yerr=std_mag_only[:,-1]/max_mean, capsize=2, color=colors_mag_only)
    else:
        line_norm = ax.errorbar(experiments, mean_norm[:,-1]/max_mean, yerr=std_norm[:,-1]/max_mean, capsize=2, color=colors_norm)
    ax.spines['left'].set_linewidth(linewidth)
    ax.spines['bottom'].set_linewidth(linewidth)
    ax.tick_params(width=linewidth, length=length_ticks)

    # save legend handles
    print('line_rtm', line_rtm[0])
    if ablation:
        legend_handles= [line_rtm[0], line_time_only[0], line_mag_only[0]]
    else:
        legend_handles= [line_rtm[0], line_norm[0]]

    if ablation:
        plt.xlabel('number of stimuli')
    else:
        plt.xlabel('probability of stimuli')
    plt.ylabel('normalized reward rate')
    # plt.legend()
    # plt.title('Average rewards over time')
    folder = plot_folder + group_folder + '_cues_' + str(experiments[-1]) +'/'
    plt.savefig(folder + 'group_final_reward_across_stimuli.'+format)

    legend_fig, legend_ax = plt.subplots(figsize=(2.4, 2))
    legend_ax.axis('off')

    table_data = []
    for i, line in enumerate(legend_handles):
        table_data.append([line])

    if ablation:
        # names = ['TMRL', 'time distribution only', 'magnitude distribution only']
        names = ['TMRL', 'time only', 'magnitude only']
        colors = [colors_RTM, colors_time_only, colors_mag_only]
    else:
        names = ['TMRL', 'standard RL']
        colors = [colors_RTM, colors_norm]
    # Add legend lines to table
    for row, line in enumerate(table_data):
        y = 1 - row * 0.3
        ylab = 1 - row * 0.3 + 0.1
        legend_ax.add_line(Line2D([0.2, 0.4], [y, y], color=colors[row],  lw=2))
        # legend_ax.add_line(Line2D([0.6, 0.8], [y, y], color=colors[1],  lw=2))
        legend_ax.text(0.2, ylab, f"{names[row]}", va='center', fontweight='bold')

    legend_ax.set_xlim(0, 1)
    legend_ax.set_ylim(0, 1.2)
    # plt.title('Average rewards over time')
    plt.savefig(folder + 'across_stimuli_legend.'+format)
    


def plot_2D_distribution(format='svg'):
    from scipy.stats import multivariate_normal

    # Define grid
    x, y = np.mgrid[-5:5:.01, -5:5:.01]
    pos = np.dstack((x, y))

    # Define two Gaussian distributions (means and covariances)
    rv1 = multivariate_normal(mean=[-2, -2], cov=[[1, 0], [0, 0.3]])
    rv2 = multivariate_normal(mean=[2, 2], cov=[[1, 0], [0, 0.3]])

    # Evaluate both distributions and sum them
    z = rv1.pdf(pos) + rv2.pdf(pos)

    print('max',np.max(z))

    def sigmoid(x):
        max_x = np.max(x)
        return 1.0/ (1 + np.exp(-50*(x - 0.5 * max_x)))

    cmap = 'coolwarm'
    # Plot
    plt.figure(figsize=(3,3))
    plt.contourf(x, y, sigmoid(z), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_0.'+format)

    plt.figure(figsize=(3,3))
    rv3 = multivariate_normal(mean=[-2, 2], cov=[[1, 0], [0, 0.3]])
    a = rv3.pdf(pos)
    plt.contourf(x, y, sigmoid(a), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_1.'+format)

    plt.figure(figsize=(3,3))
    rv4 = multivariate_normal(mean=[2, -2], cov=[[1, 0], [0, 0.3]])
    b = rv4.pdf(pos)
    plt.contourf(x, y, sigmoid(b), levels=300, cmap=cmap)
    plt.axis('off')
    plt.savefig(plot_folder + '2d_distribution_2.'+format)



# plot average rewards from multiple runs
def plot_internal_state(saveFolder, rewards_dict, title=None):
    # plot average rewards
    internal_state_init = 1
    if title is None:
        title = 'internal_state'
    plt.figure()
    for agent_name, rewards_full in rewards_dict.items():
        print('PLOT INTERNAL STATE SHAPE',np.array(rewards_full).shape)
        rewards = np.array(rewards_full)[:, -1, :]
        internal_states = np.ones(rewards.shape) * internal_state_init
        for t in range(rewards.shape[-1] - 1):
            internal_states[:, t+1] = internal_state_update(internal_states[:, t]) + rewards[:, t]
        for run in range(rewards.shape[0]):
            plt.plot(internal_states[run], label=agent_name + '_'+str(run))
        # plt.fill_between(range(len(np.mean(rewards, axis=0))), np.mean(rewards, axis=0) - np.std(rewards, axis=0), np.mean(rewards, axis=0) + np.std(rewards, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('inernal_state')
    plt.title('internal_state over time')
    plt.savefig(saveFolder + title + '.png')

def time_decay(timesteps, init_vals, increase_time, magnitudes, decay_rate=0.99):
    vals = np.zeros((len(magnitudes), timesteps))
    vals[:, 0] = init_vals
    for t in range(timesteps - 1):
        vals[:, t+1] = decay_rate * vals[:, t]
        if t == increase_time:
            vals[:, t+1] = vals[:, t] + magnitudes
    return vals

def plot_homeostatic_weights(saveFolder, title=None):
    # plot average rewards

    # Define the dimensions of the matrix
    time_steps = 50  # Number of time steps (rows)
    magnitudes_steps = 50  # Number of magnitudes (columns)

    # Create the input matrix
    magnitudes = np.linspace(0, 2, magnitudes_steps)
    init_vals = np.ones(magnitudes_steps) * 20.0
    weights = np.zeros((magnitudes_steps, time_steps))
    for increase_time in np.arange(time_steps):
        internal_states = time_decay(time_steps, init_vals, increase_time, magnitudes, decay_rate=0.99)
        homeo = np.sum(homeostatic_loss(internal_states), axis=-1)
        weights[:, increase_time] = homeo

    weights_without_magnitude = weights / (magnitudes[:, np.newaxis] + 1)
    # weights_without_magnitude[0, :] = 0

    # Plot the matrix
    plt.figure(figsize=(8, 6))
    plt.imshow(weights_without_magnitude, aspect='auto', cmap='inferno', origin='lower')
    plt.colorbar(label='Homeostatic Loss')
    plt.ylabel('Magnitude Dimension')
    plt.xlabel('Time Dimension')
    plt.title('Homeostatic Loss Matrix')
    plt.savefig(saveFolder + title + '.svg')



def plot_homeostatic_loss(saveFolder, num_runs, agent_names, agent_labels, colors=None, format='svg'):
    mean_agents = {name: [] for name in agent_names}
    std_agents = {name: [] for name in agent_names}

    # save rewards
    rewards_agents = {name: [] for name in agent_names}
    test_rewards_agents = {name: [] for name in agent_names}
    homeo_rewards_agents = {name: [] for name in agent_names}

    for name in agent_names:
        for run_num in range(num_runs):
            # load rewards
            data = np.load(saveFolder + 'rewards_' + name + '_' + str(run_num) + '.npz')
            # print('data',folder, data)
            rewards_agents[name].append(data['rewards'])
            test_rewards_agents[name].append(data['test_rewards'])
            homeo_rewards_agents[name].append(data['homeostatic_loss'])
        mean_agents[name] = np.mean(np.mean(np.array(test_rewards_agents[name]), axis=-1), axis=0)
        std_agents[name] = np.std(np.mean(np.array(test_rewards_agents[name]), axis=-1), axis=0)        
        # mean_agents[name] = np.mean(np.array(homeo_rewards_agents[name]), axis=0)
        # std_agents[name] = np.std(np.array(homeo_rewards_agents[name]), axis=0)

    test_every_n = 1000
    # Create the plot
    if colors is None:
        colors = ['#FA7914', plt.cm.Blues(0.9), 'green', 'orange', 'purple', 'brown']
    fig, ax = plt.subplots(figsize=(horizontal_size, vertical_size))
    for ind, name in enumerate(agent_names):
        ax.plot(np.arange(len(mean_agents[name])) * test_every_n, mean_agents[name],  color=colors[ind], lw=1, label = agent_labels[ind])
        ax.fill_between(np.arange(len(mean_agents[name])) * test_every_n, mean_agents[name] - std_agents[name], mean_agents[name] + std_agents[name], color=colors[ind], alpha=0.2)
    
    plt.xlabel('time steps')
    # plt.ylabel('homeostatic loss')
    # plt.ylabel('Utility')
    plt.ylabel('Reward rate')
    plt.legend(fontsize=6)
    # plt.title('Average rewards over time')
    # plt.savefig(saveFolder + '/figure_homeostatic_loss.'+format)
    plt.savefig(saveFolder + '/figure_reward_loss.'+format)

# if __name__ == '__main__':

#     # create folder if it doesn't already exist
#     if os.path.exists(saveFolder):
#         print('overwrite')
#     else:
#         os.makedirs(saveFolder)

#     config = TrainConfig()
#     # test_rewards, test_homeo_loss = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_homeostatis, config, [homeostatic_agent_particle, rewardMTAgent_particle])#, normalAgent])
#     # test_rewards, test_homeo_loss = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_homeostatis, config, [homeostatic_agent, rewardMTAgent, normalAgent])
#     test_rewards, test_homeo_loss = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_homeostatis, config, [homeostatic_agent, rewardMTAgent])
#     plot_internal_state(saveFolder, test_rewards)
#     plot_reward_histograms(config.saveFolder, test_rewards)

#     # train_multiple_agents(num_runs, num_cues=num_cues, map_size=map_size, max_reward_delay=max_reward_delay, max_reward_magnitude=max_reward_magnitude, alpha=alpha, cue_probs=cue_probs, train_timesteps=train_timesteps, test_timesteps=test_timesteps, test_every_n=test_every_n)
#     # plot_saved_rewards()
#     # cue_nums_to_plot = np.arange(2,4)    
#     # plot_different_cue_numbers(cue_nums_to_plot)

# if __name__ == '__main__':

#     group_folder = 'gridworld_group'
#     experiments = [3,4, 5]
#     comparison_plots(group_folder, experiments, num_runs, format='svg', opt=False)
#     across_stimuli_plots(group_folder, experiments, num_runs, format='svg')

    # plot_2D_distribution()

# if __name__ == '__main__':

#     plot_homeostatic_weights(saveFolder,'homeostatic_weights')


if __name__ == '__main__':

    # plot_homeostatic_loss(saveFolder,10, ['homeostatic_agent', 'rewardMTAgent'], ['RS-TMRL', 'TMRL'])
    plot_homeostatic_loss(saveFolder,10, ['homeostatic_agent', 'rewardMTAgent'], ['Time varying risk\nsensitive TMRL', 'TMRL'])
   