'''
Use case for reward time and magnitude maps
Each reward and time magnitude map describes the reward that will be given
in each state.
We assume here that you have reach any state from any other state with no travel time.
'''

import numpy as np
import matplotlib.pyplot as plt
import os, json
from dataclasses import dataclass, asdict

from no_travel_time_cue_reward_time_and_mag import noTravelTimeEnv
from no_travel_time_cue_reward_time_and_mag import noTravelAgent
from no_travel_time_cue_reward_time_and_mag import particle_noTravelAgent
from no_travel_time_cue_reward_time_and_mag import normalAgent
from no_travel_time_cue_reward_time_and_mag import noTravelAgent_1D_time_dist
from no_travel_time_cue_reward_time_and_mag import noTravelAgent_1D_magnitude_dist
from no_travel_time_cue_reward_time_and_mag import particle_time_only_agent, particle_mag_only_agent
# from no_travel_time_cue_reward_time_and_mag import train_agent
from no_travel_time_cue_reward_time_and_mag import plot_reward_magnitude_time_matrices, plot_rewards, generate_reward_magnitude_time_matrices
from gridworld_group_plots import comparison_plots, across_stimuli_plots

from no_travel_time_group_plots import set_train_test_env




plot_folder = ''
experiment_name = 'risky_patch_0_cues_2_p_02/'
saveFolder = plot_folder + experiment_name

# define a risk distortion on magnitudes
def magnitude_distortion(array):
    # distort = np.array(array.copy())
    # distort[distort>4] = 0
    # return distort
    return np.power(array * 10, 1.0/4.0)

# Environment parameters

# seed
np.random.seed(42)
# max reward delay
max_reward_delay = 5
# max reward magnitude
max_reward_magnitude = 7
# risk weights for TMD
risk_weight = magnitude_distortion(np.arange(max_reward_magnitude))
print(risk_weight)



@dataclass
class TrainConfig:
    num_runs: int = 2 # number of runs
    num_states: int = 3 # number of states
    num_cues: int = num_states
    max_reward_delay: int = max_reward_delay # max reward delay
    max_reward_magnitude: int = max_reward_magnitude # max reward magnitude
    alpha: float = 0.01
    cue_probs = 0.1 * np.ones(num_states) # cue probabilities at each time
    train_timesteps: int = 10000 # number of time steps for training or acting
    test_timesteps: int = 1000
    test_every_n: int = 100 # test reward rate every time steps
    saveFolder: str = saveFolder
    risk_weight = risk_weight




def generate_reward_magnitude_time_matrices_variable_mag(num_cues, num_states, max_reward_delay, max_reward_magnitude):
    # matrices for reward time and magnitude
    # each element is a prob of magnitude given a time
    # whole matrix should sum to one for each state
    # make the matrix have variable magnitudes at the same time and one of the states have certain magnitude
    reward_magnitude_time_matrices = np.zeros((num_cues, num_states, max_reward_delay, max_reward_magnitude))
    # number of possible reward time and mags
    
    num_poss = 2
    for i in range(1):
        j = i % num_states  # ensure that the cue is associated with a state
        # for n in range(num_poss):
        time = np.random.randint(1, max_reward_delay)
        magnitude = 1
        reward_magnitude_time_matrices[i, j, time, magnitude] = 1
        magnitude2 = max_reward_magnitude - 1
        reward_magnitude_time_matrices[i, j, time, magnitude2] = 1
        # normalize the matrix
        reward_magnitude_time_matrices[i, j] /= np.sum(reward_magnitude_time_matrices[i, j])
        reward_magnitude_time_matrices[i, (j + 1) % num_states, time, max_reward_magnitude//2] = 1
    # reward_magnitude_time_matrices[-1, -1] = np.zeros((max_reward_delay, max_reward_magnitude))
    # make one certain reward magnitude
    reward_magnitude_time_matrices[-1, -1, 1, max_reward_magnitude//2 + 1] = 1
    reward_magnitude_time_matrices[-2, -2, 2, max_reward_magnitude//2 + 2] = 1
    return reward_magnitude_time_matrices


class noRisk_patch_agent(noTravelAgent):
    def __init__(self, num_states, init_state, reward_magnitude_time_matrices, alpha=0.1, risk_weight=None):
        # Call the parent class constructor
        super().__init__(num_states, init_state, reward_magnitude_time_matrices, alpha, risk_weight=None)
        # Override risk_weight to None
        self.risk_weight = None


def train_multiple_agents(generate_rew_mag_time_mat, config, agent_classes):
    """
    Train multiple agents based on the provided agent classes.

    Args:
        generate_rew_mag_time_mat (function): Function to generate reward magnitude time matrices.
        config (TrainConfig): Configuration object with training parameters.
        agent_classes (list): List of agent classes to initialize and train.

    Returns:
        None
    """
    num_states = config.num_states
    num_cues = config.num_cues
    max_reward_delay = config.max_reward_delay
    max_reward_magnitude = config.max_reward_magnitude
    saveFolder = config.saveFolder

    # Save parameters
    with open(saveFolder + "parameters.json", "w") as f:
        json.dump(asdict(config), f, indent=2)

    # Generate reward magnitude time matrices
    reward_magnitude_time_matrices = generate_rew_mag_time_mat(num_cues, num_states, max_reward_delay, max_reward_magnitude)
    print('reward_magnitude_time_matrices', reward_magnitude_time_matrices.shape)
    plot_reward_magnitude_time_matrices(saveFolder, reward_magnitude_time_matrices, 'reward_time_magnitude_matrices_per_state')

    # Initialize rewards storage
    rewards = {agent_class.__name__: [] for agent_class in agent_classes}
    test_rewards = {agent_class.__name__: [] for agent_class in agent_classes}

    # Train agents
    for run_num in range(config.num_runs):
        print('Run number:', run_num)
        for agent_class in agent_classes:
            print(f'Training agent: {agent_class.__name__}')
            # Train the agent
            r, r_test = train_and_test(run_num, agent_class, reward_magnitude_time_matrices, config)
            # Store rewards
            rewards[agent_class.__name__].append(r)
            test_rewards[agent_class.__name__].append(r_test)
            # Save rewards
            np.savez(saveFolder + f'rewards_{agent_class.__name__}_{run_num}.npz', rewards=r, test_rewards=np.mean(r_test,axis=-1))

    # Plot average rewards
    plot_average_rewards(saveFolder, test_rewards)
    return test_rewards


def train_and_test(run_num, agent_class, reward_magnitude_time_matrices, config):
    cue_probs = config.cue_probs
    num_states = config.num_states
    train_timesteps = config.train_timesteps
    test_timesteps = config.test_timesteps
    test_every_n = config.test_every_n
    alpha = config.alpha
    risk_weight = config.risk_weight

    # create environment
    env, test_env, state, cue, reward = set_train_test_env(reward_magnitude_time_matrices, cue_probs, num_states)
    # create agent
    init_reward_MT = np.zeros(reward_magnitude_time_matrices.shape)
    agent = agent_class(num_states, state, init_reward_MT, alpha, risk_weight)   
    train_rew, test_rew = train_agent(agent, env, test_env, train_timesteps, test_timesteps, test_every_n, state, cue, reward)
    print('Learned TMD ',agent.reward_MT)
    test_learned_TMD(agent.reward_MT, reward_magnitude_time_matrices, agent.magnitude_values)
    plot_rewards(saveFolder, np.mean(test_rew, axis=-1), 'test_rewards_' + agent_class.__name__ + '_' + str(run_num))
    if (agent_class.__name__ == 'magnitude_risk_agent') or (agent_class.__name__ == 'noTravelAgent'):
        plot_reward_magnitude_time_matrices(saveFolder, agent.reward_MT, 'learned_reward_time_magnitude_matrices_per_state'+'_'+str(run_num))
    return train_rew, test_rew


# plot average rewards from multiple runs
def plot_average_rewards(saveFolder, rewards_dict, title=None):
    # plot average rewards
    if title is None:
        title = 'average_rewards'
    plt.figure()
    for agent_name, rewards_full in rewards_dict.items():
        rewards = np.mean(rewards_full, axis=-1)
        plt.plot(np.mean(rewards, axis=0), label=agent_name)
        plt.fill_between(range(len(np.mean(rewards, axis=0))), np.mean(rewards, axis=0) - np.std(rewards, axis=0), np.mean(rewards, axis=0) + np.std(rewards, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + title + '.png')

# plot average rewards from multiple runs
def plot_average_risk(saveFolder, rewards_dict, risk_function):
    # plot average rewards
    plt.figure()
    for agent_name, rewards_full in rewards_dict.items():
        rewards = np.mean(risk_function(rewards_full) * rewards_full, axis=-1)
        plt.plot(np.mean(rewards, axis=0), label=agent_name)
        plt.fill_between(range(len(np.mean(rewards, axis=0))), np.mean(rewards, axis=0) - np.std(rewards, axis=0), np.mean(rewards, axis=0) + np.std(rewards, axis=0), alpha=0.2)
    plt.legend()
    plt.xlabel('time steps')
    plt.ylabel('reward')
    plt.title('Average rewards over time')
    plt.savefig(saveFolder + 'average_distorted_rewards.png')

def plot_reward_histograms(saveFolder, rewards_dict, last_n=1000):
    """
    Plot histograms of rewards for the last `last_n` timesteps for each agent.

    Args:
        saveFolder (str): Folder to save the plots.
        rewards_dict (dict): Dictionary where keys are agent names and values are reward arrays.
        last_n (int): Number of timesteps to consider from the end.

    Returns:
        None
    """
    plt.figure()
    for agent_name, rewards_full in rewards_dict.items():
        # Extract rewards for the last `last_n` timesteps
        rewards_last_n = np.array(rewards_full)[:, -last_n:].flatten()
        rewards_last_n = rewards_last_n[rewards_last_n > 0]
        plt.hist(rewards_last_n, bins=20, alpha=0.5, label=agent_name)

    plt.xlabel('Reward')
    plt.ylabel('Frequency')
    plt.title(f'Reward Histograms (Last {last_n} Timesteps)')
    plt.legend()
    plt.savefig(saveFolder + f'reward_histograms_last_{last_n}.png')
    plt.close()


def get_new_train_test_env(num_states, reward_magnitude_time_matrices, cue_probs):
    # create environment
    env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    test_env = noTravelTimeEnv(num_states, reward_magnitude_time_matrices, cue_probs)
    # reset environment
    state, cue = env.reset()
    reward = 0
    
    return env, test_env, state, cue, reward


def test_agent(agent, env, num_timesteps):
    # reset environment
    state, cue = env.reset()
    reward = 0
    # track rewards
    rewards = np.zeros(num_timesteps)
    # reset agent for test
    agent.reset_test(state)
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.test_act(state, cue)
        # take action in environment
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        
    return rewards

def train_agent(agent, env, test_env, num_timesteps, test_timesteps, test_every, init_state, init_cue, init_reward):
    state = init_state
    cue = init_cue
    reward = init_reward
    # track rewards
    rewards = np.zeros(num_timesteps)
    test_rewards = np.zeros((num_timesteps // test_every, test_timesteps))
    
    for time in range(num_timesteps):              
        # get action from agent
        action = agent.train(reward, state, cue)
        # take action in environment
        reward, cue, state = env.step(action)            
        rewards[time] = reward
        # print('time: ',time,'cue: ', cue, 'state: ', state, 'action: ', action, 'reward: ', reward, 'cur cue delay: ', agent.cur_cue_delay)
        # test agent every test_every time steps
        if time % test_every == 0:
            # test agent
            test_rewards[time // test_every] = test_agent(agent, test_env, test_timesteps)
    
    return rewards, test_rewards


def test_learned_TMD(learned_TMD, true_TMD, risk_weight):
    print('risk weight of agent', risk_weight)
    true_risk = np.sum(true_TMD * risk_weight[None, None, None, :], axis=-1)
    true_risk_choice = np.argmax(np.max(true_risk, axis=-1), axis=-1)
    print('true risk choice ',true_risk_choice, true_risk)
    # print('cue 0 true TMD', true_TMD[0])
    learned_risk = np.sum(learned_TMD * risk_weight[None, None, None, :], axis=-1)
    learned_risk_choice = np.argmax(np.max(learned_risk, axis=-1), axis=-1)
    print('learned risk choice ',learned_risk_choice, learned_risk)




if __name__ == '__main__':

    # create folder if it doesn't already exist
    if os.path.exists(saveFolder):
        print('overwrite')
    else:
        os.makedirs(saveFolder)

    config = TrainConfig()
    # test_rewards = train_multiple_agents(generate_reward_magnitude_time_matrices_variable_mag, config, [noTravelAgent, noRisk_patch_agent])
    test_rewards = train_multiple_agents(generate_reward_magnitude_time_matrices_variable_mag, config, [noTravelAgent, noRisk_patch_agent])
    plot_average_risk(config.saveFolder, test_rewards)
    plot_reward_histograms(config.saveFolder, test_rewards)
    # plot_saved_rewards()
    # cue_nums_to_plot = np.arange(2,4)    
    # plot_different_cue_numbers(cue_nums_to_plot)


# if __name__ == '__main__':

#     group_folder = 'particle_patch_group_3'
#     experiments = [2,3,4,5]
#     comparison_plots(group_folder, experiments, num_runs, format='svg')
#     across_stimuli_plots(group_folder, experiments, num_runs, format='svg')



# if __name__ == '__main__':

#     group_folder = 'particle_patch_group_7'
#     experiments = ['3_p_005', '3_p_01', '3_p_02']#, '5_p_03']
#     comparison_plots(group_folder, experiments, num_runs, format='svg', opt=False, ablation=True)
#     across_stimuli_plots(group_folder, experiments, num_runs, format='svg', ablation=True)

