import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib.lines import Line2D
from matplotlib import rc
from matplotlib.ticker import ScalarFormatter
import os
import pandas as pd
import seaborn as sns
import argparse

import envs
import gym
from library import *
from envs.make_env import *
from envs.wrappers import *

import popgym
from popgym.wrappers import PreviousAction, Antialias, Flatten, DiscreteAction

import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
from gymnasium.wrappers import FlattenObservation

import torch
from algos.grpo import GRPO
from stable_baselines3 import DQN, TD3, PPO #, GRPO

#####################################################################################

parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='./images/',
    help="path"
)
args = parser.parse_args()
#####################################################################################

def plotdata(data, name, save_path):
    # s = 20
    # rc_ = {'figure.figsize':(10,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
    #        'ytick.labelsize': s, 'legend.fontsize': 20}
    # sns.set(rc=rc_, style="darkgrid")
    s = 50
    rc_ = {'figure.figsize':(10,8),'axes.labelsize': 60, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 25}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)
    
    fig, ax = plt.subplots()
    
    lw = 2.0
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    for i, (mean, std, dmax, label) in enumerate(data):
        x = np.arange(len(mean))*100
        ax.plot(x, mean, label=label, lw = lw, color=colors[i])
        ax.fill_between(x, mean - std, mean + std, alpha=0.4, color=colors[i])

        # ax.plot(x, dmax, lw = lw, ls = "--", color=colors[i])
    
    # ax.plot(x, x*0+1, lw = lw, ls = "--", color="black")
    handles, labels = plt.gca().get_legend_handles_labels()
    line = Line2D([0], [0], label='max across seeds', color='k')
    handles.extend([line])
    
    # ax.legend()
    plt.xlabel("steps ")
    plt.ylabel(name) # Number of observations , Returns, Success
    # plt.ylim(top=10)
    ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
    #ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    fig.savefig(save_path, bbox_inches='tight')
    # plt.show()

    fig_legend = plt.figure(figsize=(6, 1))
    handles, labels = ax.get_legend_handles_labels()
    fig_legend.legend(handles, labels, loc='center', ncol=6)
    fig_legend.savefig("images/legend.png", bbox_inches='tight')

def process_data(alldata, smooth=0):
    pdata = []
    m = smooth
    z = 1.96                    # 95% confidence interval

    for (data, label) in alldata:
        n = data.shape[0]  # number of runs

        dmax = data.max(axis=0)
        mean = data.mean(axis=0)
        ci = z * data.std(axis=0, ddof=1) / np.sqrt(10)  # 95% CI

        print("Area under the curve", label, data.sum(axis=1).mean(), data.sum(axis=1).std())
        print("Average episodic at the end of training", label, data[:,-1].mean(), data[:,-1].std())

        if smooth:
            dmax = np.convolve(dmax, np.ones(m)/m, mode='valid')
            mean = np.convolve(mean, np.ones(m)/m, mode='valid')
            ci   = np.convolve(ci,   np.ones(m)/m, mode='valid')

        pdata.append([mean, ci, dmax, label])

    return pdata
            
def plots_curves():
    runs = [0] # range(10)
    maxiter = 10000
    algorithm, env_name, arch = "PPO", "passive_tmaze-continual-v0", "mlp"
    data_root = "data_iclr" # "data_rebuttal_2" # "data_rlc" # "data_neurips_pponewdata"

    if "active" in env_name:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [8, 10, 2]]
    else:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [14, 16, 2], [62, 64, 2]]
    if "rlc" in data_root or "pponewdata" in data_root:
        mazes = [[1, 5, 3]] # [[0, 2, 2], [1, 3, 2], [2, 4, 2], [3, 5, 2], [4, 6, 2]]

    # mazes = [[16, 2, 2],[16, 4, 4],[16, 8, 8],[16, 16, 16],[16, 32, 32]]
    mazes = [[2, 4, 2]]
    # mazes = [[2, 10, 10]]
    
    names = ["rewards"] #,"passive_count","active_count","mask_regret"]
    nice_names = ["rewards"] #,"passive regret","active regret","memory regret"]
    for ml, (maze_length, k, kappa) in enumerate(mazes):
        if arch == "lstm" and "active" not in env_name and k==2: runs = [1,2,3,4,5,6,7,8,9] # Completed runs
        if arch == "transformer" and k==64: runs = [1,2,3,4,6,7] # Completed runs
        
        if "rlc" in data_root or "pponewdata" in data_root:
            algos = [
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, kappa),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
                    # "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_demir".format(algorithm, env_name, maze_length, kappa),
                    ]
        elif "tmaze" in env_name:
            algos = [
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, maze_length, kappa),
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, kappa),
                    # "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, k),
                    ]
        # algos = [
        #         "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),
        #         "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, kappa),
        #         "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
        #         "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_demir".format(algorithm, env_name, maze_length, kappa),
        #         ]
        algos = [
                "{}-arch_mlp-env_{}_maze_length_{}-random_length_False-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),
                # "{}-arch_mlp-env_{}_maze_length_{}-random_length_False-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, kappa),
                "{}-arch_mlp-env_{}_maze_length_{}-random_length_False-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
                "{}-arch_mlp-env_{}_maze_length_{}-random_length_False-num_stack_{}-mask_type_demir".format(algorithm, env_name, maze_length, kappa),
                ]
        # algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $\kappa$","FrameStack $k^*$","DemirStack $\kappa$" ]
        algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $k^*$" ,"DemirStack $\kappa$" ]
        # algos_label = ["DemirStack $\kappa$" ]

        returns = {name:[] for name in names}
        for algo,algos_label in zip(algos,algos_label):
            data = {name:[] for name in names}
            for run in runs: 
                data_path = "{}/{}-run_{}.npy".format(data_root,algo,run)
                # if "demir" in algo: data_path = "{}/{}-intrinsic_rewards_False-run_{}.npy".format(data_root,algo,run)
                print(data_path)
                if os.path.exists(data_path):
                    log = np.load(data_path, allow_pickle=True).tolist()
                    for name in names:
                        d = np.array(log[name])#[:maxiter]
                        d = np.array(np.array_split(d, 1000)).sum(axis=1)
                        # if "demir" in algo:
                        #     chunks = np.array_split(d, 10000)
                        #     d = np.array([c.sum() for c in chunks])/9.8
                        # if "demir" in algo:
                        #     newd = []
                        #     steps = maze_length+2+("active" in env_name)
                        #     n_out = 10000
                        #     chunks = np.array_split(d, n_out)
                        #     for c in chunks:
                        #         G=0 
                        #         for i in range(len(c)):
                        #             G+=c[i]*0.99**(i*steps)
                        #             c[i] = c[i]*0.99**(i*steps)
                        #         print(len(c),c,G)
                        #         newd.append(G)
                        #     d = np.array(newd)*0
                        # if "regret" in name or "count" in name: d = np.cumsum(d)
                        data[name].append(d)
                        print(np.array(log[name]).shape, d.shape)
                else:
                    print("FAILED")
                    # break
            for name in names:
                returns[name].append((np.array(data[name]), algos_label))

        for nice_name, name in zip(nice_names,names):  
            save_path = "{}/{}-{}-env_{}_maze_length_{}-num_stack_{}-{}-intrinsic_rewards_False.png".format(args.path,algorithm, arch, env_name, maze_length, k, name)
            plotdata(process_data(returns[name], smooth=None), nice_name, save_path)

            
def plots_curves_otherenvs():
    runs = [0,1,2,3,4,5,7,8,9] # [0,1,2,3,4,5,8,9] # [0,1,2,6,7,9] # [0,6,7] #
    maxiter = 100000
    algorithm, base_env, arch = "PPO", "FetchReachDense-v4", "lstm"
    data_root = "data" # "data_rlc" "data_neurips_pponewdata"

    # exps = [["PositionOnlyCartPoleHard", 2],["VelocityOnlyCartPoleHard", 2], ["NoisyPositionOnlyCartPole", 2]]
    # exps = [["PositionOnlyCartPoleHard", 2]]
    # exps = [[2,"face",2],[5,"face",2],[5,"face",10],[2,"orthographic",2],[5,"orthographic",2],[5,"orthographic",10]]
    # exps = [[5,"face",2],[5,"face",8],[5,"face",16],[5,"face",32],[5,"face",64],[10,"face",2],[10,"face",8],[10,"face",16],[10,"face",32],[10,"face",64]] # 
    # exps = [[10,"face",2],[10,"face",8],[10,"face",16],[10,"face",32],[10,"face",64]] # 
    # exps = [[5,"face",10]] # 
    # exps = [[2,"orthographic",2]] # 
    exps = [["FetchReachDense-v4",4]]

    names = ["success"]
    nice_names = ["successes"]
    for ml, exp in enumerate(exps):    
        if "popgym" in base_env:  
            env_name, k = exp
            env_name = base_env + "-" + env_name  
        elif "cube" in base_env:
            scramble_steps,cube_cam, k = exp
            env_name = "{}_scramble_steps_{}-random_length_False-cube_cam_{}".format(base_env, scramble_steps, cube_cam)
        elif "FetchReachDense-v4" in base_env:
            env_name, k = exp

        if (algorithm, base_env, arch) == ("PPO", "FetchReachDense-v4", "lstm"):
            algos = [
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format("PPO", "lstm", env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format("PPO", "mlp", env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format("RecurrentPPO", "mlp", env_name, 1),
                    ]
            # algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $k^*$"]
        elif (algorithm, base_env, arch) == ("PPO", "FetchReachDense-v4", "transformer"):
            algos = [
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, "mlp", env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, 50),
                    ]
            # algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $k^*$"]
        else:
            algos = [
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, k),
                    "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, 50),
                    # "{}-arch_{}-env_{}-num_stack_{}-mask_type_demir".format(algorithm, arch, env_name, k),
                    ]
        # algos_label = ["AdaptiveStack (Ours)","FrameStack","DemirStack"]
        algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $\kappa$","FrameStack $k^*$"] #,"DemirStack $\kappa$" ]

        returns = {name:[] for name in names}
        for algo,algos_label in zip(algos,algos_label):
            data = {name:[] for name in names}
            for run in runs: 
                data_path = "{}/{}-run_{}.npy".format(data_root,algo,run)
                print(data_path)
                if os.path.exists(data_path):
                    log = np.load(data_path, allow_pickle=True).tolist()
                    for name in names:
                        d = np.array(log[name]) # >= -0.05
                        # d[d.argmax():] = d.max()
                        # d = d.reshape(1000000//50,50).sum(axis=1)
                        # d = d.reshape(-1, 50).sum(axis=1)
                        # d = d.reshape(-1, 50)[:,-1]
                        # if name!="steps" and name!="success": # "popgym" not in base_env:
                        #     d = np.array(log[name])[:maxiter]
                        # else:
                        #     metric = np.array(log[name])
                        #     cumsum = np.array(log["steps"]).cumsum()
                        #     d = np.array([metric[np.where(cumsum<=step)[0][-1]] for step in range(0,1000000,100)])
                        # d = np.cumsum(d)

                        data[name].append(d)
                        print(np.array(log[name]).shape,d.shape)
                else:
                    print("FAILED")
                    # break
            for name in names:
                returns[name].append((np.array(data[name]), algos_label))

        for nice_name, name in zip(nice_names,names):  
            save_path = "{}/{}-{}-env_{}-num_stack_{}-{}.pdf".format(args.path,algorithm, arch, env_name, k, nice_name)
            plotdata(process_data(returns[name], smooth=1000), nice_name, save_path)

# def plots_barchart():
#     s = 50
#     rc_ = {'figure.figsize':(10,8),'axes.labelsize': 60, 'xtick.labelsize': s, 
#            'ytick.labelsize': s, 'legend.fontsize': s}
#     sns.set(rc=rc_, style="darkgrid")
#     rc('text', usetex=True)

#     gamma = 0.99
#     runs = 10
#     maxiter = 10000
#     arch = "mlp"
#     algorithm = "QL"
#     env_name = "active_tmaze-v0" # "passive_tmaze-v0" "active_tmaze-v0"
#     data_root = "data_rlc" # "data" "data_rlc" "data_neurips_pponewdata"
    
#     # names = ["returns","regrets","passive_count","active_count","mask_regret"]
#     # nice_names = ["returns","rewards regret","passive regret","active regret","memory regret"]
#     names = ["eval_successes", "eval_returns", "value_gap", "returns","passive_count","active_count","mask_regret"]
#     nice_names = ["eval successes", "eval returns", "value gap", "returns","passive regret","active regret","memory regret"]
#     if "active" in env_name:
#         mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [8, 10, 2]]
#     else:
#         mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [14, 16, 2], [62, 64, 2]]
#     if "rlc" in data_root or "pponewdata" in data_root:
#         mazes = [[0, 2, 2], [1, 3, 2], [2, 4, 2], [3, 5, 2], [4, 6, 2]]

#     data1 = {name:[0 for maze in mazes] for name in names}
#     data2 = {name:[0 for maze in mazes] for name in names}
#     data3 = {name:[0 for maze in mazes] for name in names}
#     data4 = {name:[0 for maze in mazes] for name in names}

#     for ml, (maze_length, k, kappa) in enumerate(mazes):
        
#         if "rlc" in data_root or "pponewdata" in data_root:
#             algos = [
#                     "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
#                     "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, kappa),
#                     "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),]
#         elif "tmaze" in env_name:
#             algos = [
#                     "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, k),
#                     "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, kappa),
#                     "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, maze_length, kappa),]
#         algos_label = ["FrameStack $k^*$","FrameStack $\kappa$", "AdaptiveStack $\kappa$ (Ours)"]
#         for algo,algo_label in zip(algos,algos_label):
#             if "tmaze" in env_name:
#                 continual = "rlc" in data_root
#                 env = gym.make("tmaze-v0", length=maze_length, random_length=False, active="active" in env_name, continual=continual, fix_start=True, goal_obs=False, fully_obs=False)
#                 if algo_label=="FrameStack $k^*$": 
#                     env = FrameStack(env, k)
#                 if algo_label=="FrameStack $\kappa$": 
#                     env = FrameStack(env, kappa)
#                 if algo_label=="AdaptiveStack $\kappa$ (Ours)": 
#                     env = MaskedFrameStack(env, kappa)
#                 if algorithm=="QL": env = TupleObs(env)
                            
#             data = {name:[] for name in names}
#             for run in range(runs): 
#                 data_path = "{}/{}-run_{}.npy".format(data_root,algo,run)
#                 if os.path.exists(data_path):
#                     print(data_path)
#                     log = np.load(data_path, allow_pickle=True).tolist()
#                     if algorithm=="QL":
#                         model_path = "data_rlc/{}-run_{}-values.npy".format(algo,run)
#                         Q = defaultdict(lambda: np.zeros(env.action_space.n))
#                         Q.update(np.load(model_path, allow_pickle=True).tolist())
#                         Agent = lambda state: (Q[state].argmax(), Q[state].max())
#                     elif algorithm=="PPO":
#                         model_path = "{}/{}-run_{}_values".format(data_root,algo,run)
#                         model = PPO.load(model_path, env=env)
#                         def Agent(state):
#                             action = model.predict(state, deterministic=True)[0]
#                             state = torch.as_tensor(state).to(model.device).float().unsqueeze(0)
#                             with torch.no_grad(): value = model.policy.predict_values(state)
#                             return action, float(value.cpu().numpy())

#                     for name in names:
#                         if len(log["returns"])<maxiter: 
#                             print("FAILED", run, algo_label,name, len(log["returns"]))
#                         else:
#                             if "eval" in name or "value" in name:
#                                 successes = []
#                                 returns = []
#                                 values = []
#                                 for start_goal in range(len(env.goals)):
#                                     env.unwrapped.start_goal = start_goal
#                                     state, _ = env.reset()
#                                     G = 0
#                                     for t in range(100):
#                                         action, value = Agent(state)
#                                         state, reward, done, truncate, _ = env.step(action) 
#                                         G += (gamma**t)*reward
#                                         if t==1: values.append(abs(value-gamma**(maze_length+t+env.active)))
#                                         if reward!=0: successes.append((reward>0)+0.0)
#                                         if done or truncate: break
#                                     returns.append(G)
#                                 if len(successes)==0: successes.append(0)
#                                 if name=="eval successes":
#                                     data[name].append(np.array([np.mean(successes)]))
#                                 if name=="eval returns":
#                                     data[name].append(np.array([np.mean(returns)]))
#                                 if name=="value gap":
#                                     data[name].append(np.array([np.mean(values)]))
#                                 # print(data[name][-1].shape)
#                             else:
#                                 data[name].append(np.array(log[name])[:maxiter])
#                                 # print(np.array(log[name]).shape)
#                         # print(algo_label,name, np.array(log[name]).shape)
#                 else:
#                     print("FAILED", data_path)
#                     # break
            
#             for name in names: 
#                 if algo_label == "FrameStack $k^*$":
#                     # print(name, mazes[ml], np.array(data[name]).shape)
#                     data2[name][ml] = np.array(data[name]).sum(axis=1)
#                 if algo_label == "FrameStack $\kappa$":
#                     data3[name][ml] = np.array(data[name]).sum(axis=1)
#                 if algo_label == "AdaptiveStack $\kappa$ (Ours)":
#                     print(np.array(data[name]).shape)
#                     data4[name][ml] = np.array(data[name]).sum(axis=1)

#                 if name=="eval successes":
#                     if algo_label == "FrameStack $k^*$":
#                         data2[name][ml] = np.array([np.mean(data[name])])
#                     if algo_label == "FrameStack $\kappa$":
#                         data3[name][ml] = np.array([np.mean(data[name])])
#                     if algo_label == "AdaptiveStack $\kappa$ (Ours)":
#                         data4[name][ml] = np.array([np.mean(data[name])])
    
#     for nice_name, name in zip(nice_names,names):  
#         min_runs = runs
#         for i in range(len(data1[name])):
#             # if len(data1[name][i])<min_runs: min_runs=len(data1[name][i])
#             if len(data2[name][i])<min_runs: min_runs=len(data2[name][i])
#             if len(data3[name][i])<min_runs: min_runs=len(data3[name][i])
#             if len(data4[name][i])<min_runs: min_runs=len(data4[name][i])
#         for i in range(len(data1[name])):
#             # data1[name][i] = data1[name][i][:min_runs]
#             data2[name][i] = data2[name][i][:min_runs]
#             data3[name][i] = data3[name][i][:min_runs]
#             data4[name][i] = data4[name][i][:min_runs]
#         print(name, "completed runs: ", min_runs)

#         data1[name] = np.array(data1[name])
#         data2[name] = np.array(data2[name])
#         data3[name] = np.array(data3[name])
#         data4[name] = np.array(data4[name])
#         print(type(data1[name]), data1[name].shape, data2[name].shape, data3[name].shape, data4[name].shape)
#         # mean1 = data1[name].mean(axis=1)
#         # std1 = data1[name].std(axis=1)
#         mean2 = data2[name].mean(axis=1)
#         std2 = data2[name].std(axis=1)
#         mean3 = data3[name].mean(axis=1)
#         std3 = data3[name].std(axis=1)
#         mean4 = data4[name].mean(axis=1)
#         std4 = data4[name].std(axis=1)
        
#         width = 1.5/6
#         fig,ax=plt.subplots()
#         mazes_ = np.arange(2,2+len(mazes))
#         ax.bar(mazes_-0.25, mean4, width, yerr=std4, align='center', ecolor='black', capsize=2.5, label=r"AdaptiveStack $k=\kappa$ (Ours)")
#         ax.bar(mazes_, mean3, width, yerr=std3, align='center', ecolor='black', capsize=2.5, label=r"FrameStack $k=\kappa$")
#         ax.bar(mazes_+0.25, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label="FrameStack $k=k^*$")
#         ax.set_xticks(mazes_)
#         ax.set_xticklabels([m for (_,m,_) in mazes])
#         # plt.legend()
#         plt.xlabel(r"maze length $(L+2)$")
#         plt.ylabel(nice_name)
#         ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
#         # plt.xlim(2, 2+len(mazes))
#         # plt.show()
            
#         fig_legend = plt.figure(figsize=(4, 1))
#         handles, labels = ax.get_legend_handles_labels()
#         fig_legend.legend(handles, labels, loc='center', ncol=6)
#         fig_legend.savefig("images/pdf/legend.pdf", bbox_inches='tight')

#         if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
#             save_path = "{}/pdf/{}-env_{}-{}.pdf".format(args.path, algorithm, env_name, name)
#         else:
#             save_path = "{}/pdf/{}-{}-env_{}-{}.pdf".format(args.path, algorithm, arch, env_name, name)
#         fig.savefig(save_path, bbox_inches='tight')


def plots_barchart():
    s = 50
    rc_ = {'figure.figsize':(10,8),'axes.labelsize': 60, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': s}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    intrinsic_rewards = False
    gamma = 0.99
    runs = 5
    maxiter = 10000
    arch = "mlp"
    algorithm = "QL"
    env_name = "passive_tmaze-v0" # "passive_tmaze-v0" "active_tmaze-v0"
    data_root = "data_" # "data" "data_rlc" "data_neurips_pponewdata"
    
    names = ["returns","regrets","passive_count","active_count","mask_regret"]
    nice_names = ["returns","rewards regret","passive regret","active regret","memory regret"]
    # names = ["eval_successes", "eval_returns", "value_gap", "returns","passive_count","active_count","mask_regret"]
    # nice_names = ["eval successes", "eval returns", "value gap", "returns","passive regret","active regret","memory regret"]
    if "active" in env_name:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [8, 10, 2]]
    else:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [14, 16, 2], [62, 64, 2]]
    if "rlc" in data_root or "pponewdata" in data_root or "data_" == data_root:
        mazes = [[0, 2, 2], [1, 3, 2], [2, 4, 2], [3, 5, 2], [4, 6, 2]]
    # mazes = [[1,5,3],[2,10,3]]
    mazes = [[0, 2, 2], [1, 3, 2], [2, 4, 2], [3, 5, 2], [4, 6, 2]]

    data1 = {name:[0 for maze in mazes] for name in names}
    data2 = {name:[0 for maze in mazes] for name in names}
    data3 = {name:[0 for maze in mazes] for name in names}
    data4 = {name:[0 for maze in mazes] for name in names}

    for ml, (maze_length, k, kappa) in enumerate(mazes):
        
        if True: # "rlc" in data_root or "pponewdata" in data_root or "data_" == data_root:
            algos = [
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
                    # "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, kappa),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_demir".format(algorithm, env_name, maze_length, kappa),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),
                    ]
        elif "tmaze" in env_name:
            algos = [
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, k),
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, kappa),
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, maze_length, kappa),]
        # algos_label = ["FrameStack $k^*$","FrameStack $\kappa$","DemirStack $\kappa$", "AdaptiveStack $\kappa$ (Ours)"]
        algos_label = ["FrameStack $k^*$","DemirStack $\kappa$", "AdaptiveStack $\kappa$ (Ours)"]
        for algo,algo_label in zip(algos,algos_label):
            if "tmaze" in env_name:
                continual = "rlc" in data_root
                env = gym.make("tmaze-v0", length=maze_length, random_length=False, active="active" in env_name, continual=continual, fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label=="FrameStack $k^*$": 
                    env = FrameStack(env, k)
                if algo_label=="FrameStack $\kappa$": 
                    env = FrameStack(env, kappa)
                if algo_label=="DemirStack $\kappa$": 
                    env = DemirFrameStack(env, kappa)
                if algo_label=="AdaptiveStack $\kappa$ (Ours)": 
                    env = MaskedFrameStack(env, kappa)
                if algorithm=="QL": env = TupleObs(env)
                            
            data = {name:[] for name in names}
            for run in range(runs): 
                if "demir" in algo: 
                    data_path = "{}/{}-intrinsic_rewards_{}-run_{}.npy".format(data_root,algo,intrinsic_rewards,run)
                else:
                    data_path = "{}/{}-run_{}.npy".format(data_root,algo,run)
                if os.path.exists(data_path):
                    print(data_path)
                    log = np.load(data_path, allow_pickle=True).tolist()
                    if algorithm=="QL":
                        if "demir" in algo: 
                            model_path = "{}/{}-intrinsic_rewards_{}-run_{}-values.npy".format(data_root,algo,intrinsic_rewards,run)
                        else:
                            model_path = "{}/{}-run_{}-values.npy".format(data_root,algo,run)
                        Q = defaultdict(lambda: np.zeros(env.action_space.n))
                        Q.update(np.load(model_path, allow_pickle=True).tolist())
                        Agent = lambda state: (Q[state].argmax(), Q[state].max())
                    elif algorithm=="PPO":
                        model_path = "{}/{}-run_{}_values".format(data_root,algo,run)
                        model = PPO.load(model_path, env=env)
                        def Agent(state):
                            action = model.predict(state, deterministic=True)[0]
                            state = torch.as_tensor(state).to(model.device).float().unsqueeze(0)
                            with torch.no_grad(): value = model.policy.predict_values(state)
                            return action, float(value.cpu().numpy())

                    for name in names:
                        if len(log["returns"])<maxiter: 
                            print("FAILED", run, algo_label,name, len(log["returns"]))
                        else:
                            if "eval" in name or "value" in name:
                                successes = []
                                returns = []
                                values = []
                                for start_goal in range(len(env.goals)):
                                    env.unwrapped.start_goal = start_goal
                                    state, _ = env.reset()
                                    G = 0
                                    for t in range(100):
                                        action, value = Agent(state)
                                        state, reward, done, truncate, _ = env.step(action) 
                                        G += (gamma**t)*reward
                                        if t==1: values.append(abs(value-gamma**(maze_length+t+env.active)))
                                        if reward!=0: successes.append((reward>0)+0.0)
                                        if done or truncate: break
                                    returns.append(G)
                                if len(successes)==0: successes.append(0)
                                if name=="eval successes":
                                    data[name].append(np.array([np.mean(successes)]))
                                if name=="eval returns":
                                    data[name].append(np.array([np.mean(returns)]))
                                if name=="value gap":
                                    data[name].append(np.array([np.mean(values)]))
                                # print(data[name][-1].shape)
                            else:
                                # data[name].append(np.array(log[name])[:maxiter])
                                d = np.array(log[name])#[:10000]
                                # if "demir" in algo:
                                #     chunks = np.array_split(d, 10000)
                                #     d = np.array([c.sum() for c in chunks])/9.8
                                data[name].append(d)
                                # print(np.array(log[name]).shape)
                        # print(algo_label,name, np.array(log[name]).shape)
                else:
                    print("FAILED", data_path)
                    # break
            
            for name in names: 
                if algo_label == "DemirStack $\kappa$":
                    # print(name, mazes[ml], np.array(data[name]).shape)
                    data1[name][ml] = np.array(data[name]).sum(axis=1)
                if algo_label == "FrameStack $k^*$":
                    # print(name, mazes[ml], np.array(data[name]).shape)
                    data2[name][ml] = np.array(data[name]).sum(axis=1)
                if algo_label == "FrameStack $\kappa$":
                    data3[name][ml] = np.array(data[name]).sum(axis=1)
                if algo_label == "AdaptiveStack $\kappa$ (Ours)":
                    print(np.array(data[name]).shape)
                    data4[name][ml] = np.array(data[name]).sum(axis=1)

                if name=="eval successes":
                    if algo_label == "DemirStack $\kappa$":
                        data1[name][ml] = np.array([np.mean(data[name])])
                    if algo_label == "FrameStack $k^*$":
                        data2[name][ml] = np.array([np.mean(data[name])])
                    if algo_label == "FrameStack $\kappa$":
                        data3[name][ml] = np.array([np.mean(data[name])])
                    if algo_label == "AdaptiveStack $\kappa$ (Ours)":
                        data4[name][ml] = np.array([np.mean(data[name])])
    
    for nice_name, name in zip(nice_names,names):  
        min_runs = runs
        for i in range(len(data1[name])):
            if len(data1[name][i])<min_runs: min_runs=len(data1[name][i])
            if len(data2[name][i])<min_runs: min_runs=len(data2[name][i])
            # if len(data3[name][i])<min_runs: min_runs=len(data3[name][i])
            if len(data4[name][i])<min_runs: min_runs=len(data4[name][i])
        for i in range(len(data1[name])):
            data1[name][i] = data1[name][i][:min_runs]
            data2[name][i] = data2[name][i][:min_runs]
            # data3[name][i] = data3[name][i][:min_runs]
            data4[name][i] = data4[name][i][:min_runs]
        print(name, "completed runs: ", min_runs)

        data1[name] = np.array(data1[name])
        data2[name] = np.array(data2[name])
        data3[name] = np.array(data3[name])
        data4[name] = np.array(data4[name])
        print(type(data1[name]), data1[name].shape, data2[name].shape, data3[name].shape, data4[name].shape)
        # print(type(data1[name]), data1[name].shape, data2[name].shape, data4[name].shape)
        mean1 = data1[name].mean(axis=1)
        std1 = data1[name].std(axis=1)
        mean2 = data2[name].mean(axis=1)
        std2 = data2[name].std(axis=1)
        # mean3 = data3[name].mean(axis=1)
        # std3 = data3[name].std(axis=1)
        mean4 = data4[name].mean(axis=1)
        std4 = data4[name].std(axis=1)
        
        print(env_name, intrinsic_rewards)
        print("performance metric", name)
        print("memory length: ", mazes)
        print("AS kappa (mean): ", mean4)
        print("AS kappa (std): ", std4)
        print("DS kappa (mean): ", mean1)
        print("DS kappa (std): ", std1)
        # print("FS kappa (mean): ", mean3)
        # print("FS kappa (std): ", std3)
        print("FS k (mean): ", mean2)
        print("FS k (std): ", std2)
        print("_____________________________________________")
        
        # width = 1.2/6
        # fig,ax=plt.subplots()
        # mazes_ = np.arange(2,2+len(mazes))
        # ax.bar(mazes_-0.3, mean4, width, yerr=std4, align='center', ecolor='black', capsize=2.5, label=r"AdaptiveStack $k=\kappa$ (Ours)")
        # ax.bar(mazes_-0.1, mean3, width, yerr=std3, align='center', ecolor='black', capsize=2.5, label=r"FrameStack $k=\kappa$")
        # ax.bar(mazes_+0.1, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label="FrameStack $k=k^*$")
        # ax.bar(mazes_+0.3, mean1, width, yerr=std1, align='center', ecolor='black', capsize=2.5, label=r"DemirStack $k=\kappa$")
        width = 1.5/6
        fig,ax=plt.subplots()
        mazes_ = np.arange(2,2+len(mazes))
        ax.bar(mazes_-0.25, mean4, width, yerr=std4, align='center', ecolor='black', capsize=2.5, label=r"AdaptiveStack $k=\kappa$ (Ours)")
        ax.bar(mazes_, mean1, width, yerr=std1, align='center', ecolor='black', capsize=2.5, label=r"DemirStack $k=\kappa$")
        ax.bar(mazes_+0.25, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label="FrameStack $k=k^*$")
        
        ax.set_xticks(mazes_)
        ax.set_xticklabels([m for (_,m,_) in mazes])
        # plt.legend()
        plt.xlabel(r"maze length $(L+2)$")
        plt.ylabel(nice_name)
        ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
        # plt.xlim(2, 2+len(mazes))
        # plt.show()
            
        fig_legend = plt.figure(figsize=(4, 1))
        handles, labels = ax.get_legend_handles_labels()
        fig_legend.legend(handles, labels, loc='center', ncol=6)
        fig_legend.savefig("images/legend-demir.pdf", bbox_inches='tight')

        if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
            # save_path = "{}/pdf/{}-env_{}-{}.pdf".format(args.path, algorithm, env_name, name)
            save_path = "{}/{}-env_{}-{}-intrinsic_rewards_{}.pdf".format(args.path, algorithm, env_name, name, intrinsic_rewards)
        else:
            save_path = "{}/{}-{}-env_{}-{}.pdf".format(args.path, algorithm, arch, env_name, name)
        fig.savefig(save_path, bbox_inches='tight')


def plots_memory_barchart():
    s = 50
    rc_ = {'figure.figsize':(10,8),'axes.labelsize': 60, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': s}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)

    gamma = 0.99
    runs = 10
    maxiter = 10000
    arch = "mlp"
    algorithm = "GRPO"
    env_name = "passive_tmaze-v0" # "passive_tmaze-v0" "active_tmaze-v0"
    data_root = "data_grpo" # "data" "data_rlc" "data_neurips_pponewdata"
    
    names = ["returns"] #,"returns","passive_count","active_count","mask_regret"]
    nice_names = ["returns"] #,"returns","passive regret","active regret","memory regret"]
    # names = ["eval_successes", "eval_returns", "value_gap", "returns","passive_count","active_count","mask_regret"]
    # nice_names = ["eval successes", "eval returns", "value gap", "returns","passive regret","active regret","memory regret"]
    if "active" in env_name:
        mazes = [[16, 2, 2], [16, 4, 4], [16, 8, 8], [16, 16, 16], [16, 32, 32]]
    else:
        mazes = [[16, 2, 2], [16, 4, 4], [16, 8, 8], [16, 16, 16], [16, 32, 32]]
    if "xormaze" in env_name:
        mazes = [[1, 3, 3],[1,5,5],[2, 3, 3]]

    mazes = [[14, 2, 2], [14, 4, 4], [14, 8, 8], [14, 16, 16], [14, 32, 32]]

    data1 = {name:[0 for maze in mazes] for name in names}
    data2 = {name:[0 for maze in mazes] for name in names}
    data3 = {name:[0 for maze in mazes] for name in names}
    data4 = {name:[0 for maze in mazes] for name in names}

    for ml, (maze_length, k, kappa) in enumerate(mazes):
        
        if "rlc" in data_root or "pponewdata" in data_root:
            algos = [
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),]
        elif "tmaze" in env_name:
            algos = [
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_framestack".format(algorithm, arch, env_name, maze_length, k),
                    "{}-arch_{}-env_{}_maze_length_{}-random_length_True-num_stack_{}-mask_type_masked".format(algorithm, arch, env_name, maze_length, kappa),]
        elif "xormaze" in env_name:
            algos = [
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_framestack".format(algorithm, env_name, maze_length, k),
                    "{}-env_{}_maze_length_{}-num_stack_{}-mask_type_masked".format(algorithm, env_name, maze_length, kappa),]
        algos_label = ["FrameStack", "AdaptiveStack (Ours)"]
        for algo,algo_label in zip(algos,algos_label):
            if "tmaze" in env_name:
                continual = "rlc" in data_root
                env = gym.make("tmaze-v0", length=maze_length, random_length=False, active="active" in env_name, continual=continual, fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label=="FrameStack": 
                    env = FrameStack(env, k)
                if algo_label=="AdaptiveStack (Ours)": 
                    env = MaskedFrameStack(env, kappa,use_multidiscrete=True)
                if algorithm=="QL": env = TupleObs(env)
            if "xormaze" in env_name:
                continual = "rlc" in data_root
                env = gym.make("xormaze-v0", length=maze_length, random_length=False, active="active" in env_name, continual=continual, fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label=="FrameStack": 
                    env = FrameStack(env, k)
                if algo_label=="AdaptiveStack (Ours)": 
                    env = MaskedFrameStack(env, kappa,use_multidiscrete=True)
                if algorithm=="QL": env = TupleObs(env)
                            
            data = {name:[] for name in names}
            for run in range(runs): 
                data_path = "{}/{}-run_{}.npy".format(data_root,algo,run)
                if os.path.exists(data_path):
                    print(data_path)
                    log = np.load(data_path, allow_pickle=True).tolist()
                    # print(run, algo_label, len(log["returns"]))
                    if algorithm=="QL":
                        model_path = "{}/{}-run_{}-values.npy".format(data_root,algo,run)
                        Q = defaultdict(lambda: np.zeros(env.action_space.n))
                        Q.update(np.load(model_path, allow_pickle=True).tolist())
                        Agent = lambda state: (Q[state].argmax(), Q[state].max())
                    elif algorithm=="PPO":
                        model_path = "{}/{}-run_{}_values".format(data_root,algo,run)
                        model = PPO.load(model_path, env=env)
                        def Agent(state):
                            action = model.predict(state, deterministic=True)[0]
                            state = torch.as_tensor(state).to(model.device).float().unsqueeze(0)
                            with torch.no_grad(): value = model.policy.predict_values(state)
                            return action, float(value.cpu().numpy())

                    for name in names:
                        if len(log["returns"])<maxiter: 
                            print("FAILED", run, algo_label,name, len(log["returns"]))
                        else:
                            if "eval" in name or "value" in name:
                                successes = []
                                returns = []
                                values = []
                                for start_goal in range(len(env.goals)):
                                    env.unwrapped.start_goal = start_goal
                                    state, _ = env.reset()
                                    G = 0
                                    for t in range(100):
                                        action, value = Agent(state)
                                        state, reward, done, truncate, _ = env.step(action) 
                                        G += (gamma**t)*reward
                                        if t==1: values.append(abs(value-gamma**(maze_length+t+env.active)))
                                        if reward!=0: successes.append((reward>0)+0.0)
                                        if done or truncate: break
                                    returns.append(G)
                                if len(successes)==0: successes.append(0)
                                if name=="eval successes":
                                    data[name].append(np.array([np.mean(successes)]))
                                if name=="eval returns":
                                    data[name].append(np.array([np.mean(returns)]))
                                if name=="value gap":
                                    data[name].append(np.array([np.mean(values)]))
                                # print(data[name][-1].shape)
                            else:
                                data[name].append(np.array(log[name])[:maxiter])
                                # print(np.array(log[name]).shape)
                        # print(algo_label,name, np.array(log[name]).shape)
                else:
                    print("FAILED", data_path)
                    # break
            
            for name in names: 
                if algo_label == "FrameStack":
                    # print(name, mazes[ml], np.array(data[name]).shape)
                    # print(data)
                    # print(name, np.array(data[name]).shape)
                    data2[name][ml] = np.array(data[name]).sum(axis=1)
                if algo_label == "AdaptiveStack (Ours)":
                    print(np.array(data[name]).shape)
                    data4[name][ml] = np.array(data[name]).sum(axis=1)

                if name=="eval successes":
                    if algo_label == "FrameStack":
                        data2[name][ml] = np.array([np.mean(data[name])])
                    if algo_label == "AdaptiveStack (Ours)":
                        data4[name][ml] = np.array([np.mean(data[name])])
    
    for nice_name, name in zip(nice_names,names):  
        min_runs = runs
        for i in range(len(data1[name])):
            # if len(data1[name][i])<min_runs: min_runs=len(data1[name][i])
            if len(data2[name][i])<min_runs: min_runs=len(data2[name][i])
            if len(data4[name][i])<min_runs: min_runs=len(data4[name][i])
        for i in range(len(data1[name])):
            # data1[name][i] = data1[name][i][:min_runs]
            data2[name][i] = data2[name][i][:min_runs]
            data4[name][i] = data4[name][i][:min_runs]
        # print(name, "completed runs: ", min_runs)

        data1[name] = np.array(data1[name])
        data2[name] = np.array(data2[name])
        data4[name] = np.array(data4[name])
        # print(type(data1[name]), data1[name].shape, data2[name].shape, data4[name].shape)
        # mean1 = data1[name].mean(axis=1)
        # std1 = data1[name].std(axis=1)
        mean2 = data2[name].mean(axis=1)
        std2 = data2[name].std(axis=1)
        mean4 = data4[name].mean(axis=1)
        std4 = data4[name].std(axis=1)

        print("performance metric", name)
        print("memory length: ", mazes)
        print("AS (mean): ", mean4)
        print("AS (std): ", std4)
        print("FS (mean): ", mean2)
        print("FS (std): ", std2)
        print("_____________________________________________")
        
        width = 1.5/6
        fig,ax=plt.subplots()
        mazes_ = np.arange(2,2+len(mazes))
        ax.bar(mazes_-0.15, mean4, width, yerr=std4, align='center', ecolor='black', capsize=2.5, label=r"AdaptiveStack (Ours)")
        ax.bar(mazes_+0.15, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label=r"FrameStack")
        # ax.bar(mazes_+0.15, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label="FrameStack")
        ax.set_xticks(mazes_)
        ax.set_xticklabels([m for (_,m,_) in mazes])
        # plt.legend()
        plt.xlabel(r"memory length")
        plt.ylabel(nice_name)
        ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
        # plt.xlim(2, 2+len(mazes))
        # plt.show()
            
        fig_legend = plt.figure(figsize=(4, 1))
        handles, labels = ax.get_legend_handles_labels()
        fig_legend.legend(handles, labels, loc='center', ncol=6)
        fig_legend.savefig("images/pdf/legend.pdf", bbox_inches='tight')

        if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
            save_path = "{}/pdf/memory-{}-env_{}-{}.pdf".format(args.path, algorithm, env_name, name)
        else:
            save_path = "{}/pdf/memory-{}-{}-env_{}-{}.pdf".format(args.path, algorithm, arch, env_name, name)
        fig.savefig(save_path, bbox_inches='tight')


# ------------------------------------------------------------
# Helper: build agent callable exactly as in your training code
# ------------------------------------------------------------
def load_agent(algo_label, algorithm, model_path, env, device="cpu", deterministic=True):
    if algorithm == "QL":
        Q = defaultdict(lambda: np.zeros(env.action_space.n))
        Q.update(np.load(model_path, allow_pickle=True).item())
        return lambda s: (Q[s].argmax(), Q[s].max())

    if algorithm == "PPO":
        model = PPO.load(model_path, env=env, device=device)
        def agent(s):
            a = model.predict(s, deterministic=deterministic)[0]
            s_t = torch.as_tensor(s).to(model.device).float().unsqueeze(0)
            with torch.no_grad():
                v = model.policy.predict_values(s_t)
            return a, float(v.cpu().numpy())
        return agent
    elif algorithm == "GRPO":
        model = GRPO.load(model_path, env=env, device=device)
        def agent(s):
            a = model.predict(s, deterministic=deterministic)[0]
            s_t = torch.as_tensor(s).to(model.device).float().unsqueeze(0)
            with torch.no_grad():
                v = model.policy.predict_values(s_t)
            return a, float(v.cpu().numpy())
        return agent

    raise ValueError(f"Unknown algorithm {algorithm!r}")


# ------------------------------------------------------------
# Main: cross-maze generalisation plot
# ------------------------------------------------------------
def plots_generalisation(
        env_name           = "passive_tmaze-v0",
        data_root          = "data",        # "data" "data_rlc" "data_neurips_pponewdata"
        algorithm          = "PPO",          # or "PPO"
        arch               = "lstm",
        gamma              = 0.99,
        runs               = 10,
        max_steps          = 100,
        out_dir            = "images/pdf",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':50,'axes.titlesize':40,
           'xtick.labelsize':40,'ytick.labelsize':40,'legend.fontsize':40}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    continual = "rlc" in data_root
    if continual: max_steps = 1000

    if "active" in env_name:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [8, 10, 2]]
    else:
        mazes = [[0, 2, 2], [2, 4, 2], [6, 8, 2], [14, 16, 2], [62, 64, 2]]
    if "rlc" in data_root or "pponewdata" in data_root:
        mazes = [[0, 2, 2], [1, 3, 2], [2, 4, 2], [3, 5, 2], [4, 6, 2]]

    algo_templates, algo_labels = [], []
    if "rlc" in data_root or "pponewdata" in data_root:
        algo_templates = [
            "{algo}-env_{env}_maze_length_{L}-num_stack_{k}-mask_type_framestack",
            "{algo}-env_{env}_maze_length_{L}-num_stack_{kappa}-mask_type_framestack",
            "{algo}-env_{env}_maze_length_{L}-num_stack_{kappa}-mask_type_masked"
        ]
    else:  # NeurIPS-style naming
        algo_templates = [
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{k}-mask_type_framestack",
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_framestack",
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_masked"
        ]
    algo_labels = ["FrameStack $k^*$", "FrameStack $\\kappa$", "AdaptiveStack $\\kappa$ (Ours)"]

    M = len(mazes)
    metrics = {"success_rate_mean": np.zeros((3, M, M)),  # algo × train × test
               "success_rate_std" : np.zeros((3, M, M)),
               "returns_mean" : np.zeros((3, M, M)),
               "returns_std" : np.zeros((3, M, M)),
               "optimality_mean" : np.zeros((3, M, M)),
               "optimality_std" : np.zeros((3, M, M))}

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(mazes):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name,
                                    L=L_train, k=k, kappa=kappa)

            # load once per run inside test-maze loop ↓
            for test_idx, (L_test, _, _) in enumerate(mazes):
                succ_all, ret_all, optimality_all = [], [], []

                # build test environment *for this algorithm wrapper*
                test_env = gym.make("tmaze-v0",
                                    length=L_test,
                                    random_length=False,
                                    active="active" in env_name,
                                    continual=continual,
                                    fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label == "FrameStack $k^*$":
                    test_env = FrameStack(test_env, k)
                elif algo_label == "FrameStack $\\kappa$":
                    test_env = FrameStack(test_env, kappa)
                elif algo_label.startswith("AdaptiveStack"):
                    test_env = MaskedFrameStack(test_env, kappa)
                if algorithm == "QL":
                    test_env = TupleObs(test_env)
                
                V_opt = gamma**(L_test+2-1+test_env.active)
                if continual:
                    V_opt = get_Vopt(test_env, gamma)
                    V_opt = np.mean([V_opt[s] for s in test_env.start_states])

                # run-wise evaluation
                for run in range(runs):
                    model_path = (f"{data_root}/{algo_name}-run_{run}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device)

                    G = 0.0
                    successes = []
                    n_goals = len(test_env.goals)
                    for sg in range(n_goals):
                        test_env.unwrapped.start_goal = sg
                        state, _ = test_env.reset()
                        for t in range(max_steps):
                            action, _ = agent(state)
                            state, r, done, trunc, _ = test_env.step(action)
                            G += (gamma ** t) * r
                            if r!=0: 
                                successes.append(float(r > 0))
                                test_env.unwrapped.start_goal = (test_env.unwrapped.start_goal+1)%n_goals
                            if done or trunc: break
                    if not successes: successes.append(0)
                    ret_all.append(G/n_goals)
                    succ_all += successes
                    optimality_all.append(1-abs(V_opt - G/n_goals)/(2*V_opt))
                    # optimality_all.append(1-abs(V_opt - G/n_goals)/(V_opt+1))

                # average over runs & goals
                metrics["success_rate_mean"][algo_id, test_idx, train_idx] = np.mean(succ_all)
                metrics["success_rate_std"][algo_id, test_idx, train_idx] = np.std(succ_all)
                metrics["returns_mean" ][algo_id, test_idx, train_idx] = np.mean(ret_all)
                metrics["returns_std" ][algo_id, test_idx, train_idx] = np.std(ret_all)
                metrics["optimality_mean" ][algo_id, test_idx, train_idx] = np.mean(optimality_all)
                metrics["optimality_std" ][algo_id, test_idx, train_idx] = np.std(optimality_all)

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for (_, k, _) in mazes]
    ylabels = [k for (_, k, _) in mazes]

    for algo_id, algo_label in enumerate(algo_labels):
        for key, title in [
                        #    ("success_rate", "success rate"),
                        #    ("returns" , "returns"),
                           ("optimality" , "optimality")
                           ]:
            annotations = np.empty_like(metrics[key+"_mean"][algo_id], dtype=object)
            for i in range(metrics[key+"_mean"][algo_id].shape[0]):
                for j in range(metrics[key+"_mean"][algo_id].shape[1]):
                    mean = metrics[key+"_mean"][algo_id][i, j]
                    std = metrics[key+"_std"][algo_id][i, j]
                    annotations[i, j] = f"{mean:.2f} \n ± {std:.2f}"
                    
            fig, ax = plt.subplots(figsize=(figsize_per_cell*M, figsize_per_cell*M))
            sns.heatmap(metrics[key+"_mean"][algo_id], 
                        annot=annotations, annot_kws={"size": 20}, fmt= "", # ".2f", 
                        vmin=0, # if key=="success_rate" else None, 
                        vmax=1, # if key=="success_rate" else None,
                        xticklabels=xlabels, yticklabels=ylabels,
                        cbar_kws=dict(label=title), ax=ax)

            ax.set_xlabel("train maze length $(L+2)$")
            ax.set_ylabel("test maze length $(L+2)$")
            ax.set_title(f"{algo_label}", pad=20)

            if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
                fname = f"{algorithm}-{algo_label}-{env_name}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            else:
                fname = f"{algorithm}-{algo_label}-{arch}-{env_name}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")


# ------------------------------------------------------------
# Main: cross-memory generalisation plot
# ------------------------------------------------------------
def plots_memory_generalisation(
        env_name           = "passive_tmaze-v0",
        data_root          = "data_grpo",
        algorithm          = "GRPO",          # or "PPO"
        arch               = "mlp",
        gamma              = 0.99,
        runs               = 10,
        max_steps          = 100,
        out_dir            = "images/pdf",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    rc_ = {'figure.figsize':(10,8),'axes.labelsize':50,'axes.titlesize':40,
           'xtick.labelsize':40,'ytick.labelsize':40,'legend.fontsize':40}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    mazes = [[16, 2, 2], [16, 4, 4], [16, 8, 8], [16, 16, 16], [16, 32, 32]]
    mazes = [[14, 2, 2], [14, 4, 4], [14, 8, 8], [14, 16, 16], [14, 32, 32]]
    # mazes = [[14, 2, 2], [14, 4, 4], [14, 8, 8]]
    mazes_test = [0,2,6,14,30]

    algo_templates, algo_labels = [], []
    if "rlc" in data_root or "pponewdata" in data_root:
        algo_templates = [
            "{algo}-env_{env}_maze_length_{L}-num_stack_{kappa}-mask_type_framestack",
            "{algo}-env_{env}_maze_length_{L}-num_stack_{kappa}-mask_type_masked"
        ]
    else:  # NeurIPS-style naming
        algo_templates = [
            # "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_demir",
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_framestack",
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_masked"
        ]
    algo_labels = ["FrameStack", "AdaptiveStack (Ours)"]
    # algo_labels = ["DemirStack", "FrameStack", "AdaptiveStack (Ours)"]

    M = len(mazes)
    N = len(mazes_test)
    metrics = {"success_rate_mean": np.zeros((3, N, M)),  # algo × train × test
               "success_rate_std" : np.zeros((3, N, M)),
               "returns_mean" : np.zeros((3, N, M)),
               "returns_std" : np.zeros((3, N, M)),
               "optimality_mean" : np.zeros((3, N, M)),
               "optimality_std" : np.zeros((3, N, M)),
               "states_mean" : np.zeros((3, N, M)),
               "states_std" : np.zeros((3, N, M))}

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(mazes):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name,
                                    L=L_train, k=k, kappa=kappa)

            # load once per run inside test-maze loop ↓
            for test_idx, L_test in enumerate(mazes_test):
                succ_all, ret_all, optimality_all, states_all = [], [], [], []

                # build test environment *for this algorithm wrapper*
                test_env = gym.make("tmaze-v0",
                                    length=L_test,
                                    random_length=False,
                                    active="active" in env_name,
                                    continual=("rlc" in data_root),
                                    fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label == "FrameStack":
                    test_env = FrameStack(test_env, k)
                elif algo_label.startswith("AdaptiveStack"):
                    test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
                if algorithm == "QL":
                    test_env = TupleObs(test_env)
                V_opt = gamma**(L_test+2-1+("active" in env_name))

                # run-wise evaluation
                for run in range(runs):
                    model_path = (f"{data_root}/{algo_name}-run_{run}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device)

                    G = 0.0
                    successes = []
                    n_goals = 2
                    for sg in range(n_goals):
                        states = set()
                        test_env.unwrapped.start_goal = sg
                        state, _ = test_env.reset()
                        for t in range(max_steps):
                            states.add(tuple(state.flatten()))
                            action, _ = agent(state)
                            state, r, done, trunc, _ = test_env.step(action)
                            G += (gamma ** t) * r
                            if r!=0: 
                                successes.append(float(r > 0))
                                test_env.unwrapped.start_goal = (test_env.unwrapped.start_goal+1)%n_goals
                            if done or trunc: break
                        # states_all.append(len(states))
                        # if ((L_test+2)-2)!=0: 
                        #     states_all.append(1-((L_test+2)-len(states))/((L_test+2)-2))
                        # else:
                        #     states_all.append(1)
                        if (min(L_test+2,k+2)-2)!=0:
                            states_all.append(1-(min(L_test+2,k+2)-len(states))/(min(L_test+2,k+2)-2))
                        else:
                            states_all.append(1)
                    if not successes: successes.append(0)
                    ret_all.append(G/n_goals)
                    succ_all += successes
                    optimality_all.append(1-abs(V_opt - G/n_goals)/(2*V_opt))
                    # optimality_all.append(1-abs(V_opt - G/n_goals)/(V_opt+1))

                # average over runs & goals
                metrics["success_rate_mean"][algo_id, test_idx, train_idx] = np.mean(succ_all)
                metrics["success_rate_std"][algo_id, test_idx, train_idx] = np.std(succ_all)
                metrics["returns_mean" ][algo_id, test_idx, train_idx] = np.mean(ret_all)
                metrics["returns_std" ][algo_id, test_idx, train_idx] = np.std(ret_all)
                metrics["optimality_mean" ][algo_id, test_idx, train_idx] = np.mean(optimality_all)
                metrics["optimality_std" ][algo_id, test_idx, train_idx] = np.std(optimality_all)
                metrics["states_mean" ][algo_id, test_idx, train_idx] = np.mean(states_all)
                metrics["states_std" ][algo_id, test_idx, train_idx] = np.std(states_all)

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for (_, k, _) in mazes]
    ylabels = [l+2 for l in mazes_test]

    for algo_id, algo_label in enumerate(algo_labels):
        for key, title in [
                        #    ("success_rate", "success rate"),
                        #    ("returns" , "returns"),
                           ("optimality" , "optimality"),
                           ("states" , "agent states")
                           ]:
            annotations = np.empty_like(metrics[key+"_mean"][algo_id], dtype=object)
            for i in range(metrics[key+"_mean"][algo_id].shape[0]):
                for j in range(metrics[key+"_mean"][algo_id].shape[1]):
                    mean = metrics[key+"_mean"][algo_id][i, j]
                    std = metrics[key+"_std"][algo_id][i, j]
                    annotations[i, j] = f"{mean:.2f} \n ± {std:.2f}"
                    
            fig, ax = plt.subplots(figsize=(figsize_per_cell*M, figsize_per_cell*M))
            sns.heatmap(metrics[key+"_mean"][algo_id], 
                        annot=annotations, annot_kws={"size": 20}, fmt= "", # ".2f", 
                        vmin=0, # if key=="success_rate" else None, 
                        vmax=1, # if key!="states" else mazes_test[-1]+2, # if key=="success_rate" else None,
                        xticklabels=xlabels, yticklabels=ylabels,
                        cbar_kws=dict(label=title), ax=ax)

            ax.set_xlabel("train memory length $(L+2)$")
            ax.set_ylabel("test maze length $(L+2)$")
            ax.set_title(f"{algo_label}", pad=20)

            if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
                fname = f"memory-{algorithm}-{algo_label}-{env_name}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            else:
                fname = f"memory-{algorithm}-{algo_label}-{arch}-{env_name}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")


# ------------------------------------------------------------
# Main: memory usage plot
# ------------------------------------------------------------
def plots_memory_usage(
        env_name           = "passive_tmaze-v0",
        data_root          = "data_rebuttal_2",
        algorithm          = "PPO",          # or "PPO"
        arch               = "mlp",
        gamma              = 0.99,
        runs               = 2,
        max_steps          = 100,
        out_dir            = "images/pdf",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':50,'axes.titlesize':40,
           'xtick.labelsize':40,'ytick.labelsize':40,'legend.fontsize':40}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    mazes = [[16, 16, 16]] # [[16, 2, 2], [16, 4, 4], [16, 8, 8], [16, 16, 16], [16, 32, 32]]
    mazes_test = [0,2,6,14,30]

    algo_templates, algo_labels = [], []
    if "rlc" in data_root or "pponewdata" in data_root:
        algo_templates = [
            "{algo}-env_{env}_maze_length_{L}-num_stack_{k}-mask_type_framestack",
            "{algo}-env_{env}_maze_length_{L}-num_stack_{kappa}-mask_type_masked"
        ]
    else:  # NeurIPS-style naming
        algo_templates = [
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{k}-mask_type_framestack",
            "{algo}-arch_{arch}-env_{env}_maze_length_{L}-random_length_True-num_stack_{kappa}-mask_type_masked"
        ]
    algo_labels = ["FrameStack", "AdaptiveStack (Ours)"]
    memory_usage = {algo: defaultdict(lambda: np.zeros(runs)) for algo in algo_labels}

    M = len(mazes)
    metrics = {"success_rate_mean": np.zeros((3, M, M)),  # algo × train × test
               "success_rate_std" : np.zeros((3, M, M)),
               "returns_mean" : np.zeros((3, M, M)),
               "returns_std" : np.zeros((3, M, M)),
               "optimality_mean" : np.zeros((3, M, M)),
               "optimality_std" : np.zeros((3, M, M))}

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(mazes):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name,
                                    L=L_train, k=k, kappa=kappa)

            # load once per run inside test-maze loop ↓
            for test_idx, L_test in enumerate(mazes_test):
                succ_all, ret_all, optimality_all = [], [], []

                # build test environment *for this algorithm wrapper*
                test_env = gym.make("tmaze-v0",
                                    length=L_test,
                                    random_length=False,
                                    active="active" in env_name,
                                    continual=("rlc" in data_root),
                                    fix_start=True, goal_obs=False, fully_obs=False)
                if algo_label == "FrameStack":
                    test_env = FrameStack(test_env, k)
                elif algo_label.startswith("AdaptiveStack"):
                    test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
                if algorithm == "QL":
                    test_env = TupleObs(test_env)
                V_opt = gamma**(L_test+2-1+test_env.active)

                # run-wise evaluation
                for run in range(runs):
                    model_path = (f"{data_root}/{algo_name}-run_{run}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device)

                    G = 0.0
                    successes = []
                    n_goals = len(test_env.goals)
                    for sg in range(n_goals):
                        test_env.unwrapped.start_goal = sg
                        state, _ = test_env.reset()
                        for t in range(max_steps):
                            state_ = tuple(state.flatten())
                            memory_usage[algo_label][state_][run] += (1/(L_test+2))/len(mazes_test)
                            action, _ = agent(state)
                            state, r, done, trunc, _ = test_env.step(action)
                            G += (gamma ** t) * r
                            if r!=0: 
                                successes.append(float(r > 0))
                                test_env.unwrapped.start_goal = (test_env.unwrapped.start_goal+1)%n_goals
                            if done or trunc: break
                    

    # ---------------- Plot & save  --------------------------------------- 
    states = set(list(memory_usage[algo_labels[0]].keys()))#.union(set(list(memory_usage[algo_labels[1]].keys())) )
    print(len(states))
    
    # width = 1.5/6
    # fig,ax=plt.subplots()
    # mazes_ = np.arange(2,2+len(mazes))
    # ax.bar(mazes_-0.15, mean4, width, yerr=std4, align='center', ecolor='black', capsize=2.5, label=r"AdaptiveStack (Ours)")
    # ax.bar(mazes_+0.15, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label=r"FrameStack")
    # # ax.bar(mazes_+0.15, mean2, width, yerr=std2, align='center', ecolor='black', capsize=2.5, label="FrameStack")
    # ax.set_xticks(mazes_)
    # ax.set_xticklabels([m for (_,m,_) in mazes])
    # # plt.legend()
    # plt.xlabel(r"memory length")
    # plt.ylabel(nice_name)
    # ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    # # plt.xlim(2, 2+len(mazes))
    # # plt.show()
        
    # fig_legend = plt.figure(figsize=(4, 1))
    # handles, labels = ax.get_legend_handles_labels()
    # fig_legend.legend(handles, labels, loc='center', ncol=6)
    # fig_legend.savefig("images/pdf/legend.pdf", bbox_inches='tight')

    # if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
    #     save_path = "{}/memory-{}-env_{}-{}.pdf".format(args.path, algorithm, env_name, name)
    # else:
    #     save_path = "{}/memory-{}-{}-env_{}-{}.pdf".format(args.path, algorithm, arch, env_name, name)
    # fig.savefig(save_path, bbox_inches='tight')
    
# ------------------------------------------------------------
# Main: cross-scrambles generalisation plot
# ------------------------------------------------------------
def plots_scrambles_generalisation(
        env_name           = "cube-v0",
        data_root          = "data_rebuttal_3",
        algorithm          = "PPO",          # or "PPO"
        arch               = "mlp",
        gamma              = 0.99,
        runs               = 10,
        max_steps          = 100,
        out_dir            = "images",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':40,'axes.titlesize':30,
           'xtick.labelsize':30,'ytick.labelsize':30,'legend.fontsize':30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    cube_cam = "orthographic" # "face" # 
    cubes = [[2, 2, 2]] #,[5, 10, 10]]
    cubes_test = list(range(20))

    algo_templates, algo_labels = [], []
    algo_templates = [
        "{algo}-arch_{arch}-env_{env}_scramble_steps_{L}-random_length_False-cube_cam_{cube_cam}-num_stack_{k}-mask_type_framestack",
        "{algo}-arch_{arch}-env_{env}_scramble_steps_{L}-random_length_False-cube_cam_{cube_cam}-num_stack_{kappa}-mask_type_masked"
    ]
    algo_labels = ["FrameStack", "AdaptiveStack (Ours)"]

    M = len(cubes_test)
    N = len(cubes)
    metrics = {
               "success_rate": np.zeros((3, M, N)),  # algo × test × train
               "steps" : np.zeros((3, M, N)),
               "returns" : np.zeros((3, M, N)),
               "optimality" : np.zeros((3, M, N))
               }

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(cubes):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name, cube_cam=cube_cam,
                                    L=L_train, k=k, kappa=kappa)

            # load once per run inside test-maze loop ↓
            for test_idx, L_test in enumerate(cubes_test):
                succ_all, ret_all, steps_all, optimality_all = [], [], [], []

                # run-wise evaluation
                for run in range(runs):
                # build test environment *for this algorithm wrapper*
                    test_env = gym.make("cube-v0", episode_steps=100, scramble_steps=L_test, random_length=False,
                                cube_cam=cube_cam, seed=run)
                    if algo_label == "FrameStack":
                        test_env = FrameStack(test_env, k)
                    elif algo_label.startswith("AdaptiveStack"):
                        test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
                    if algorithm == "QL":
                        test_env = TupleObs(test_env)
                        
                    model_path = (f"{data_root}/{algo_name}-run_{run%10}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device)

                    state, _ = test_env.reset()
                    G = 0.0
                    successes = []
                    steps = 0
                    for t in range(max_steps):
                        action, _ = agent(state)
                        state, r, done, trunc, _ = test_env.step(action)
                        G += (gamma ** t) * r
                        steps += 1
                        if r!=0: 
                            successes.append(float(r > 0))
                        if done or trunc: break
                    if not successes: successes.append(0)
                    ret_all.append(G)
                    succ_all += successes
                    steps_all.append(steps)
                    # optimality_all.append(1-abs(steps-L_test)/(100-L_test))

                # average over runs & goals
                metrics["success_rate"][algo_id, test_idx, train_idx] = np.mean(succ_all)
                metrics["returns" ][algo_id, test_idx, train_idx] = np.mean(ret_all)
                metrics["steps" ][algo_id, test_idx, train_idx] = np.mean(steps_all)

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [f"({l},{k})" for (l, k, _) in cubes]
    ylabels = [k for k in cubes_test]

    for algo_id, algo_label in enumerate(algo_labels):
        for key, title in [("success_rate", "success rate"),
                           ("returns" , "returns"),
                           ("steps" , "steps")]:
            fig, ax = plt.subplots(figsize=(figsize_per_cell*M, figsize_per_cell*M))
            sns.heatmap(metrics[key][algo_id], 
                        annot=True, annot_kws={"size": 20}, fmt=".2f", 
                        vmin=0, # if key=="success_rate" else None, 
                        vmax=100, # if key=="success_rate" else None,
                        xticklabels=xlabels, yticklabels=ylabels,
                        cbar_kws=dict(label=title), ax=ax)

            ax.set_xlabel("train (scrambles,stack)")
            ax.set_ylabel("test scrambles")
            ax.set_title(f"{algo_label}", pad=20)

            if algorithm == "QL" or "rlc" in data_root or "pponewdata" in data_root:
                fname = f"{algorithm}-{algo_label}-{env_name}-{cube_cam}-{key}.png".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            else:
                fname = f"{algorithm}-{algo_label}-{arch}-{env_name}-{cube_cam}-{key}.png".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")


    
# ------------------------------------------------------------
# Main: cross-scrambles generalisation plot
# ------------------------------------------------------------
def plots_scrambles_generalisation_boxplot(
        env_name           = "cube-v0",
        data_root          = "data_grpo",
        algorithm          = "GRPO",          # or "PPO"
        arch               = "mlp",
        gamma              = 0.99,
        runs               = 50,
        max_steps          = 100,
        out_dir            = "images",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':60,'axes.titlesize':50,
           'xtick.labelsize':50,'ytick.labelsize':50,'legend.fontsize':30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    # cube_cam = "orthographic" # "face" # 
    # cubes = [[2, 2, 2]] # [[5, 10, 10]] # 
    cube_cam = "face" # 
    cubes = [[5, 10, 10]] # 
    # cubes = [[5,2,2],[5,8,8],[5,16,16],[5,32,32],[5,64,64]] #  
    # cubes = [[10,2,2],[10,8,8],[10,16,16],[10,32,32],[10,64,64]] #  
    cubes_test = list(range(11))

    algo_templates, algo_labels = [], []
    algo_templates = [
        "{algo}-arch_{arch}-env_{env}_scramble_steps_{L}-random_length_False-cube_cam_{cube_cam}-num_stack_{k}-mask_type_framestack",
        "{algo}-arch_{arch}-env_{env}_scramble_steps_{L}-random_length_False-cube_cam_{cube_cam}-num_stack_{kappa}-mask_type_demir",
        "{algo}-arch_{arch}-env_{env}_scramble_steps_{L}-random_length_False-cube_cam_{cube_cam}-num_stack_{kappa}-mask_type_masked"
    ]
    algo_labels = ["FrameStack", "DemirFrameStack", "AdaptiveStack (Ours)"]

    M = len(cubes_test)
    N = len(cubes)
    metrics = {
               "success_rate": np.zeros((3, M, N, runs)),  # algo × test × train
               "steps" : np.zeros((3, M, N, runs)),
               "returns" : np.zeros((3, M, N, runs)),
               "optimality" : np.zeros((3, M, N, runs))
               }

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(cubes):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name, cube_cam=cube_cam,
                                    L=L_train, k=k, kappa=kappa)

            # load once per run inside test-maze loop ↓
            for test_idx, L_test in enumerate(cubes_test):
                succ_all, ret_all, steps_all, optimality_all = [], [], [], []

                # run-wise evaluation
                for run in range(runs):
                # build test environment *for this algorithm wrapper*
                    test_env = gym.make("cube-v0", episode_steps=100, scramble_steps=L_test, random_length=False,
                                cube_cam=cube_cam, seed=run)
                    if algo_label == "FrameStack":
                        test_env = FrameStack(test_env, k)
                    elif algo_label == "DemirFrameStack":
                        test_env = DemirFrameStack(test_env, k)
                    elif algo_label.startswith("AdaptiveStack"):
                        test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
                    if algorithm == "QL":
                        test_env = TupleObs(test_env)
                    
                    model_path = (f"{data_root}/{algo_name}-run_{run%10}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if algo_label == "DemirFrameStack":
                        model_path = (f"{data_root}/{algo_name}-run_{run%10}"
                                    + ("-values.npy" if algorithm=="QL" else "_values"))

                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device, deterministic=False)

                    state, _ = test_env.reset()
                    G = 0.0
                    successes = []
                    steps = 0
                    for t in range(max_steps):
                        action, _ = agent(state)
                        state, r, done, trunc, _ = test_env.step(action)
                        G += (gamma ** t) * r
                        steps += 1
                        if r!=0: 
                            successes.append(float(r > 0))
                        if done or trunc: break
                    if not successes: successes.append(0)
                    ret_all.append(G)
                    succ_all += successes
                    steps_all.append(steps)
                    # optimality_all.append(1-abs(steps-L_test)/(100-L_test))

                # average over runs & goals
                metrics["success_rate"][algo_id, test_idx, train_idx] = succ_all
                metrics["returns" ][algo_id, test_idx, train_idx] = ret_all
                metrics["steps" ][algo_id, test_idx, train_idx] = steps_all
                if L_test==5: print(min(steps_all), steps_all)

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for k in cubes_test]

    for key, title in [
                        ("success_rate", "success rate"),
                        # ("returns" , "returns"),
                        ("steps" , "steps")]:
        types = algo_labels
        for train_idx, (L_train, k, kappa) in enumerate(cubes):
            data = pd.DataFrame(
            [[metrics[key][2][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[2]] for i in range(runs)] +
            [[metrics[key][1][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[1]] for i in range(runs)] +
            [[metrics[key][0][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[0]] for i in range(runs)],
            columns=xlabels+[""])
            data = pd.melt(data, "", var_name="test scrambles", value_name=title)
            fig, ax = plt.subplots()
            # ax = sns.boxplot(x="test scrambles", y=title, data=data, linewidth=3, hue="", showfliers = False)
            ax = sns.lineplot(
                x="test scrambles",
                y=title,
                data=data,
                hue="",
                errorbar=("ci", 95),   # optional: show std deviation as shaded region
                linewidth=2,
                marker="o"       # optional: markers for clarity
            )
            # ax.get_legend().set_visible(False)
            
            # fig_legend = plt.figure(figsize=(3, 1))
            # handles, labels = ax.get_legend_handles_labels()
            # fig_legend.legend(handles, labels, loc='center', ncol=6)
            # fig_legend.savefig("images/legend2.pdf", bbox_inches='tight')

            fname = f"{algorithm}-{arch}-{env_name}-cube_cam_{cube_cam}-scramble_steps_{L_train}-num_stack_{k}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")

def plots_fetch_generalisation(
        env_name           = "FetchReachDense-v4",
        data_root          = "data",
        algorithm          = "PPO",          # "PPO" "GRPO"
        arch               = "transformer",
        gamma              = 0.99,
        runs               = [2],           # range(10)
        max_steps          = 1000000,
        max_episodes       = 100,
        out_dir            = "images",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cuda"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':60,'axes.titlesize':50,
           'xtick.labelsize':50,'ytick.labelsize':50,'legend.fontsize':30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)

    # cube_cam = "orthographic" # "face" # 
    # cubes = [[2, 2, 2]] # [[5, 10, 10]] # 
    cube_cam = "face" # 
    visible = [[2, 50, 4]] # 
    # cubes = [[5,2,2],[5,8,8],[5,16,16],[5,32,32],[5,64,64]] #  
    # cubes = [[10,2,2],[10,8,8],[10,16,16],[10,32,32],[10,64,64]] #  
    # visible_test = [100,1000,10000,100000,1000000] # [100,1000,10000,100000,1000000] # list(range(10,100,10))
    visible_test = [1,2,3,4,5,10,20,30,40,50] # [1,2,4,8,16,32,64] # [1,2,3,4,5,10,20,30,40,50] # list(range(1,51))

    algo_templates, algo_labels = [], []
    algo_templates = [
        "{algo}-arch_{arch}-env_{env}-num_stack_{k}-mask_type_framestack",
        "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_framestack",
        # "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_demir",
        "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_masked"
    ]
    algo_labels = ["FrameStack k*", "FrameStack", "AdaptiveStack (Ours)"]
    # algo_labels = ["FrameStack k*", "FrameStack", "AdaptiveStack (Ours)"]
    # algo_labels = ["FrameStack k*", "FrameStack", "DemirFrameStack", "AdaptiveStack (Ours)"]

    M = len(visible_test)
    N = len(visible)
    metrics = {
               "success_rate": np.zeros((len(algo_templates), M, N, len(runs)+max_episodes)),  # algo × test × train
               "steps" : np.zeros((len(algo_templates), M, N, len(runs)+max_episodes)),
               "returns" : np.zeros((len(algo_templates), M, N, len(runs)+max_episodes)),
               "optimality" : np.zeros((len(algo_templates), M, N, len(runs)+max_episodes))
               }

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(visible):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name, k=k, kappa=kappa)

            # run-wise evaluation
            render_mode = "rgb_array" # "human" #  
            for run in runs:
                test_env = gym.make(env_name, render_mode=render_mode)
                test_env = PartialObsGoal(test_env, visible_goal_steps=2)
                test_env = FlattenObservation(test_env)

                if algo_label == "FrameStack k*":
                    test_env = FrameStack(test_env, k)
                elif algo_label == "Transformer FrameStack":
                    test_env = FrameStack(test_env, k)
                elif algo_label == "FrameStack":
                    test_env = FrameStack(test_env, kappa)
                elif algo_label == "DemirFrameStack":
                    test_env = DemirFrameStack(test_env, kappa)
                elif algo_label.startswith("AdaptiveStack"):
                    test_env = MaskedFrameStack(test_env, kappa, use_multidiscrete=True)
                if algorithm == "QL":
                    test_env = TupleObs(test_env)
                
                model_path = (f"{data_root}/{algo_name}-run_{run}"
                                + ("-values.npy" if algorithm=="QL" else "_values"))
                print(model_path)
                
                if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                    print("FAILED ", model_path)
                    continue  # missing run – silently skip
                agent = load_agent(algo_label, algorithm, model_path, test_env, device, deterministic=False)

                # load once per run inside test-maze loop ↓
                for test_idx, L_test in enumerate(visible_test):
                    print(test_idx, L_test, algo_label, "_____________________________")
                    # build test environment *for this algorithm wrapper*
                    test_env = gym.make(env_name, render_mode=render_mode, max_episode_steps=50)
                    test_env = PartialObsGoal(test_env, visible_goal_steps=L_test)
                    test_env = FlattenObservation(test_env)

                    if algo_label == "FrameStack k*":
                        test_env = FrameStack(test_env, k)
                    elif algo_label == "Transformer FrameStack":
                        test_env = FrameStack(test_env, k)
                    elif algo_label == "FrameStack":
                        test_env = FrameStack(test_env, kappa)
                    elif algo_label == "DemirFrameStack":
                        test_env = DemirFrameStack(test_env, kappa)
                    elif algo_label.startswith("AdaptiveStack"):
                        test_env = MaskedFrameStack(test_env, kappa, use_multidiscrete=True)
                    if algorithm == "QL":
                        test_env = TupleObs(test_env)

                    succ_all, ret_all, steps_all, optimality_all = [], [], [], []
                    for episode in range(max_episodes):
                        state, _ = test_env.reset(seed=episode)
                        G = 0.0
                        successes = 0
                        steps = 0
                        for t in range(max_steps):
                            action, _ = agent(state)
                            state, r, done, trunc, info = test_env.step(action)
                            G += (gamma ** t) * r
                            steps += 1
                            # successes += info["is_success"]
                            # if info["is_success"]: break
                            if done or trunc: 
                                successes += info["is_success"]
                                break
                        ret_all.append(G)
                        succ_all.append(successes)
                        steps_all.append(steps)
                    # optimality_all.append(1-abs(steps-L_test)/(100-L_test))
                    test_env.close()

                    # average over runs & goals
                    metrics["success_rate"][algo_id, test_idx, train_idx, 0:0*max_episodes+max_episodes] = succ_all
                    metrics["returns" ][algo_id, test_idx, train_idx, 0:0*max_episodes+max_episodes] = ret_all
                    metrics["steps" ][algo_id, test_idx, train_idx, 0:0*max_episodes+max_episodes] = steps_all

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for k in visible_test]

    for key, title in [
                        ("success_rate", "success rate"),
                        # ("returns" , "returns"),
                        ("steps" , "steps")]:
        types = algo_labels
        for train_idx, (L_train, k, kappa) in enumerate(visible):
            data = pd.DataFrame(
            # [[metrics[key][3][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[3]] for i in runs+list(range(max_episodes))] +
            [[metrics[key][2][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[2]] for i in runs+list(range(max_episodes))] +
            [[metrics[key][1][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[1]] for i in runs+list(range(max_episodes))] +
            [[metrics[key][0][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[0]] for i in runs+list(range(max_episodes))],
            columns=xlabels+[""])
            var_name="visible goal steps" # "test episode steps"
            data = pd.melt(data, "", var_name=var_name, value_name=title)
            fig, ax = plt.subplots()
            ax = sns.lineplot(
                x=var_name,
                y=title,
                data=data,
                hue="",
                errorbar=("ci", 95),   # optional: show std deviation as shaded region
                linewidth=10,
                marker="o"       # optional: markers for clarity
            )
            ax.get_legend().set_visible(False)
            ax.set_xscale('log', base=2) 
            ax.xaxis.set_major_formatter(ScalarFormatter())
            ax.ticklabel_format(style='plain', axis='x')
            # ax.xaxis.set_major_formatter(ScalarFormatter())
            # ax.vlines(50, color="black", lw=10, marker="--")
            
            # fig_legend = plt.figure(figsize=(3, 1))
            # handles, labels = ax.get_legend_handles_labels()
            # fig_legend.legend(handles, labels, loc='center', ncol=6)
            # fig_legend.savefig("images/legend2.pdf", bbox_inches='tight')

            fname = f"{algorithm}-{arch}-{env_name}-num_stack_{k}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")

def plots_fetch_generalisation_episode(
        env_name           = "FetchReachDense-v4",
        data_root          = "data",
        algorithm          = "PPO",          # "PPO" "GRPO"
        arch               = "transformer",
        gamma              = 0.99,
        runs               = [2],           # range(10)
        max_steps          = 1000000,
        max_episodes       = 100,
        out_dir            = "images",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cuda"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':60,'axes.titlesize':50,
           'xtick.labelsize':50,'ytick.labelsize':50,'legend.fontsize':30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)

    # cube_cam = "orthographic" # "face" # 
    # cubes = [[2, 2, 2]] # [[5, 10, 10]] # 
    cube_cam = "face" # 
    visible = [[2, 50, 4]] # 
    # cubes = [[5,2,2],[5,8,8],[5,16,16],[5,32,32],[5,64,64]] #  
    # cubes = [[10,2,2],[10,8,8],[10,16,16],[10,32,32],[10,64,64]] #  
    visible_test = list(range(0,100)) # [100,1000,10000,100000,1000000] # list(range(10,100,10))
    # visible_test = list(range(1,51,5)) # [1,2,4,8,16,32,64] # [1,2,3,4,5,10,20,30,40,50] # list(range(1,51))

    algo_templates, algo_labels = [], []
    algo_templates = [
        "{algo}-arch_{arch}-env_{env}-num_stack_{k}-mask_type_framestack",
        "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_framestack",
        # "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_demir",
        "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_masked"
    ]
    algo_labels = ["FrameStack k*", "FrameStack", "AdaptiveStack (Ours)"]
    # algo_labels = ["FrameStack k*", "FrameStack", "AdaptiveStack (Ours)"]
    # algo_labels = ["FrameStack k*", "FrameStack", "DemirFrameStack", "AdaptiveStack (Ours)"]

    M = len(visible_test)
    N = len(visible)
    metrics = {
               "success_rate": np.zeros((len(algo_templates), M, N, len(runs)*max_episodes)),  # algo × test × train
               "steps" : np.zeros((len(algo_templates), M, N, len(runs)*max_episodes)),
               "returns" : np.zeros((len(algo_templates), M, N, len(runs)*max_episodes)),
               "optimality" : np.zeros((len(algo_templates), M, N, len(runs)*max_episodes))
               }

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (L_train, k, kappa) in enumerate(visible):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name, k=k, kappa=kappa)
            print(algo_name)

            # run-wise evaluation
            render_mode = "rgb_array" # "human" #  
            for run in runs:
                test_env = gym.make(env_name, render_mode=render_mode, max_episode_steps=visible_test[-1])
                test_env = PartialObsGoal(test_env, visible_goal_steps=2)
                test_env = FlattenObservation(test_env)

                if algo_label == "FrameStack k*":
                    test_env = FrameStack(test_env, k)
                elif algo_label == "Transformer FrameStack":
                    test_env = FrameStack(test_env, k)
                elif algo_label == "FrameStack":
                    test_env = FrameStack(test_env, kappa)
                elif algo_label == "DemirFrameStack":
                    test_env = DemirFrameStack(test_env, kappa)
                elif algo_label.startswith("AdaptiveStack"):
                    test_env = MaskedFrameStack(test_env, kappa, use_multidiscrete=True)
                if algorithm == "QL":
                    test_env = TupleObs(test_env)
                
                model_path = (f"{data_root}/{algo_name}-run_{run}"
                                + ("-values.npy" if algorithm=="QL" else "_values"))
                print(model_path)
                
                if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                    print("FAILED ", model_path)
                    continue  # missing run – silently skip
                agent = load_agent(algo_label, algorithm, model_path, test_env, device, deterministic=False)

                for episode in range(max_episodes):
                    state, _ = test_env.reset(seed=episode)
                    G = 0.0
                    steps = 0
                    successes = 0
                    for t in range(max_steps):
                        action, _ = agent(state)
                        state, r, done, trunc, info = test_env.step(action)
                        G += (gamma ** t) * r
                        steps += 1
                        successes += info["is_success"] 
                        metrics["success_rate"][algo_id, steps, train_idx, 0*episode+episode] = info["is_success"] # successes/steps
                        metrics["returns" ][algo_id, steps, train_idx, 0*episode+episode] = (gamma ** t) * r
                        if done or trunc: break
                # optimality_all.append(1-abs(steps-L_test)/(100-L_test))
                test_env.close()


    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for k in visible_test]

    for key, title in [
                        ("success_rate", "success rate"),
                        ("returns" , "returns"),
                        ]:
        types = algo_labels
        for train_idx, (L_train, k, kappa) in enumerate(visible):
            data = pd.DataFrame(
            # [[metrics[key][3][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[3]] for i in range(len(runs)*max_episodes)] +
            [[metrics[key][2][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[2]] for i in range(len(runs)*max_episodes)] +
            [[metrics[key][1][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[1]] for i in range(len(runs)*max_episodes)] +
            [[metrics[key][0][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[0]] for i in range(len(runs)*max_episodes)],
            columns=xlabels+[""])
            var_name="test episode steps" 
            data = pd.melt(data, "", var_name=var_name, value_name=title)
            fig, ax = plt.subplots()
            ax = sns.lineplot(
                x=var_name,
                y=title,
                data=data,
                hue="",
                errorbar=("ci", 95),   # optional: show std deviation as shaded region
                linewidth=10,
                # marker="o"       # optional: markers for clarity
            )
            ax.get_legend().set_visible(False)
            # ax.set_xscale('log') 
            if key == "success_rate":
                ax.vlines(x=50, ymin=0,ymax=1, color="black", lw=10, ls="--")
            if key == "returns":
                ax.vlines(x=50, ymin=-0.5,ymax=0, color="black", lw=10, ls="--")
            
            # fig_legend = plt.figure(figsize=(3, 1))
            # handles, labels = ax.get_legend_handles_labels()
            # fig_legend.legend(handles, labels, loc='center', ncol=6)
            # fig_legend.savefig("images/legend2.pdf", bbox_inches='tight')

            fname = f"{algorithm}-{arch}-{env_name}-num_stack_{k}-{key}.pdf".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")
    
# ------------------------------------------------------------
# Main: cross-scrambles generalisation plot
# ------------------------------------------------------------
def plots_popgym_generalisation_boxplot(
        env_name           = "popgym-PositionOnlyCartPoleHard",
        data_root          = "data_rebuttal_1",
        algorithm          = "PPO",          # or "PPO"
        arch               = "mlp",
        gamma              = 0.99,
        runs               = 100,
        out_dir            = "images",
        figsize_per_cell   = 1.8,           # heat-map scaling
        device             = "cpu"
):
    """Cross-evaluate all trained policies on all maze lengths and plot heat-maps."""

    # ---------------- Settings identical to your original ----------------
    rc_ = {'figure.figsize':(10,8),'axes.labelsize':40,'axes.titlesize':30,
           'xtick.labelsize':30,'ytick.labelsize':30,'legend.fontsize':30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)

    config = [[2, 2]] # [[5, 10, 10]] # 
    config_test = [600,800,1000]

    algo_templates, algo_labels = [], []
    algo_templates = [
        "{algo}-arch_{arch}-env_{env}-num_stack_{k}-mask_type_framestack",
        "{algo}-arch_{arch}-env_{env}-num_stack_{kappa}-mask_type_masked"
    ]
    algo_labels = ["FrameStack", "AdaptiveStack (Ours)"]

    M = len(config_test)
    N = len(config)
    metrics = {
               "success_rate": np.zeros((3, M, N, runs)),  # algo × test × train
               "steps" : np.zeros((3, M, N, runs)),
               "returns" : np.zeros((3, M, N, runs)),
               "optimality" : np.zeros((3, M, N, runs))
               }

    # ---------------- Loop over (train_maze, test_maze, run) -------------
    for train_idx, (k, kappa) in enumerate(config):
        for algo_id, (tmpl, algo_label) in enumerate(zip(algo_templates, algo_labels)):
            algo_name = tmpl.format(algo=algorithm, arch=arch, env=env_name, k=k, kappa=kappa)

            # test_env = gym.make("cube-v0", episode_steps=100, scramble_steps=L_train, random_length=False,
            #             cube_cam=cube_cam, seed=0)
            # if algo_label == "FrameStack":
            #     test_env = FrameStack(test_env, k)
            # elif algo_label.startswith("AdaptiveStack"):
            #     test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
            # if algorithm == "QL":
            #     test_env = TupleObs(test_env)
            # model_path = (f"{data_root}/{algo_name}-run_{0}"
            #                 + ("-values.npy" if algorithm=="QL" else "_values"))
            # if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
            #     continue  # missing run – silently skip
            # agent = load_agent(algo_label, algorithm, model_path, test_env, device)
        
            # load once per run inside test-maze loop ↓
            for test_idx, L_test in enumerate(config_test):
                succ_all, ret_all, steps_all, optimality_all = [], [], [], []

                # run-wise evaluation
                for run in range(runs):
                # build test environment *for this algorithm wrapper*
                    if "PositionOnlyCartPole" in env_name:
                        test_env = popgym.envs.position_only_cartpole.PositionOnlyCartPoleHard(max_episode_length=L_test)
                        test_env.max_episode_length = L_test
                        test_env = DiscreteAction(Flatten(PreviousAction(test_env))) 
                    elif "VelocityOnlyCartPole" in env_name:
                        test_env = popgym.envs.velocity_only_cartpole.VelocityOnlyCartPoleHard(max_episode_length=L_test)
                        test_env.max_episode_length = L_test
                        test_env = DiscreteAction(Flatten(PreviousAction(test_env)))  
                    elif "NoisyPositionOnlyCartPole" in env_name:
                        test_env = popgym.envs.noisy_position_only_cartpole.NoisyPositionOnlyCartPole(max_episode_length=L_test)
                        test_env.max_episode_length = L_test
                        test_env = DiscreteAction(Flatten(PreviousAction(test_env)))  
                    
                    if algo_label == "FrameStack":
                        test_env = FrameStack(test_env, k)
                    elif algo_label.startswith("AdaptiveStack"):
                        test_env = MaskedFrameStack(test_env, kappa,use_multidiscrete=True)
                    if algorithm == "QL":
                        test_env = TupleObs(test_env)
                    
                    model_path = (f"{data_root}/{algo_name}-run_{run%10}"
                                  + ("-values.npy" if algorithm=="QL" else "_values"))
                    if not os.path.exists(model_path) and not os.path.exists(model_path+".zip"):
                        print("FAILED ", model_path)
                        continue  # missing run – silently skip
                    agent = load_agent(algo_label, algorithm, model_path, test_env, device, deterministic=False)

                    state, _ = test_env.reset()
                    G = 0.0
                    steps = 0
                    for t in range(L_test):
                        action, _ = agent(state)
                        state, r, done, trunc, _ = test_env.step(action)
                        G += (gamma ** t) * r
                        steps += 1
                        if done or trunc: break
                    # print(run,L_test,steps,done, trunc)
                    ret_all.append(G)
                    steps_all.append(steps)
                    # optimality_all.append(1-abs(steps-L_test)/(100-L_test))

                # average over runs & goals
                metrics["returns" ][algo_id, test_idx, train_idx] = ret_all
                metrics["steps" ][algo_id, test_idx, train_idx] = (steps_all)

    # ---------------- Plot & save  ---------------------------------------
    os.makedirs(out_dir, exist_ok=True)
    xlabels = [k for k in config_test]

    for key, title in [
                        # ("success_rate", "success rate"),
                        # ("returns" , "returns"),
                        ("steps" , "steps")]:
        types = algo_labels
        for train_idx, (k, kappa) in enumerate(config):
            data = pd.DataFrame(
            [[metrics[key][1][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[1]] for i in range(runs)] +
            [[metrics[key][0][t,train_idx,i] for t in range(len(xlabels))]+[algo_labels[0]] for i in range(runs)],
            columns=xlabels+[""])
            data = pd.melt(data, "", var_name="test episode length", value_name=title)
            fig, ax = plt.subplots()
            ax = sns.boxplot(x="test episode length", y=title, hue="", data=data, linewidth=3, showfliers = False)

            fname = f"{algorithm}-{arch}-{env_name}-num_stack_{k}-{key}.png".replace(" ", "-").replace("$","").replace("\\","").replace("*","").replace("^","")
            plt.tight_layout()
            fig.savefig(os.path.join(out_dir, fname), bbox_inches="tight")
            plt.close(fig)

    print(f"Saved heat-maps to {out_dir}")

def plots_ppo_neurips_agregate_std():
    # Styling
    s = 30
    rc_ = {'figure.figsize': (10, 8), 'axes.labelsize': 40,
           'xtick.labelsize': s, 'ytick.labelsize': s, 'legend.fontsize': 30}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=False)

    # Experiment settings
    runs = 3
    maxiter = 1
    algorithm, env = "PPO", "passive_tmaze-v0"
    names = ["rewards"]#, "passive_count", "active_count", "mask_regret"]
    nice_names = ["rewards"]#, "passive regret", "active regret", "memory regret"]
    if "active" in env:
        mazes = [[6, 8, 2]] # [[0, 2, 2], [2, 4, 2], [6, 8, 2]]
    else:
        mazes = [[14, 16, 2]] # [[0, 2, 2], [2, 4, 2], [6, 8, 2], [14, 16, 2], [62, 64, 2]]

    mazes = [[14, 16, 2]]
    arches = ["mlp", "lstm", "transformer"]
    variants = [r"AdaptiveStack $k=\kappa$ (Ours)", r"FrameStack $k=\kappa$", r"FrameStack $k=k^*$"]

    # Containers for aggregated results
    agg_means = {name: {var: [] for var in variants} for name in names}
    agg_stds = {name: {var: [] for var in variants} for name in names}

    # Process each architecture
    for arch in arches:
        # Temporary storage per arch
        data2 = {name: [] for name in names}
        data3 = {name: [] for name in names}
        data4 = {name: [] for name in names}

        for maze_length, k, kappa in mazes:
            # Define replay files for each variant
            algos = [
                f"{algorithm}-arch_{arch}-env_{env}_maze_length_{maze_length}-random_length_True-num_stack_{kappa}-mask_type_masked",
                f"{algorithm}-arch_{arch}-env_{env}_maze_length_{maze_length}-random_length_True-num_stack_{kappa}-mask_type_framestack",
                f"{algorithm}-arch_{arch}-env_{env}_maze_length_{maze_length}-random_length_True-num_stack_{k}-mask_type_framestack",
            ]
            labels = variants

            for algo_key, label in zip(algos, labels):
                # Collect runs
                runs_data = {name: [] for name in names}
                for run in range(runs):
                    data_path = f"data/{algo_key}-run_{run}.npy"
                    if os.path.exists(data_path):
                        print(data_path)
                        log = np.load(data_path, allow_pickle=True).tolist()
                        for name in names:
                            arr = np.array(log[name])
                            print(log["T"], len(log["rewards"]), arr.shape[0])
                            if arr.shape[0] >= maxiter:
                                runs_data[name].append([arr.sum()])
                                # runs_data[name].append(arr[:maxiter])
                            else: 
                                print("FAILED", run, label,name, arr.shape)
                    else:
                        print("FAILED", data_path)
                        # break
                # Sum over timesteps
                for name in names:
                    if len(runs_data[name]) == 0:
                        continue
                    print(np.array(runs_data[name]).shape)
                    sums = np.array(runs_data[name]).sum(axis=1)
                    if label == variants[0]:
                        data2[name].append(sums)
                    elif label == variants[1]:
                        data3[name].append(sums)
                    else:
                        data4[name].append(sums)

        # # Now compute aggregation across mazes (mean of per-maze means)
        # for name in names:
        #     for variant, data_list in zip(variants, [data2[name], data3[name], data4[name]]):
        #         if not data_list:
        #             agg_means[name][variant].append(np.nan)
        #             agg_stds[name][variant].append(np.nan)
        #         else:
        #             # First compute per-maze mean across runs
        #             maze_means = [arr.mean() for arr in data_list]
        #             # Then mean and std across mazes
        #             agg_means[name][variant].append(np.mean(maze_means))
        #             agg_stds[name][variant].append(np.std(maze_means))
        
        # Aggregate across mazes and runs for this arch
        for name in names:
            for i, variant in enumerate(variants):
                # Select correct data list
                arr_list = data2[name] if i == 0 else (data3[name] if i == 1 else data4[name])
                if len(arr_list) == 0:
                    agg_means[name][variant].append(np.nan)
                    agg_stds[name][variant].append(np.nan)
                    continue
                flat = np.concatenate(arr_list)
                agg_means[name][variant].append(flat.mean())
                agg_stds[name][variant].append(flat.std())
                # 95% CI: 1.96 * (std / sqrt(n))
                # n = len(flat)
                # std = np.std(flat, ddof=1)
                # sem = std / np.sqrt(n)
                # ci95 = 1.96 * sem
                # agg_means[name][variant].append(flat)
                # agg_stds[name][variant].append(ci95)

    # Plot one aggregate bar chart per metric
    x = np.arange(len(arches))
    width = 0.25
    for nice_name, name in zip(nice_names, names):
        fig, ax = plt.subplots()
        for j, variant in enumerate(variants):
            means = agg_means[name][variant]
            stds = agg_stds[name][variant]
            ax.bar(x + (j-1)*width, means, width,
                   yerr=stds, align='center', ecolor='black', capsize=5,
                   label=variant)
        ax.set_xticks(x)
        ax.set_xticklabels(arches)
        ax.set_xlabel("Architecture")
        ax.set_ylabel(nice_name)
        ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
        plt.legend(loc='upper right')
        plt.tight_layout()
        plt.show()

        # Save figure
        save_path = f"{args.path}/pdf/{algorithm}-{env}-{name}-aggregate.pdf"
        fig.savefig(save_path, bbox_inches='tight')

def plots_fetch_agregate():
    # Styling
    s = 45
    rc_ = {'figure.figsize': (10, 8), 'axes.labelsize': 80,
           'xtick.labelsize': s, 'ytick.labelsize': s, 'legend.fontsize': 40}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)

    # Experiment settings
    runs = 3
    maxiter = 1
    algorithm, env = "PPO", "FetchReachDense-v4"
    names = ["success"]#, "passive_count", "active_count", "mask_regret"]
    nice_names = ["successes"]#, "passive regret", "active regret", "memory regret"]

    episode_lengths = [[50, 50, 4]]
    arches = ["MLP", "LSTM", "Transformer"]
    variants = [r"AdaptiveStack $k=\kappa$ (Ours)", r"FrameStack $k=\kappa$", r"FrameStack $k=k^*$"]

    # Containers for aggregated results
    agg_means = {name: {var: [] for var in variants} for name in names}
    agg_stds = {name: {var: [] for var in variants} for name in names}

    # Process each architecture
    for arch in arches:
        # Temporary storage per arch
        data2 = {name: [] for name in names}
        data3 = {name: [] for name in names}
        data4 = {name: [] for name in names}

        for episode_length, k, kappa in episode_lengths:
            # Define replay files for each variant
            if (algorithm, env, arch) == ("PPO", "FetchReachDense-v4", "LSTM"):
                algos = [
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format("PPO", "lstm", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format("PPO", "lstm", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format("PPO", "lstm", env, k),
                        # "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format("RecurrentPPO", "mlp", env, 1),
                        ]
                # algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $k^*$"]
            elif (algorithm, env, arch) == ("PPO", "FetchReachDense-v4", "Transformer"):
                algos = [
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format(algorithm, "transformer", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, "transformer", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, "transformer", env, k),
                        ]
                # algos_label = ["AdaptiveStack $\kappa$ (Ours)","FrameStack $k^*$"]
            else:
                algos = [
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_masked".format(algorithm, "mlp", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, "mlp", env, kappa),
                        "{}-arch_{}-env_{}-num_stack_{}-mask_type_framestack".format(algorithm, "mlp", env, k),
                        # "{}-arch_{}-env_{}-num_stack_{}-mask_type_demir".format(algorithm, arch, env_name, k),
                        ]
            
            labels = variants

            for algo_key, label in zip(algos, labels):
                # Collect runs
                runs_data = {name: [] for name in names}
                for run in range(runs):
                    if run==6: continue
                    data_path = f"data/{algo_key}-run_{run}.npy"
                    if os.path.exists(data_path):
                        print(data_path)
                        log = np.load(data_path, allow_pickle=True).tolist()
                        for name in names:
                            arr = np.array(log[name])
                            print(log["T"], len(log["rewards"]), arr.shape[0])
                            if arr.shape[0] >= maxiter:
                                runs_data[name].append([arr.sum()])
                                # runs_data[name].append(arr[:maxiter])
                            else: 
                                print("FAILED", run, label,name, arr.shape)
                    else:
                        print("FAILED", data_path)
                        # break
                # Sum over timesteps
                for name in names:
                    if len(runs_data[name]) == 0:
                        continue
                    print(np.array(runs_data[name]).shape)
                    sums = np.array(runs_data[name]).sum(axis=1)
                    if label == variants[0]:
                        data2[name].append(sums)
                    elif label == variants[1]:
                        data3[name].append(sums)
                    else:
                        data4[name].append(sums)

        # Aggregate across episode_lengths and runs for this arch
        for name in names:
            for i, variant in enumerate(variants):
                # Select correct data list
                arr_list = data2[name] if i == 0 else (data3[name] if i == 1 else data4[name])
                if len(arr_list) == 0:
                    agg_means[name][variant].append(np.nan)
                    agg_stds[name][variant].append(np.nan)
                    continue
                flat = np.concatenate(arr_list)
                agg_means[name][variant].append(flat.mean())
                agg_stds[name][variant].append(flat.std())
                # # 95% CI: 1.96 * (std / sqrt(n))
                # n = len(flat)
                # std = np.std(flat, ddof=1)
                # sem = std / np.sqrt(n)
                # ci95 = 1.96 * sem
                # agg_stds[name][variant].append(ci95)

    # Plot one aggregate bar chart per metric
    x = np.arange(len(arches))
    width = 0.25
    for nice_name, name in zip(nice_names, names):
        fig, ax = plt.subplots()
        for j, variant in enumerate(variants):
            means = agg_means[name][variant]
            stds = agg_stds[name][variant]
            ax.bar(x + (j-1)*width, means, width,
                   yerr=stds, align='center', ecolor='black', capsize=5,
                   label=variant)
        ax.set_xticks(x)
        ax.set_xticklabels(arches)
        ax.set_xlabel("architecture")
        ax.set_ylabel(nice_name)
        ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
        # plt.legend(loc='upper right')
        plt.tight_layout()
        plt.show()

        # Save figure
        save_path = f"{args.path}/pdf/{algorithm}-{env}-{name}-aggregate.pdf"
        fig.savefig(save_path, bbox_inches='tight')


def plots_k_kappa():
    s = 60
    rc_ = {'figure.figsize':(10,8),'axes.labelsize': 65, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 45}
    sns.set(rc=rc_, style="darkgrid")
    # rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    lw = 10.0
    l = np.arange(2,13)
    k = l
    kappa = l*0+2
    print(k/kappa)
    ax.plot(l, k,  label=r"FS ($k=k^*$)", lw = lw)
    ax.plot(l, kappa,  label=r"AS ($k=\kappa$)", lw = lw)
    ax.plot(l, k/kappa,  label=r"$k^*$/$\kappa$", lw = lw)
    
    ax.legend(loc='upper left')
    plt.xlabel(r"maze length ($L+2$)")
    plt.ylabel("minimal memory")
    plt.xlim((2,len(l)-1))
    plt.ylim(bottom=0)
    # ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
    # ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    fig.savefig("{}/{}.pdf".format(args.path,"k_kappa"), bbox_inches='tight')
    # plt.show()
        
def plots_values():
    s = 60
    rc_ = {'figure.figsize':(10,8),'axes.labelsize': 65, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 45}
    sns.set(rc=rc_, style="darkgrid")
    # rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    lw = 10.0
    l = np.arange(3,1000)
    gamma = 0.99
    V_FS_k = gamma**(l-2)
    V_FS_kappa = l*0.0
    V_FS_kappa[0] = V_FS_k[0]
    V_AS_kappa = (1/(l-2))*(gamma*(1-gamma**(l-2))/(1-gamma))
    # print(V_FS_k, V_AS_kappa, V_FS_kappa)
    V_nmdp = l*0
    ax.plot(l, V_FS_k,  label=r"FS ($k=k^*$)", lw = lw)
    # ax.plot(l, V_FS_kappa,  label=r"Frame Stacking ($k=\kappa$)", lw = lw)
    ax.plot(l, V_AS_kappa,  label=r"AS ($k=\kappa$)", lw = lw)
    # ax.plot(l, V_nmdp,  label=r"no memory (k=1)", lw = lw)
    ax.plot(l, np.abs(V_AS_kappa-V_FS_k),  label=r"Value gap", lw = lw)
    
    ax.legend()
    plt.xlabel(r"maze length ($L+2$)")
    plt.ylabel("optimal values")
    plt.xlim((3,len(l)))
    plt.ylim((0,1))
    # ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
    # ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    fig.savefig("{}/{}.pdf".format(args.path,"optimal_values"), bbox_inches='tight')
    # plt.show()
    
        
def plots_values2():
    s = 50
    rc_ = {'figure.figsize':(13,9),'axes.labelsize': 60, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': s}
    sns.set(rc=rc_, style="darkgrid")
    # rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    lw = 6.0
    l = np.arange(3,1000)+0.0
    gamma = 0.99
    V_FS_k = gamma**(l-2)
    V_FS_kappa = V_FS_k*0.0
    V_FS_kappa[0] = V_FS_k[0]
    V_AS_kappa = (1/(l-2))*(gamma*(1-gamma**(l-2))/(1-gamma))
    # print(V_FS_k, V_AS_kappa, V_FS_kappa)
    V_nmdp = l*0
    ax.plot(l, np.abs(V_AS_kappa-V_FS_k),  label=r"Value gap", color="black", lw = lw)
    ax.plot(l, V_FS_k,  label=r"FS ($k=k^*$)", lw = lw)
    # ax.plot(l, V_FS_kappa,  label=r"FS ($k=\kappa$)", lw = lw)
    ax.plot(l, V_AS_kappa,  label=r"AS ($k=\kappa$) $\pi^*_k$", lw = lw)
    # ax.plot(l, V_nmdp,  label=r"no memory (k=1)", lw = lw)
    ax.plot(l, V_AS_kappa,  label=r"AS ($k=k^*$) $\pi^*_k$", linestyle=(0, (5, 5)), lw = lw)
    ax.plot(l, V_FS_k,  label=r"AS ($k=k^*$) $\pi^{FS}_k$", linestyle=(0, (5, 5)), lw = lw)
    
    ax.legend()
    # ax.legend(loc=(1,0.04))
    plt.xlabel(r"maze length ($L+2$)")
    plt.ylabel("optimal values")
    plt.xlim((3,len(l)))
    plt.ylim((0,1))
    # ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
    # ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    fig.savefig("{}/{}.pdf".format(args.path,"optimal_values_2"), bbox_inches='tight')
    # plt.show()
    
#####################################################################################

if __name__ == '__main__': 
    # plots_curves()
    # plots_barchart()
    # plots_ppo_neurips_aggregate_CI()
    # plots_ppo_neurips_agregate_std()

    # plots_generalisation()
    # plots_memory_generalisation()
    # plots_memory_barchart()
    # plots_curves_otherenvs()
    # plots_scrambles_generalisation()
    # plots_scrambles_generalisation_boxplot()
    # plots_popgym_generalisation_boxplot()
    # plots_memory_usage()
    plots_fetch_generalisation()
    # plots_fetch_generalisation_episode()
    # plots_fetch_agregate()

    # plots_values()
    # plots_k_kappa()
    # plots_values2()
