'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the probability of reward that will be given
in each state.
We use a square gridworld environment.
Here, we try to optimize for homeostatic losses.
'''

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D
import os,  json
from dataclasses import dataclass, asdict

from gridwork_reward_time_and_mag import gridWorldEnv, random_position#, train_agent
from gridwork_reward_time_and_mag import rewardMTAgent, rewardMTAgent_particle, homeostatic_agent, normalAgent, QuantileRLAgent, magnitude_risk_agent, magnitude_risk_particle_agent
from gridwork_reward_time_and_mag import generate_grid_reward_magnitude_time_matrices, plot_reward_magnitude_time_matrices, plot_rewards, plot_value_function, plot_reward_MT_positions
from patch_risk_sensitivity import plot_average_rewards, plot_average_risk, plot_reward_histograms
from gridworld_homeostatis import train_multiple_agents, plot_homeostatic_loss

# import plotly.express as px
# import plotly.graph_objects as go


# Parameters for plots
length_ticks = 2
font_size = 9
linewidth = 1.2
scatter_size = 2
length_ticks = 2
scatter_size = 20
horizontal_size = 1.5
vertical_size = 1.5
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size - 5
mpl.rcParams['ytick.labelsize'] = font_size - 5
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.titlesize'] = font_size - 2
mpl.rcParams['legend.fontsize'] = font_size - 2


plot_folder = ''
experiment_name = 'magnitude_risk_gridworld_group_cues_2/'
saveFolder = plot_folder + experiment_name


def risk_function(a, matrix):
    """
    Based on CVAR
    """
    alpha = 0.5
    # Step 1: Normalize the matrix along the last axis
    normalized_matrix = matrix / (np.sum(matrix, axis=-1, keepdims=True) + 1e-6)

    # Step 2: Compute the cumulative sum along the last axis
    cumulative = np.cumsum(normalized_matrix, axis=-1)

    # Step 3: Create an output matrix initialized to zeros
    output_matrix = np.zeros_like(matrix)

    # Step 4: Iterate over the rows to apply the threshold logic
    for i in range(matrix.shape[0]):  # Iterate over the first dimension
        for j in range(matrix.shape[1]):  # Iterate over the second dimension
            # Extract the 2D slice for the current (i, j)
            for t in range(matrix.shape[-2]):
                # Find the index where the cumulative sum exceeds alpha
                threshold_index = np.searchsorted(cumulative[i, j, t], alpha)

                # Set all elements below the threshold to 1
                output_matrix[i, j, t, :threshold_index] = 1

                # Adjust the element at the threshold index
                if threshold_index < matrix.shape[-1]:
                    remaining_value = alpha - np.sum(normalized_matrix[i, j, t, :threshold_index])
                    output_matrix[i, j, t, threshold_index] = remaining_value / (normalized_matrix[i, j, t, threshold_index] + 1e-6)

    return output_matrix * np.arange(matrix.shape[-1])

# seed
np.random.seed(42)
# max reward delay
max_reward_delay = 10
# max reward magnitude
max_reward_magnitude = 8


dx_der = 0.1
dy_der = 0.1
x_der, y_der = np.mgrid[-1:(max_reward_delay + 1):dx_der,-1:(max_reward_magnitude + 1):dy_der]
Nx_der = x_der.shape[0]
Ny_der = y_der.shape[1]
x_der_flat = np.expand_dims(np.ndarray.flatten(x_der),axis=1)
y_der_flat = np.expand_dims(np.ndarray.flatten(y_der),axis=1)
particles_der = np.concatenate((x_der_flat,y_der_flat),axis=1)
n_interactions = 100
batch_size = 100
bw = 1.0 #1.0 #0.8
lamb=  0.2  * np.max([max_reward_delay, max_reward_magnitude]) # 1.0
particle_gamma = 50.0 #100.0 #2.0 #0.09*16
learning_rates = 0.001 * np.ones(10000)
learning_rates[:5000] = np.linspace(0.001, 0.0001, 5000)
n_particles = 10 * 10 #5*5 
cov = 0.5*np.array([[1, 0], [0, 1]])
cov_shrink = np.ones(n_interactions * 1000) * 0.01
cov_shrink[:n_interactions] = np.linspace(1.0, 0.01, n_interactions)




@dataclass
class TrainConfig:
    num_runs: int = 2 # number of runs
    map_size: int = 5 # number of states is map_size x map_size
    num_cues: int = 3 # number of cues
    max_reward_delay: int = max_reward_delay # max reward delay
    max_reward_magnitude: int = max_reward_magnitude # max reward magnitude
    alpha: float = 0.005
    cue_probs = 0.05 * np.ones(num_cues) # cue probabilities at each time
    train_timesteps: int = 200000 # number of time steps for training or acting
    test_timesteps: int = 1000
    test_every_n: int = 1000 # test reward rate every time steps
    saveFolder: str = saveFolder
    # discount factor
    discount_gamma: float = 0.99
    # epsilon greedy exploration
    epsilon: float = 0.1
    # risk weights for TMD
    risk_function = risk_function
    # DNL particle parameters
    n_particles: int = n_particles
    cov = cov
    cov_shrink = cov_shrink
    batch_size: int = batch_size
    dx_der: float = dx_der
    dy_der: float = dy_der
    bw: float = bw
    learning_rates = learning_rates
    x_der = x_der
    y_der = y_der
    particles_der = particles_der
    n_interactions: int = n_interactions
    lamb: float = lamb
    particle_gamma: float = particle_gamma
    Nx_der: int = Nx_der
    Ny_der: int = Ny_der
    # test cue simultaneous
    test_cue_simultaneous: bool = True



def set_train_test_env(reward_magnitude_time_matrices, cue_probs, map_size):
    # create environment
    env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    test_env = gridWorldEnv(map_size, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    return env, test_env, state, cue, reward



def generate_grid_reward_magnitude_time_matrices_magnitude_risk(num_cues, map_size, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    reward_magnitude_time_matrices = np.zeros((num_cues, map_size, map_size, max_reward_delay, max_reward_magnitude))
    rewarded_positions = []
    # number of possible reward time and mags
    num_poss = 1
    # for i in range(num_cues - 1):
    rewarded_positions = [(0,0)]
    x,y = rewarded_positions[0]
    time = max_reward_delay - 3
    magnitude = np.random.randint(1, max_reward_magnitude-2)
    reward_magnitude_time_matrices[2, x, y, time, magnitude] = 1

    x,y = 4,0 #random_position(map_size, rewarded_positions)
    rewarded_positions = np.insert(rewarded_positions, 0, (x, y), axis=0)

    # rewarded_positions.append((x,y))
    # rewarded_positions = np.insert(rewarded_positions, 0, (x,y), axis=0)
    time = max_reward_delay - 3
    magnitude = 1
    # reward_magnitude_time_matrices[-1, x, y, time, magnitude] = 1
    reward_magnitude_time_matrices[1, x, y, time, magnitude] = 1
    magnitude2 = max_reward_magnitude - 1
    # reward_magnitude_time_matrices[-1, x, y, time, magnitude2] = 1
    reward_magnitude_time_matrices[1, x, y, time, magnitude2] = 1
    # normalize the matrix
    # reward_magnitude_time_matrices[-1, x, y] /= np.sum(reward_magnitude_time_matrices[-1, x, y])
    reward_magnitude_time_matrices[1, x, y] /= np.sum(reward_magnitude_time_matrices[1, x, y])
    x_certain,y_certain = 0,4 #random_position(map_size, rewarded_positions)
    rewarded_positions = np.insert(rewarded_positions, 0, (x_certain,y_certain), axis=0)
    print('got x y certain',x_certain,y_certain)
    print('rewarded positions', rewarded_positions)

    # reward_magnitude_time_matrices[-1, x_certain, y_certain, time, 2] = 1
    reward_magnitude_time_matrices[0, x_certain, y_certain, time, 2] = 1
    return reward_magnitude_time_matrices, rewarded_positions


# plot average rewards from multiple runs
def plot_average_risk_sensitive_reward(saveFolder, rewards_dict, risk_function):
    # plot average rewards
    plt.figure()
    for agent_name, rewards_full in rewards_dict.items():
        rewards = np.mean(risk_function(rewards_full) * rewards_full, axis=-1)
        plt.plot(np.mean(rewards, axis=0), label=agent_name)
        plt.fill_between(range(len(np.mean(rewards, axis=0))), np.mean(rewards, axis=0) - np.std(rewards, axis=0), np.mean(rewards, axis=0) + np.std(rewards, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + 'average_distorted_rewards.png')


def plot_choice_histograms(plot_folder, group_folder, choice_rewards, agent_names, agent_labels, num_runs=10, colors=None):
    dict_rewards = {}

    for i, agent_name in enumerate(agent_names):
        # add entry to dict
        dict_rewards[agent_name] = []
        folder = plot_folder + group_folder + '/'

        for run_num in range(num_runs):
            # load rewards
            data = np.load(folder + 'rewards_' + agent_name + '_' + str(run_num) + '.npz')
            # save rewards to dict
            dict_rewards[agent_name].append(data['test_rewards'])

    plt.figure(figsize=(horizontal_size, vertical_size))
    print('colors', colors)
    if colors is None:
        colors = ['#FA8072', plt.cm.Blues(0.9), 'green', 'orange', 'purple', 'brown']
    for i, agent_name in enumerate(agent_names):
        rewards = np.array(dict_rewards[agent_name])
        print('rewards shape', rewards.shape)
        # get the rewards at the end of training
        final_rewards = rewards[:, -10:, :]
        # flatten the rewards
        final_rewards = final_rewards.reshape(-1, final_rewards.shape[-1])
        choice_counts = np.zeros((len(final_rewards), len(choice_rewards)))
        for ind_ch, choice_rewards_set in enumerate(choice_rewards):
            choice_rewards_set = np.array(choice_rewards_set)
            # Check if elements in final_rewards are in choice_rewards_set
            matches = np.isin(final_rewards, choice_rewards_set)
            choice_counts[:, ind_ch] = np.sum(matches, axis=-1)
            
            # plot histogram
        # plt.hist(final_rewards, bins=np.arange(0, np.max(choice_rewards)+2)-0.5, density=True, alpha=0.5, label=agent_labels[i])
        choice_counts = choice_counts / np.sum(choice_counts, axis=-1, keepdims=True)
        mean_choice_counts = np.mean(choice_counts, axis=0)
        std_choice_counts = np.std(choice_counts, axis=0)

        # Plot bar chart with error bars
        plt.bar(
            np.arange(1, len(choice_rewards) + 1) + i * 0.3 - 0.15,
            mean_choice_counts,
            yerr=std_choice_counts,  # Add error bars
            color = colors[i],
            width=0.3,
            alpha=0.7,
            label=agent_labels[i],
            capsize=2,  # Add flat caps to the error bars
            error_kw={'alpha': 0.7,  # Set transparency for the error bars
                      'elinewidth': 0.7}  # Set the thickness of the error bars and caps
        )
        # plt.bar(np.arange(1, len(choice_rewards)+1) + i*0.2, mean_choice_counts, width=0.2, alpha=0.5, label=agent_labels[i])
    plt.xticks(np.arange(1, len(choice_rewards)+1, dtype=int))
    plt.xlabel('Choice')
    # plt.xlabel('Reward magnitude')
    plt.ylabel('Probability')
    # plt.legend(fontsize=6, frameon=False)
    # plt.title('Choice histogram for ' + agent_name)
    plt.savefig(plot_folder + group_folder + '/choice_histogram_group.svg')
    # plt.savefig(plot_folder + group_folder + '/full_reward_histogram_group.svg')


# if __name__ == '__main__':

#     # create folder if it doesn't already exist
#     if os.path.exists(saveFolder):
#         print('overwrite')
#     else:
#         os.makedirs(saveFolder)

#     config = TrainConfig()
#     # For quantile regression agent comparison
#     # test_rewards, _ = train_multiple_agents(generate_grid_reward_magnitude_time_matrices, config, [rewardMTAgent_particle, normalAgent, QuantileRLAgent])
    
#     # For magnitude risk agent
#     test_rewards, _ = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_magnitude_risk, config, [magnitude_risk_agent, rewardMTAgent])
#     # test_rewards, _ = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_magnitude_risk, config, [rewardMTAgent_particle, magnitude_risk_particle_agent])
#     # test_rewards, _ = train_multiple_agents(generate_grid_reward_magnitude_time_matrices_magnitude_risk, config, [rewardMTAgent])
#     # plot_average_risk(config.saveFolder, test_rewards, risk_function)
#     plot_reward_histograms(config.saveFolder, test_rewards)

    # plot_saved_rewards()
    # cue_nums_to_plot = np.arange(2,4)    
    # plot_different_cue_numbers(cue_nums_to_plot)

if __name__ == '__main__':

    group_folder = 'magnitude_risk_gridworld_group'
    plot_choice_histograms(plot_folder, group_folder, choice_rewards=[[1],[2],[3],[4],[5],[6],[7]], agent_names=['magnitude_risk_agent', 'rewardMTAgent'], agent_labels=['Magnitude risk\nsensitive TMRL', 'TMRL'], num_runs=10)    
