import numpy as np
import pandas as pd 

from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
import seaborn as sns

import os
from os import path
from pathlib import Path
from sys import dont_write_bytecode, platform
import csv, re
from scipy.special import softmax


def load_results(path, load_type, param_names, filters=None):
    '''
    loads all results and returns an np array
        path : root of results
        type : 'trace', 'offline_eval_returns', 'offline_eval_steps', 'actions', 'states', 'td_errs', 'params'
        param_names : list of parameters to extract
        filters : a dictionary of filters to apply 
    '''

    def load_params(root):
        params = {}
        with open(os.path.join(root, 'args.csv'), mode='r') as infile:
            reader = csv.reader(infile)
            for line in reader:
                params[line[0]] = line[1]
        return params

    data = []
    all_params = []

    for root, subdirs, files in os.walk(path):
        if len(subdirs) == 0: # base case
            params = load_params(root)
            params_run  = [] # relevant parameters for a single run
            
            filter = False
            for p in param_names:

                if p == 'exp_val':
                    if 'softmax' in params['exploration_strategy']:
                        param = float(params.get('temp', None))
                    elif params['exploration_strategy'] == 'epsilon-greedy':
                        param = float(params.get('epsilon', None))
                    elif params['exploration_strategy'] == 'mellowmax':
                        param = float(params.get('omega', None))
                    else:
                        param = float(params.get('eta', None))

                else:

                    param = params.get(p, None)

                params_run.append(param)
                
                if filters is not None:
                    allowed_values = filters.get(p)
                    if allowed_values is not None and param not in allowed_values:
                        filter = True

            if filter == False:
                add = True
                if load_type == 'trace':
                    for file in files:
                        if 'trace.episode_returns' in file or 'trace.rewards' in file:
                            load_file = os.path.join(root, file)
                            run_data  = np.sum(np.load(load_file, allow_pickle=True))

                elif load_type == 'params':
                    continue
                
                elif load_type == 'state':
                    load_file = os.path.join(root, 'states.npy')
                    run_data  = np.load(load_file)

                else:
                    load_file = os.path.join(root, load_type+'.npy')
                    run_data  = np.load(load_file)

                if add:
                    all_params.append(params_run)
                    data.append(run_data)

    return np.array(all_params), np.array(data)


def random_argmax(a):
    '''
    like np.argmax, but returns a random index in the case of ties
    '''
    return np.random.choice(np.flatnonzero(a == a.max()))


def resmax(eta, q):
    """
    Return probability distribution p over actions representing a stochastic policy
    
    arguments:
        q: values for each action for a fixed state
    
    returns:
        p: probability of each action
    """

    q = q.flatten()
    num_actions = q.shape[0]
    
    q_max = np.max(q)
    argmax_a = random_argmax(q) 

    p = 1 / (num_actions + (1/eta)*(q_max - q))

    p[argmax_a] = 1 - np.sum(np.delete(p, [argmax_a]))

    return p 


def sensitivity(title :str, df : pd.DataFrame, ci=None, x='Exp Value', y='Total Return', hue='Exploration Technique', xlabel=r'$\epsilon, \eta, \tau, \omega$', legend=True):
    """
    This funtion will visualize the accumulated rewards through steps. Will average across all random seeds
    title : str
        The title that will be shown at the top of the plot
    ci : None
        Whether to plot confidence intervals
    x : str
        X-axis variable in df
    y : str
        y-axis variable in df
    hue : str
        categorical variable. Plots a difference line for each
    """


    df = df.replace('epsilon-greedy', r'$\epsilon$-greedy')
    sns.set(font_scale=1.6)
    sns.set_style(style='white')    
    plot = sns.lineplot(x=x, y=y, hue=hue, data=df, ci=ci, palette=sns.color_palette("tab10", len(df['Exploration Technique'].unique())))
    
    plot.set_xscale('log', base=2)
    plt.xlabel(xlabel) 
    plot.set_title(title)

    # remove every second label 
    for ind, label in enumerate(plot.get_xticklabels()):
        if ind % 2 == 0: 
            label.set_visible(True)
        else:
            label.set_visible(False)

    plot.legend(framealpha=0, title='') 

    if legend==False:
        plot.get_legend().remove()

    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
   
    return plot


def figure_1():
    plt.clf()
    plt.rcParams.update({'font.size': 14})

    figure(figsize=(5, 3), dpi=80)

    xs = np.arange(0, 10, 0.01)
    params = [2**x for x in range(-3, 5, 2)]
    colours = sns.color_palette("rocket",  len(params)).as_hex()

    for i, param in enumerate(params):
        ys = [softmax((1/param)* np.array([x, 0]))[0] for x in xs]
        plt.plot(xs, ys, color=colours[i], label=param)
        ys = [resmax(param, np.array([x, 0]))[0] for x in xs]
        plt.plot(xs, ys, linestyle='dashed', color=colours[i])

    ys = [softmax(  (1/0.125)* np.array([x, 0]))[0] for x in xs]
    plt.plot(xs, ys, color=colours[0], label='softmax')

    ys = [resmax(0.125, np.array([x, 0]))[0] for x in xs]
    plt.plot(xs, ys, linestyle='dashed', color=colours[0], label='resmax')

    plt.xlabel(r'$q(a_1)-q(a_2)$')
    plt.ylabel(r'$\pi(a_1)$', rotation=90)

    # add two legends
    handles, labels = plt.gca().get_legend_handles_labels()
    leg1 = plt.legend([handles[idx] for idx in range(4)],[labels[idx] for idx in range(4)], loc='lower left', bbox_to_anchor=(-0., -0.55), framealpha=0, ncol=4, columnspacing=0.9, handletextpad=0.2)
    plt.legend([handles[idx] for idx in range(4, 6)],[labels[idx] for idx in range(4, 6)], loc='lower left', bbox_to_anchor=(-0, -0.7), framealpha=0, ncol=4, handletextpad=0.2)
    plt.gca().add_artist(leg1)

    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.savefig("figures/overemphasis.pdf", bbox_inches='tight')


def figure_2_b():
    plt.clf()
    save_dir = Path('figures')/'hardsquare'
    if not path.exists(save_dir):
        os.makedirs(save_dir)

    params, data = load_results('results/results_hardsquare', 'trace', ['env_name', 'algorithm', 'exploration_strategy', 'exp_val', 'step_size'], 
        filters={'exploration_strategy' : ['resmax', 'softmax', 'epsilon-greedy', 'mellowmax']})

    df1 = pd.DataFrame(params)
    df2 = pd.DataFrame(data)
    df = pd.concat([df1, df2], axis=1)
    df.columns = ['Environment', 'Algorithm', 'Exploration Technique', 'Exp Value',  'Step Size', 'Total Return']
    df['Exp Value'] = df['Exp Value'] .astype(float)
    df['Step Size'] = df['Step Size'] .astype(float)

    for step_size in df['Step Size'].unique():  
        for env in df["Environment"].unique():
            for alg in df["Algorithm"].unique():
                step_df = df[df['Step Size']==step_size]
                env_df = step_df[step_df["Environment"]==env]
                alg_df = env_df[env_df["Algorithm"]==alg]
                sensitivity("", alg_df, ci=68, xlabel=r'$\eta, \tau, \omega$', legend=False)
                plot = plt.savefig(save_dir/("%s_%s_step_size_%0.3f.pdf" % (env[0:8], alg, step_size)), bbox_inches = "tight")
                plt.clf()


def figure_3_a():
    save_dir = Path('figures')/'riverswim'
    if not path.exists(save_dir):
        os.makedirs(save_dir)

    params, data = load_results('results/results_riverswim', 'trace', ['env_name', 'algorithm', 'exploration_strategy', 'exp_val', 'step_size'], 
        filters={'exploration_strategy' : ['resmax', 'softmax', 'epsilon-greedy', 'mellowmax']})

    df1 = pd.DataFrame(params)
    df2 = pd.DataFrame(data)
    df = pd.concat([df1, df2], axis=1)
    df.columns = ['Environment', 'Algorithm', 'Exploration Technique', 'Exp Value',  'Step Size', 'Total Return']
    df['Exp Value'] = df['Exp Value'] .astype(float)
    df['Step Size'] = df['Step Size'] .astype(float)
    df['Total Return'] = df['Total Return'].astype(float) * 1000 # since averaged results are only saved every 1000 steps

    for step_size in df['Step Size'].unique():  
        for env in df["Environment"].unique():
            for alg in df["Algorithm"].unique():
                step_df = df[df['Step Size']==step_size]
                env_df = step_df[step_df["Environment"]==env]
                alg_df = env_df[env_df["Algorithm"]==alg]
                sensitivity("", alg_df, ci=68, legend=False)
                plot = plt.savefig(save_dir/("%s_%s_step_size_%0.3f.pdf" % (env[0:8], alg, step_size)), bbox_inches = "tight")
                plt.clf()


def figure_3_b():
    save_dir = Path('figures')/'stochastic_rs'
    if not path.exists(save_dir):
        os.makedirs(save_dir)

    params, data = load_results('results/results_stochastic_rs', 'trace', ['env_name', 'algorithm', 'exploration_strategy', 'exp_val', 'step_size'], 
        filters={'exploration_strategy' : ['resmax', 'softmax', 'epsilon-greedy', 'mellowmax']})

    df1 = pd.DataFrame(params)
    df2 = pd.DataFrame(data)
    df = pd.concat([df1, df2], axis=1)
    df.columns = ['Environment', 'Algorithm', 'Exploration Technique', 'Exp Value',  'Step Size', 'Total Return']
    df['Exp Value'] = df['Exp Value'] .astype(float)
    df['Step Size'] = df['Step Size'] .astype(float)
    df['Total Return'] = df['Total Return'].astype(float) * 1000 # since averaged results are only saved every 1000 steps

    for step_size in df['Step Size'].unique():  
        for env in df["Environment"].unique():
            for alg in df["Algorithm"].unique():
                step_df = df[df['Step Size']==step_size]
                env_df = step_df[step_df["Environment"]==env]
                alg_df = env_df[env_df["Algorithm"]==alg]
                sensitivity("", alg_df, ci=68, legend=False)
                plot = plt.savefig(save_dir/("%s_%s_step_size_%0.3f.pdf" % (env[0:8], alg, step_size)), bbox_inches = "tight")
                plt.clf()

def figure_4():
    execution_times = {
        'mellowmax' : [],
        'softmax' : [],
        'resmax' : [],
        'epsilon-greedy' : []
        }

    def add_times(root):
        '''
        Adds results to execution times
        '''
        nonlocal execution_times
        for f in root.rglob('log_file'):
                with open(f, 'r') as file:
                    string = file.read()
                    m = re.findall('Execution Time \(s\): [0-9]*\.[0-9]*', string)[0]
                    seconds = float(m.replace('Execution Time (s): ', ''))

                    m = re.findall('\'exploration_strategy\': \'[a-z|-]*\'', string)[0]
                    exp_strat =  m.replace('\'exploration_strategy\': \'', '') 
                    exp_strat = exp_strat.replace('\'', '')    

                    try:
                        execution_times[exp_strat].append(seconds)
                    except:
                        pass

    add_times(Path('results/results_riverswim'))
    add_times(Path('results/results_stochastic_rs'))

    labels = []
    means = []
    errors = []

    palette=sns.color_palette("tab10", 4)
    for j, i in enumerate(execution_times.keys()):
        mean = np.mean(execution_times[i])
        std = np.std(execution_times[i])
        labels.append(i)
        means.append(mean)
        errors.append(std)

    order = [2, 0 , 1, 3]
    colors = [palette[i] for i in order]

    import matplotlib
    plt.figure(figsize=(4, 2))
    matplotlib.rcParams.update({'font.size': 8})
    plt.bar(range(4), means, yerr=errors,  capsize=10, color=colors)
    #plt.xticks(rotation = 20) 
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.ylabel('Runtime (s)')
    plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False)

    colors = {l:c for l, c in zip(labels, colors)}        
    labels = list(colors.keys())
    handles = [plt.Rectangle((0,1),1,1, color=colors[label]) for label in labels]
    plt.legend(handles, labels, frameon=False)
    plt.savefig('figures/runtime.pdf', bbox_inches='tight')


def figure_5_a():

    def moving_average(x, w):
        return np.convolve(x, np.ones(w), 'valid') / w

    def visitation_to_freq(s, ep_len=20, state=5):
        x = np.reshape(s==state, (len(s)//ep_len, ep_len ))
        return np.average(x, axis=1)

    # load data
    params, data = load_results('results/results_riverswim', 'state', ['env_name', 'algorithm', 'exploration_strategy', 'exp_val', 'step_size'], filters={'algorithm':['expected-sarsa']})
    df1 = pd.DataFrame(params)

    df2 = pd.DataFrame(columns=['States'])
    for i, row in enumerate(data):
        df2.loc[i, 'States'] = row
        
    df = pd.concat([df1, df2], axis=1)
    df.columns = ['Environment', 'Algorithm', 'Exploration Technique', 'Exp Value',  'Step Size', 'States Visited']
    df['Exp Value'] = df['Exp Value'] .astype(float)
    df['Step Size'] = df['Step Size'] .astype(float)

    palette=sns.color_palette("tab10", 2)

    resmax = { # the optimal hyperparameters for resmax in each env
        'gym_riverswim:riverswim-v0' : 1/16,
        'riverswim_variants:stochastic-riverswim-v0' : 1/32
    }

    colors = {
        'resmax' : palette[1],
        'softmax' : palette[0]
    }

    import matplotlib
    plt.figure(figsize=(5, 3))
    matplotlib.rcParams.update({'font.size': 16})

    for env in df["Environment"].unique():
        env_df = df[df["Environment"]==env]

        for alg in env_df["Algorithm"].unique():
            alg_df = env_df[env_df["Algorithm"]==alg]
        
            for exp_tech in alg_df['Exploration Technique'].unique():
                if exp_tech == 'epsilon-greedy':
                    continue
                exp_df = alg_df[alg_df['Exploration Technique'] == exp_tech]
                
                for exp_val in exp_df['Exp Value'].unique():

                    if exp_tech == 'resmax' and exp_val != resmax[env]:
                        continue

                    for  s in [0, 5]:
                        plot_df = exp_df[exp_df['Exp Value'] == exp_val]
                        states = plot_df.apply(lambda row : visitation_to_freq(row['States Visited'], state=s), axis=1)
                        state_freq = np.stack(states.values, axis=0)
                        avg_state_freq = np.mean(state_freq, axis=0)
                        if s == 0:
                            alpha = 0.3
                        else:
                            alpha = 1

                        if s == 0:
                            plt.plot(moving_average(avg_state_freq, 10), c=colors[exp_tech], alpha=alpha)
                        else:
                            plt.plot(moving_average(avg_state_freq, 10), label = '{}'.format(exp_tech), c=colors[exp_tech], alpha=alpha)
                    
        
            plt.legend(frameon=False)
            plt.gca().spines['top'].set_visible(False)
            plt.gca().spines['right'].set_visible(False)
            plt.xlabel('Episode')
            plt.ylabel('Frequency of State')
            plt.savefig('figures/{}_state_visit.pdf'.format(env), bbox_inches='tight')
            plt.clf()


def figure_5_b():
    save_dir = Path('figures')/'policy_curves'
    if not path.exists(save_dir):
        os.makedirs(save_dir)

    # DF for normalized runs
    params, data = load_results('results/results_policy', 'policy', ['env_name', 'algorithm', 'exploration_strategy', 'exp_val', 'step_size'], filters={'exploration_strategy': ['epsilon-greedy', 'softmax', 'resmax']})

    plt.rcParams.update({'font.size': 30})

    figure(figsize=(10, 6), dpi=80)

    colours_map = {
        'epsilon-greedy' : '#2ca02c',
        'resmax' : '#ff7f0e',
        'softmax' : '#1f77b4'
        }

    for alg in np.unique(params[:, 1]):
        for exp_tech in ['softmax', 'resmax']:
            inds = np.where((params[:, 1]==alg)&(params[:,2]==exp_tech))
            data_plot = data[inds]
            y = np.mean(data_plot, axis=0)
            x = np.arange(0, y.shape[0])
            plt.xlabel('Episode')
            plt.ylabel('Prob. Optimal Action')
            plot = plt.plot(x, y, label=exp_tech, c=colours_map[exp_tech])

        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.legend(framealpha=0)
        plt.savefig(save_dir/'policy_{}.pdf'.format(alg), bbox_inches='tight')
        plt.clf()


def figure_6():
    '''
    NOTE: Since this funciton runs the MDP, it will take quite some time
    '''
    from agent.agents import TwoStateTabularAgent
    import gym

    def moving_average(x, w):
        return np.convolve(x, np.ones(w), 'valid') / w

    NUM_EPISODES = 10000000
    NUM_RUNS = 1
    EXP_VALUES = {
            'softmax': [1/16.55],
            'resmax':  [0.000085],
            }

    agent_params: dict = {
            'algorithm': 'expected-sarsa',
            'ac_dim': 2,
            'ob_dim': 2,
            'step_size': None,
            'input_shape': None, 
            'gamma': 0.98,
            'num_timesteps': NUM_EPISODES,
            'save_policy':  False,
            'horizon' : False,
            'initial_optimism' : 0}

    total_steps = NUM_EPISODES*NUM_RUNS* sum([len(EXP_VALUES[i]) for i in EXP_VALUES.keys()])

    qs = {
        'softmax' : np.zeros((len(EXP_VALUES['softmax']), NUM_EPISODES, 2)),
        'resmax' : np.zeros((len(EXP_VALUES['resmax']), NUM_EPISODES, 2)),
    }

    for exp in EXP_VALUES.keys():
        if exp == 'softmax':
            continue
        for i_v, v in enumerate(EXP_VALUES[exp]):
            agent_params['exploration_strategy'] = exp
            agent_params['exp_value'] = v

            for n in range(NUM_RUNS):
                env = gym.make('gym_exploration:TwoState-v0')
                env.reset()
                agent = TwoStateTabularAgent(env, agent_params)
                agent.start()
                for i in range(NUM_EPISODES):
                    agent.step_env()
                    qs[exp][i_v, i, :] += agent.q[0, 0, :]       

        for exp in qs.keys():
            qs[exp] /= NUM_RUNS
            for val in range(qs[exp].shape[0]):
                plt.plot(moving_average(qs[exp][val, :, 0], 10), label='Q(s_1, a)')
                plt.plot(moving_average(qs[exp][val, :, 1], 10), label='Q(s_1, b)')

    plt.xlabel('Iteration')
    plt.ylabel('Action-value')
    plt.legend()
    plt.savefig('{}-{}-{}.pdf'.format(agent_params['algorithm'], exp, agent_params['step_size']), bbox_inches='tight')
    plt.clf()


if __name__ == "__main__":

    save_dir = Path('figures')
    if not path.exists(save_dir):
        os.makedirs(save_dir)
    figure_1()
    figure_2_b()
    figure_3_a()
    figure_3_b()
    figure_4()
    figure_5_a()
    figure_5_b()
    figure_6()