import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

COLOR_MAP = {
    'TRPO'     :   '#1f77b4',  # Original
    'PPO'      :   '#ff7f0e',  # Original
    'A2C'      :   '#2ca02c',  # Original
    'DQN'      :   '#d62728',  # Original
    'NACE'     :   '#9467bd',  # Original
    'BT'       :   '#8c564b',  # Original
    'IMPALA'   :   '#e377c2',  # Original
    'COUNT'    :   '#7f7f7f',  # Original
    'RIDE'     :   '#bcbd22',  # Original
    'CURIOSITY':   '#17becf',  # Original
    'RND'      :   '#ddbb78',  # New color for RND
    'AMIGO'    :   '#98df8a',   # New color for AMIGO
    'DreamerV3' :  '#a0aa8f'
}


BT_ALL_ENVS_MAP = {"MiniGrid-ObstructedMaze-1Dlhb-v0":0,
                   "MiniGrid-Empty-6x6-v0":           1,
                   "MiniGrid-Empty-8x8-v0":           2,
                   "MiniGrid-Empty-Random-5x5-v0":    3,
                   "MiniGrid-Empty-Random-6x6-v0":    4,
                   "BabyAI-GoToRedBallNoDists-v0":    5,
                   "MiniGrid-DistShift2-v0":          6,
                   "MiniGrid-LavaGapS7-v0":           7,
                   "MiniGrid-FourRooms-v0":           8,
                   "MiniGrid-MultiRoom-N6-v0":        9,
                   "MiniGrid-SimpleCrossingS11N5-v0": 10,
                   "MiniGrid-LavaCrossingS11N5-v0":   11,
                   "MiniGrid-Unlock-v0":              12,
                   "MiniGrid-DoorKey-8x8-v0":         13,
                   "MiniGrid-UnlockPickup-v0":        14,
                   "MiniGrid-Empty-16x16-v0":         15}

NACE_env_list = [
    "MiniGrid-BlockedUnlockPickup-v0"
]

RL_NACE_env_list = [
    "BabyAI-GoToRedBallNoDists-v0",
    "MiniGrid-DistShift2-v0",
    "MiniGrid-DoorKey-8x8-v0",
    "MiniGrid-Empty-16x16-v0",
    "MiniGrid-Empty-6x6-v0",
    "MiniGrid-Empty-8x8-v0",
    "MiniGrid-Empty-Random-5x5-v0",
    "MiniGrid-Empty-Random-6x6-v0",
    "MiniGrid-LavaCrossingS11N5-v0",
    "MiniGrid-LavaGapS7-v0",
    "MiniGrid-SimpleCrossingS11N5-v0",
    "MiniGrid-Unlock-v0",
    "MiniGrid-UnlockPickup-v0",
    "MiniGrid-BlockedUnlockPickup-v0"
]

RL_env_list = {
    "MiniGrid-Empty-5x5-v0",
    "MiniGrid-FourRooms-v0",
    "MiniGrid-MultiRoom-N6-v0"
}

base_dir='./'
save_dir='./SavedCSV/'
algos=["AMIGO", 'TRPO', 'PPO', 'A2C', 'DQN', 'NACE', 'BT', "IMPALA", "COUNT", "RIDE", "CURIOSITY", "RND", "DreamerV3"]

#mvanilla="IMPALA"
#mcount="COUNT"
#mride="RIDE"
#mcuriosity = "CURIOSITY"
#mrnd = "RND"
#mamigo = "AMIGO"

window = 10000
C = 0.5
standard_deviations = dict([])

def plotter_func(base_dir, save_dir, envname, models, window, C):
    # Create a 2x2 subplot layout
    fig, axs = plt.subplots(figsize=(5, 3.5)) #plt.subplots(1, 2, figsize=(10, 5))
    models_generated_str = "" # For some ENVs not all algos are present
    # Iterate over available algorithms
    for i, model in enumerate(models):
        #if model == 'NACE':
        #    window = 500
        if model == 'BT': # If BT (hand solution) generate DF based on scalar
            if not BT_ALL_ENVS_MAP.get(envname): continue # If there is no solution for given ENV then skip
            df_temp = pd.read_csv(f"{save_dir}All_Envs_{model}.csv")
            df_temp = df_temp[df_temp['Env']==BT_ALL_ENVS_MAP.get(envname)]
            df = pd.DataFrame(columns=df_temp.columns[1:], index=np.arange(100, 10e6+100, 100, dtype=int)).reset_index()
            df.columns.values[0] = 'Timesteps'
            for col in df.columns[1:]:
                df[col] = df_temp[col].values.tolist()[0]
        else:
            try: df = pd.read_csv(f"{save_dir}{envname}_{model}.csv")
            except:continue

        models_generated_str += model + ' '
        # Plot Left: Mean Episode Reward
        rolling_ep_reward_MEAN = df['Mean Episode Reward'].rolling(window, min_periods=1).mean()
        rolling_ep_reward_STD  = df['Standard Deviation of Episode Reward'].rolling(window, min_periods=1).mean()
        
        stdwindow = 1000 * 100
        """rolling_ep_reward_MEAN2 = df['Mean Episode Reward'].rolling(stdwindow, min_periods=1).mean()
        # Calculate residuals (the difference between the actual and smoothed mean episode reward)
        residuals = df['Mean Episode Reward'] - rolling_ep_reward_MEAN2
        # Compute rolling variance of the residuals (this captures the noise)
        rolling_ep_reward_VAR2 = residuals.rolling(window, min_periods=1).var()
        rolling_ep_reward_STD2 = rolling_ep_reward_VAR2 ** 0.5"""
        
        
        #estimate standard deviation where it is missing:
        # Calculate residuals (the difference between the actual and smoothed mean episode reward)
        residuals = df['Mean Episode Reward'] - df['Mean Episode Reward'].rolling(stdwindow, min_periods=1).mean()
        # Step 1: Bias-corrected variance using Bessel's correction (ddof=1)
        rolling_ep_reward_VAR2 = residuals.rolling(stdwindow, min_periods=1).var(ddof=1)
        # Step 2: Adjust variance by scaling with the window size
        scaling_factor = (stdwindow / (stdwindow - 1))  # Window size correction factor
        adjusted_var = scaling_factor * rolling_ep_reward_VAR2
        # Step 3: Take the square root to get the standard deviation
        rolling_ep_reward_STD2 = adjusted_var ** 0.5
        #print("STD estimate",i, model, rolling_ep_reward_STD.iloc[-1] - rolling_ep_reward_STD2.iloc[-1])
        # Standard deviation is the square root of the variance
        if model in ["IMPALA", "COUNT", "RIDE", "CURIOSITY", "RND", "AMIGO"]:
            print("STD estimate",i, model, rolling_ep_reward_STD2.iloc[-1])
            rolling_ep_reward_STD = rolling_ep_reward_STD2 #rolling_ep_reward_VAR2 ** 0.5
            standard_deviations[model+"__"+envname] = rolling_ep_reward_STD2.iloc[-1]
            #print("ESTIMATE STDDEV FOR", model)
        else:
            print("STD         ",i, model, rolling_ep_reward_STD.iloc[-1])
            standard_deviations[model+"__"+envname] = rolling_ep_reward_STD.iloc[-1]
        
        axs.set_xscale('log')
        #axs.plot(df['Timesteps'], rolling_ep_reward_MEAN, label=model)
        #axs.fill_between(df['Timesteps'],
        #                    rolling_ep_reward_MEAN - C*rolling_ep_reward_STD,
        #                    rolling_ep_reward_MEAN + C*rolling_ep_reward_STD,
         #                   alpha=0.2,)
        axs.plot(df['Timesteps'], rolling_ep_reward_MEAN, label=model, color=COLOR_MAP.get(model, 'black'))
        axs.fill_between(df['Timesteps'],
                            rolling_ep_reward_MEAN - C*rolling_ep_reward_STD,
                            rolling_ep_reward_MEAN + C*rolling_ep_reward_STD,
                            alpha=0.2,
                            color=COLOR_MAP.get(model, 'black'))
        
        
        # Plot Right: Mean Episode Length
        rolling_ep_length_MEAN = df['Mean Episode Length'].rolling(window, min_periods=1).mean()
        rolling_ep_length_STD  = df['Standard Deviation of Episode Length'].rolling(window, min_periods=1).mean()
        #axs[1].set_xscale('log')
        #axs[1].plot(df['Timesteps'], rolling_ep_length_MEAN, label=model)
        #axs[1].fill_between(df['Timesteps'],
        #                    rolling_ep_length_MEAN - C*rolling_ep_length_STD,
        #                    rolling_ep_length_MEAN + C*rolling_ep_length_STD,
        #                    alpha=0.2,)

    if models_generated_str:
        # Titles and Legends
        #axs.set_title('Mean Episode Reward')
        #axs.legend(loc='center left', bbox_to_anchor=(-1.0, 0.5))
        axs.set_xlabel('Timesteps')
        axs.set_ylabel('Mean Reward')
        #axs[1].set_title('Mean Episode Length')
        #axs[1].set_xlabel('Timesteps')
        #axs[1].set_ylabel('Mean Length')
        # Adjust layout to prevent overlap
        #fig.suptitle(envname + " (Evaluation)")
        
        plt.subplots_adjust(right=0.33)  # Increase left margin

        # Add the legend with the desired position
        axs.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))  # Adjust bbox_to_anchor to move further left

        
        plt.tight_layout()
        #plt.legend(loc='center left', bbox_to_anchor=(-0.1, 0.5))
        print("Plot for {0} env generated with models: {1}".format(envname, models_generated_str))
        # Show the plot
        #plt.show(block=False)
        dire= base_dir + "Figure/"
        if not os.path.exists(dire):
            os.makedirs(dire)
        fig.savefig(dire+f"{envname}.png", dpi=144, format=None, metadata=None, bbox_inches=None,
                    pad_inches=0.0, facecolor='auto', edgecolor='auto',backend=None)







for envname in RL_NACE_env_list + NACE_env_list:
    plotter_func(base_dir, save_dir, envname, algos, window, C)

with open("stddevs_dict", "w") as f:
    f.write(str(standard_deviations))
