import pandas as pd
import numpy as np
import plotly.graph_objects as go
import scipy.stats as stats
from tqdm import tqdm
import copy

from src.rl_main import ReinforcementLearning


class RLExperiments:
    def __init__(self):
        self.rl = None

    def start_experiment(self, agent, env, state_representation=None):
        self.rl = ReinforcementLearning(agent, env, state_representation)

    def pendulum(self, agent, env, state_representation, num_runs, max_steps, step_size):
        # create dataframe to store results
        all_results = pd.DataFrame()

        # run experiment
        for run in tqdm(range(num_runs)):
            self.start_experiment(agent, env, state_representation)
            np.random.seed(run)

            last_state, last_action = self.rl.rl_start()

            for step_n in range(max_steps):
                if step_size['value'] == '1/n':
                    use_step_size = {
                        'value': 1 / (step_n + 1),
                        'policy': step_size['policy'] * 1 / (step_n + 1),
                        'avg_reward': step_size['avg_reward'] * 1 / (step_n + 1),
                        'var': step_size['var'] * 1 / (step_n + 1),
                    }
                else:
                    use_step_size = {
                        'value': step_size['value'],
                        'policy': step_size['policy'] * step_size['value'],
                        'avg_reward': step_size['avg_reward'] * step_size['value'],
                        'var': step_size['var'] * step_size['value'],
                    }

                reward, state, action, terminal = self.rl.rl_step(last_state, last_action, step_size=use_step_size)

                last_state = state
                last_action = action

            # get experiment data
            results_df = self.rl.get_data()

            # add to all results
            results_df['run'] = run + 1
            all_results = pd.concat([all_results, results_df], ignore_index=True)

        return all_results

    def redpillbluepill(self, agent, env, num_runs, max_steps, step_size, epsilon):
        # create dataframe to store results
        all_results = pd.DataFrame()

        # run experiment
        for run in tqdm(range(num_runs)):
            self.start_experiment(agent, env)
            np.random.seed(run)

            last_state, last_action = self.rl.rl_start()

            for step_n in range(max_steps):
                if step_size['value'] == '1/n':
                    use_step_size = {
                        'value': 1 / (step_n + 1),
                        'avg_reward': step_size['avg_reward'] * 1 / (step_n + 1),
                        'var': step_size['var'] * 1 / (step_n + 1),
                    }
                else:
                    use_step_size = {
                        'value': step_size['value'],
                        'avg_reward': step_size['avg_reward'] * step_size['value'],
                        'var': step_size['var'] * step_size['value'],
                    }

                reward, state, action, terminal = self.rl.rl_step(last_state, last_action, step_size=use_step_size, epsilon=epsilon)

                last_state = state
                last_action = action

            # get experiment data
            results_df = self.rl.get_data()

            # add to all results
            results_df['run'] = run + 1
            all_results = pd.concat([all_results, results_df], ignore_index=True)

        return all_results

    def get_performance_figure(self, experiment, df_cvar, df_reg, quantile, rolling_average_amount=1000,
                               x_max=100000, confidence_interval=0.95):
        fig = go.Figure()

        df_dict = {
            'Differential': {
                'df': df_reg,
                'color_cvar': '#007FA3',
                'color_average': '#2FD1FF',
                'color_cvar_ci': '#C5D3EE',
                'color_average_ci': '#BAF0FF',
            },
            'RED CVaR': {
                'df': df_cvar,
                'color_cvar': '#AB1368',
                'color_average': '#EC52A8',
                'color_cvar_ci': '#E9C9EF',
                'color_average_ci': '#F9C5E2',
            },
        }

        # for confidence interval
        z_value = stats.norm.ppf(0.5 + confidence_interval / 2)

        runs = df_cvar['run'].unique()
        for df_name in df_dict.keys():
            df_runs = pd.DataFrame()
            cvar_cols = []
            avg_cols = []
            for run_i in tqdm(range(len(runs))):
                # get dataframe
                df = df_dict[df_name]['df']
                df = df[df['run'] == runs[run_i]]
                df = df.reset_index(drop=True)

                # get rolling cvar and average
                rolling_cvar = []
                rolling_average = []
                counter = []
                for i in range(rolling_average_amount, len(df)):
                    rewards = df.iloc[i - rolling_average_amount:i]
                    VAR = rewards['reward'].quantile(quantile)
                    rolling_cvar.append(rewards[rewards['reward'] < VAR]['reward'].mean())
                    rolling_average.append(rewards['reward'].mean())
                    counter.append(i)

                df_runs['cvar_run_' + str(run_i)] = rolling_cvar
                df_runs['avg_run_' + str(run_i)] = rolling_average

                cvar_cols.append('cvar_run_' + str(run_i))
                avg_cols.append('avg_run_' + str(run_i))

            cvar_mean = df_runs[cvar_cols].mean(axis=1)
            cvar_std = df_runs[cvar_cols].std(axis=1)
            avg_mean = df_runs[avg_cols].mean(axis=1)
            avg_std = df_runs[avg_cols].std(axis=1)

            cvar_ci_upper = cvar_mean + (z_value * (cvar_std / np.sqrt(len(runs))))
            cvar_ci_lower = cvar_mean - (z_value * (cvar_std / np.sqrt(len(runs))))
            avg_ci_upper = avg_mean + (z_value * (avg_std / np.sqrt(len(runs))))
            avg_ci_lower = avg_mean - (z_value * (avg_std / np.sqrt(len(runs))))

            fig.add_trace(go.Scatter(x=counter, y=cvar_ci_upper, mode='lines',
                                     line=dict(color=df_dict[df_name]['color_cvar_ci'], width=0.5),
                                     showlegend=False))

            fig.add_trace(go.Scatter(x=counter, y=cvar_ci_lower, mode='lines',
                                     name=df_name + ': Reward CVaR ' + str(np.around(100 * confidence_interval, 0)).replace('.0','') + '% CI',
                                     line=dict(color=df_dict[df_name]['color_cvar_ci'], width=0.5),
                                     fill='tonexty', fillcolor=df_dict[df_name]['color_cvar_ci']))

            fig.add_trace(go.Scatter(x=counter, y=cvar_mean, mode='lines', name=df_name + ': Reward CVaR',
                                     line=dict(color=df_dict[df_name]['color_cvar'], width=2)))

            fig.add_trace(go.Scatter(x=counter, y=avg_ci_upper, mode='lines',
                                     line=dict(color=df_dict[df_name]['color_average_ci'], width=0.5),
                                     showlegend=False))

            fig.add_trace(go.Scatter(x=counter, y=avg_ci_lower, mode='lines',
                                     name=df_name + ': Average Reward ' + str(np.around(100 * confidence_interval, 0)).replace('.0','') + '% CI',
                                     line=dict(color=df_dict[df_name]['color_average_ci'], width=0.5),
                                     fill='tonexty', fillcolor=df_dict[df_name]['color_average_ci']))

            fig.add_trace(go.Scatter(x=counter, y=avg_mean, mode='lines', name=df_name + ': Average Reward',
                                     line=dict(color=df_dict[df_name]['color_average'], width=2)))

        fig.update_xaxes(title='Time Step', range=[rolling_average_amount, x_max], linewidth=3, mirror=False,
                         ticks='outside', showline=True, linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_yaxes(title='Reward', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_layout(template='plotly_white', height=500, width=800, font=dict(color='#3F3F3F', size=14, family='times'),
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')

        fig.show()
        return

    def get_3d_plot(self, agent, env, num_runs, max_steps, step_size, epsilon, plots):
        fig = go.Figure()

        for plot_id in plots.keys():
            plot_info = plots[plot_id]

            use_agent = copy.deepcopy(agent)

            use_agent.avg_reward = plot_info['init_avg_reward']
            use_agent.var_reward = plot_info['init_var_reward']

            z_max = plot_info['z_max']
            color = plot_info['color']

            all_results = pd.DataFrame()
            for run in range(num_runs):
                self.start_experiment(use_agent, env)
                np.random.seed(run)

                last_state, last_action = self.rl.rl_start()
                rewards = []
                cvar_rewards = []
                avg_reward = []
                var_reward = []
                for step_n in tqdm(range(max_steps)):
                    if step_size['value'] == '1/n':
                        use_step_size = {
                            'value': 1 / (step_n + 1),
                            'avg_reward': step_size['avg_reward'] * 1 / (step_n + 1),
                            'var': step_size['var'] * 1 / (step_n + 1),
                        }
                    else:
                        use_step_size = {
                            'value': step_size['value'],
                            'avg_reward': step_size['avg_reward'] * step_size['value'],
                            'var': step_size['var'] * step_size['value'],
                        }

                    reward, state, action, terminal = self.rl.rl_step(last_state, last_action, step_size=use_step_size, epsilon=epsilon)

                    last_state = state
                    last_action = action

                    rewards.append(reward)
                    avg_reward.append(self.rl.agent.avg_reward)
                    var_reward.append(self.rl.agent.var_reward)
                    cvar_rewards.append(self.rl.agent.var_reward - (1 / self.rl.agent.var_quantile) * max(0, self.rl.agent.var_reward - reward))

                # get experiment data
                results_df = self.rl.get_data()

                # add to all results
                results_df['run'] = run + 1
                all_results = pd.concat([all_results, results_df], ignore_index=True)

            fig.add_trace(go.Scatter3d(
                x=var_reward[:z_max],
                y=avg_reward[:z_max],
                z=np.arange(len(var_reward)),
                mode='markers',
                marker=dict(
                    size=2,
                    opacity=0.75,
                    color=color,

                )
            ))

        fig.update_layout( template='plotly_white', showlegend=False,
                          height=600, width=1200,
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)',
                          scene=dict(
                              aspectratio_x=2, aspectratio_y=1, aspectratio_z=1,
                              xaxis=dict(
                                  title='VaR Estimate', linewidth=1, mirror=True, ticks='outside',
                                  showline=True, linecolor='#E7E8E4', gridcolor='#E7E8E4',
                              ),
                              yaxis=dict(
                                  title='CVaR Estimate', linewidth=1, mirror=True, ticks='outside',
                                  showline=True, linecolor='#E7E8E4', gridcolor='#E7E8E4',
                              ),
                              zaxis=dict(
                                  title='Time Step', linewidth=1, mirror=True, ticks='outside',
                                  showline=True, linecolor='#E7E8E4', gridcolor='#E7E8E4',
                              )),
                          font=dict(color='#3F3F3F', family='times'),
                          )

        fig.update_scenes(zaxis_autorange="reversed")

        fig.show()
        return

    def cvar_redpillbluepill_comparison(self, experiment, agent, env, num_runs, max_steps, step_size, epsilon):
        # run experiment
        for run in range(num_runs):
            self.start_experiment(agent, env)
            np.random.seed(run)

            last_state, last_action = self.rl.rl_start()

            avg_reward = []
            var_reward = []
            for step_n in tqdm(range(max_steps)):
                if step_size['value'] == '1/n':
                    use_step_size = {
                        'value': 1 / (step_n + 1),
                        'avg_reward': step_size['avg_reward'] * 1 / (step_n + 1),
                        'var': step_size['var'] * 1 / (step_n + 1),
                    }
                else:
                    use_step_size = {
                        'value': step_size['value'],
                        'avg_reward': step_size['avg_reward'] * step_size['value'],
                        'var': step_size['var'] * step_size['value'],
                    }

                reward, state, action, terminal = self.rl.rl_step(last_state, last_action, step_size=use_step_size, epsilon=epsilon)

                last_state = state
                last_action = action

                avg_reward.append(self.rl.agent.avg_reward)
                var_reward.append(self.rl.agent.var_reward)

        # get plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=list(range(1, max_steps + 1))[0::100], y=avg_reward[0::100], name='CVaR Estimate',
                                 mode='lines', line=dict(color='#AB1368', width=3)))

        fig.add_trace(go.Scatter(x=list(range(1, max_steps + 1))[0::100], y=var_reward[0::100], name='VaR Estimate',
                                 mode='lines', line=dict(color='#8DBF2E', width=3)))

        fig.update_xaxes(title='Time Step', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_yaxes(title='Reward', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_layout(template='plotly_white', height=500, width=800,
                          font=dict(color='#3F3F3F', size=14, family='times'),
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')

        fig.show()
        return

    def cvar_pendulum_comparison(self, experiment, agent, env, state_representation, num_runs, max_steps, step_size):
        # run experiment
        for run in range(num_runs):
            self.start_experiment(agent, env, state_representation)
            np.random.seed(run)

            last_state, last_action = self.rl.rl_start()

            avg_reward = []
            var_reward = []
            for step_n in tqdm(range(max_steps)):
                if step_size['value'] == '1/n':
                    use_step_size = {
                        'value': 1 / (step_n + 1),
                        'policy': step_size['policy'] * 1 / (step_n + 1),
                        'avg_reward': step_size['avg_reward'] * 1 / (step_n + 1),
                        'var': step_size['var'] * 1 / (step_n + 1),
                    }
                else:
                    use_step_size = {
                        'value': step_size['value'],
                        'policy': step_size['policy'] * step_size['value'],
                        'avg_reward': step_size['avg_reward'] * step_size['value'],
                        'var': step_size['var'] * step_size['value'],
                    }

                reward, state, action, terminal = self.rl.rl_step(last_state, last_action, step_size=use_step_size)

                last_state = state
                last_action = action

                avg_reward.append(self.rl.agent.avg_reward)
                var_reward.append(self.rl.agent.var_reward)

        # get plot
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=list(range(1, max_steps + 1))[0::100], y=avg_reward[0::100], name='CVaR Estimate',
                                 mode='lines', line=dict(color='#AB1368', width=3)))

        fig.add_trace(go.Scatter(x=list(range(1, max_steps + 1))[0::100], y=var_reward[0::100], name='VaR Estimate',
                                 mode='lines', line=dict(color='#8DBF2E', width=3)))

        fig.update_xaxes(title='Time Step', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_yaxes(title='Reward', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_layout(template='plotly_white', height=500, width=800,
                          font=dict(color='#3F3F3F', size=14, family='times'),
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')

        fig.show()
        return

    def get_cvar_by_tau_plot(self, n_samples=100000, epsillon=0.1):
        # Get and plot the CVaR values by tau. CVaR values are estimated using monte carlo

        # Use distributions for Red-Pill Blue-Pill environment
        dist_1 = {'mean': -0.7, 'stdev': 0.05}
        dist_2a = {'mean': -1, 'stdev': 0.05}
        dist_2b = {'mean': -0.2, 'stdev': 0.05}
        dist_2_prob = 0.5

        df_samples = pd.DataFrame()

        tau_list = []
        red_var_list = []
        red_cvar_list = []
        blue_var_list = []
        blue_cvar_list = []

        for tau in np.arange(0.01, 1, 0.1):
            # get CVaR of red policy:
            samples = []
            for i in range(n_samples):
                p = np.random.rand()
                if p > epsillon:
                    samples.append(np.random.normal(loc=dist_1['mean'], scale=dist_1['stdev']))
                else:
                    dist = np.random.choice(['dist2a', 'dist2b'], p=[dist_2_prob, 1 - dist_2_prob])
                    if dist == 'dist2a':
                        samples.append(np.random.normal(loc=dist_2a['mean'], scale=dist_2a['stdev']))
                    elif dist == 'dist2b':
                        samples.append(np.random.normal(loc=dist_2b['mean'], scale=dist_2b['stdev']))

            df_samples['tau_' + str(tau) + '_red_samples'] = samples

            red_var = np.quantile(samples, q=tau)
            red_cvar = df_samples[df_samples['tau_' + str(tau) + '_red_samples'] <= red_var][
                'tau_' + str(tau) + '_red_samples'].mean()

            # get CVaR of blue policy:
            samples = []
            for i in range(n_samples):
                p = np.random.rand()
                if p <= epsillon:
                    samples.append(np.random.normal(loc=dist_1['mean'], scale=dist_1['stdev']))
                else:
                    dist = np.random.choice(['dist2a', 'dist2b'], p=[dist_2_prob, 1 - dist_2_prob])
                    if dist == 'dist2a':
                        samples.append(np.random.normal(loc=dist_2a['mean'], scale=dist_2a['stdev']))
                    elif dist == 'dist2b':
                        samples.append(np.random.normal(loc=dist_2b['mean'], scale=dist_2b['stdev']))

            df_samples['tau_' + str(tau) + '_blue_samples'] = samples

            blue_var = np.quantile(samples, q=tau)
            blue_cvar = df_samples[df_samples['tau_' + str(tau) + '_blue_samples'] <= blue_var][
                'tau_' + str(tau) + '_blue_samples'].mean()

            tau_list.append(tau)
            red_var_list.append(red_var)
            red_cvar_list.append(red_cvar)
            blue_var_list.append(blue_var)
            blue_cvar_list.append(blue_cvar)

        # get results
        df_results = pd.DataFrame({
            'tau': tau_list,
            'red_var': red_var_list,
            'red_cvar': red_cvar_list,
            'blue_var': blue_var_list,
            'blue_cvar': blue_cvar_list,

        })

        df_results['optimal_policy'] = 'red'
        df_results.loc[df_results['red_cvar'] < df_results['blue_cvar'], 'optimal_policy'] = 'blue'

        # plot results
        fig = go.Figure()

        fig.add_trace(go.Scatter(
            mode='lines',
            name='CVaR of Red Policy',
            x=df_results['tau'],
            y=df_results['red_cvar'],
            line=dict(color='#AB1368', width=3))
        )

        fig.add_trace(go.Scatter(
            mode='lines',
            name='CVaR of Blue Policy',
            x=df_results['tau'],
            y=df_results['blue_cvar'],
            line=dict(color='#007FA3', width=3)),
        )

        fig.update_xaxes(title='CVaR Parameter, τ', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_yaxes(title='CVaR', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_layout(template='plotly_white', height=500, width=1000,
                          font=dict(color='#3F3F3F', size=14, family='times'),
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')

        fig.show()
        return

    def get_tau_results_figure(self, experiment, results_dict, n_runs, rolling_average_amount=1000,
                              x_max=100000, confidence_interval=0.95):

        df_dict = {
            0.1: {
                'df': results_dict[0.1],
                'color_percent': '#AB1368',
                'color_percent_ci': 'rgba(213, 137, 179, 0.5)',
            },
            0.25: {
                'df': results_dict[0.25],
                'color_percent': '#007FA3',
                'color_percent_ci': 'rgba(197, 211, 238, 0.5)',
            },
            0.5: {
                'df': results_dict[0.5],
                'color_percent': '#F1C500',
                'color_percent_ci': 'rgba(248, 226, 128, 0.5)',
            },
            0.75: {
                'df': results_dict[0.75],
                'color_percent': '#00A189',
                'color_percent_ci': 'rgba(128, 208, 196, 0.5)',
            },
            0.85: {
                'df': results_dict[0.85],
                'color_percent': '#DC4633',
                'color_percent_ci': 'rgba(237, 162, 153, 0.5)',
            },
            0.9: {
                'df': results_dict[0.90],
                'color_percent': '#8DBF2E',
                'color_percent_ci': 'rgba(198, 223, 150, 0.5)',

            },
        }

        fig = go.Figure()

        # for confidence interval
        z_value = stats.norm.ppf(0.5 + confidence_interval / 2)

        for df_name in reversed(df_dict.keys()):
            quantile = df_name
            df_runs = pd.DataFrame()
            plot_cols = []
            for run_i in tqdm(range(n_runs)):
                # get dataframe
                df = df_dict[df_name]['df']
                df = df[df['run'] == run_i + 1]
                df = df.reset_index(drop=True)

                # get rolling percent of time in blue state
                rolling_percent = []
                counter = []
                for i in range(rolling_average_amount, len(df)):
                    states = df.iloc[i - rolling_average_amount:i]
                    rolling_percent.append(len(states[states['state'] == 'blueworld']) / len(states))
                    counter.append(i)

                df_runs['percent_in_blueworld_run_' + str(run_i)] = rolling_percent

                plot_cols.append('percent_in_blueworld_run_' + str(run_i))

            percent_mean = df_runs[plot_cols].mean(axis=1)
            percent_std = df_runs[plot_cols].std(axis=1)

            percent_ci_upper = percent_mean + (z_value * (percent_std / np.sqrt(n_runs)))
            percent_ci_lower = percent_mean - (z_value * (percent_std / np.sqrt(n_runs)))

            fig.add_trace(go.Scatter(x=counter, y=percent_ci_upper, mode='lines',
                                     line=dict(color=df_dict[df_name]['color_percent_ci'], width=0.5),
                                     showlegend=False))

            fig.add_trace(go.Scatter(x=counter, y=percent_ci_lower, mode='lines',
                                     line=dict(color=df_dict[df_name]['color_percent_ci'], width=0.5),
                                     fill='tonexty', fillcolor=df_dict[df_name]['color_percent_ci'], showlegend=False))

            fig.add_trace(go.Scatter(x=counter, y=percent_mean, mode='lines', name='τ=' + str(quantile) + '',
                                     line=dict(color=df_dict[df_name]['color_percent'], width=2)))

        fig.update_xaxes(title='Time Step', range=[rolling_average_amount, x_max], linewidth=3, mirror=False,
                         ticks='outside', showline=True, linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_yaxes(title='Percent of Time in Blue World State (x100%)', linewidth=3, mirror=False, ticks='outside', showline=True,
                         linecolor='#262626', gridcolor='rgba(243,243,241, 0.75)', gridwidth=1)

        fig.update_layout(template='plotly_white', height=500, width=1000, font=dict(color='#3F3F3F', size=14, family='times'),
                          paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)')

        fig.show()
        return
