import torch
import torch.nn
import torch.nn.functional as F
from IPython import embed
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset
from net import Net, Transformer, TransformerTall, TransformerBERT, TransformerTall2
from lqr_env import LQREnv, LQRController, TransformerController
from darkroom_env import DarkroomEnv, DarkroomEnvStitch, DarkroomOptPolicy,  DarkroomTransformerController, RandCommit
import bandit_env
from bandit_env import LinearBanditEnv, BanditEnvVec
from bandit_env import BanditTransformerController, BanditEnv, OptPolicy, GreedyOptPolicy, PessMeanPolicy, EmpMeanPolicy, UCBPolicy
from bandit_env import TopKBanditTransformerController, TopKBanditEnv, TopKRandCommitPolicy, LinUCB, ThompsonSamplingPolicy2
import pandas as pd
import scipy.stats
from evals import eval_bandit
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




def analyze_pess(consts, eval_trajs):
    
    const_values = []
    for const in consts:

        all_rs = []
        for i in range(len(eval_trajs)):
            traj = eval_trajs[i]
            means = traj['means']
            env = BanditEnv(means, H, var=var)
            
            batch = {
                'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device),
                'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device),
                'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device),
                'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).float().to(device)
            }
            pess = PessMeanPolicy(env, const=const)
            pess.set_batch(batch)

            xs_pess, us_pess, xps_pess, rs_pess = env.deploy_eval(pess)
            all_rs.append(rs_pess)

        const_values.append(all_rs)
    
    return np.array(const_values)
            


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding")
    parser.add_argument("--head", type=int, required=False, default=1, help="Embedding")
    parser.add_argument("--layer", type=int, required=False, default=3, help="Embedding")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning rate")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--lin_d", type=int, required=False, default=1, help="Linear dimension")
    parser.add_argument("--epoch", type=int, required=False, default=-1, help="Epoch to evaluate")
    parser.add_argument("--opt", type=int, required=False, default=0, help="Optimizer type")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Variance")
    parser.add_argument("--test_var", type=float, required=False, default=-1.0, help="Test Variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage")
    parser.add_argument("--test_cov", type=float, required=False, default=-1.0, help="Test coverage")
    parser.add_argument("--trans", type=int, required=False, default=0, help="Transformer type")
    parser.add_argument("--hor", type=int, required=False, default=-1, help="Episode horizon (for mdp)")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top K value")
    parser.add_argument("--orig", type=int, required=False, default=2, help="Orig")
    parser.add_argument("--test_orig", type=int, required=False, default=2, help="Orig")


    parser.add_argument('--full', default=False, action='store_true')
    parser.add_argument('--shuffle', default=False, action='store_true')
    parser.add_argument('--test', default=False, action='store_true')

    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    dim = args['dim']
    dx = dim
    du = dim
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    lr = args['lr']
    epoch = args['epoch']
    shuffle = args['shuffle']
    full = args['full']
    opt = args['opt']
    dropout = args['dropout']
    var = args['var']
    test_var = args['test_var']
    trans = args['trans']
    cov = args['cov']
    test_cov = args['test_cov']
    envname = args['env']
    horizon = args['hor']
    use_test = args['test']
    topk = args['k']
    orig = args['orig']
    test_orig = args['test_orig']
    lin_d = args['lin_d']
    warm_start = False

    use_net = False

    if test_cov < 0:
        test_cov = cov
    if test_var < 0:
        test_var = var
    if horizon < 0:
        horizon = H

    if envname in ['bandit', 'bandit_ood']:
        bandit = True
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_orig{orig}_H{H}_d{dim}'
        bandit_type = 'uniform'
        prior_mean = 0.5
        prior_var = 1.0/12.0

    elif envname == 'bandit_thompson':
        bandit = True
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_H{H}_d{dim}'
        bandit_type = 'bernoulli'
        prior_mean = 0.5
        prior_var = 1.0/12.0

    elif envname == 'bandit_topk':
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_k{topk}_H{H}_d{dim}'
        prior_mean = 0.5
        prior_var = 1.0/12.0
    
    elif envname == 'linear_bandit':
        dx = 1
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_H{H}_d{dim}_dlin{lin_d}_ws{warm_start}'
        bandit_type = 'uniform'
        prior_mean = 0.0
        prior_var = 1.0
        
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        bandit = False
        dx = 2
        du = 5
        prefix = envname
        filename = f'{prefix}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}'
    
    else:
        raise NotImplementedError

    config = {
        'H': H,
        'dx': dx,
        'du': du,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'Q': False,
        'full': full,
        'dropout': dropout,
    }

    # load network from saved file. By default, load the final file, otherwise load specified epoch
    if use_net:                 model = Net(config).to(device)
    elif trans == 0:            model = Transformer(config).to(device)
    elif trans == 1:            model = TransformerTall(config).to(device)
    elif trans == 2:            model = TransformerTall2(config).to(device)
    else:                       model = TransformerBERT(config).to(device)

    if epoch < 0:       model_path = f'models/{filename}.pt'   
    else:               model_path = f'models/{filename}_epoch{epoch}.pt' 
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint)
    model.eval()
    model.config['full'] = False
    
    # load eval trajs from evaluation datasets
    n_eval = 200
    H_eval = 10
    
    # if envname == 'bandit':             eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{H}_d{dim}_var{var}_cov{test_cov}.pkl'
    # elif envname == 'darkroom':         eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{H}_d{dim}.pkl'
    # else:                               raise ValueError(f'Environment {envname} not supported')

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
        eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{horizon}_d{dim}_var{test_var}_cov{test_cov}_orig{test_orig}.pkl'   
        save_filename = f'{filename}_testcov{test_cov}_testvar{test_var}_testorig{test_orig}_hor{horizon}.pkl'
    
    elif envname == 'bandit_topk':
        eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{horizon}_d{dim}_var{var}_k{topk}.pkl'   
        save_filename = f'{filename}_hor{horizon}.pkl'
    
    elif envname == 'linear_bandit':
        eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{horizon}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}.pkl'   
        save_filename = f'{filename}_hor{horizon}_ws{warm_start}.pkl'

    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        if use_test:    
            eval_filepath = f'datasets/trajs_{envname}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
            eval_filepath_train = f'datasets/trajs_{envname}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        else:           eval_filepath = f'datasets/trajs_eval_{envname}_envs{n_eval}_H{H}_d{dim}.pkl'
        save_filename = f'{filename}_hor{horizon}_test{use_test}.pkl'
    else:
        raise ValueError(f'Environment {envname} not supported')   


    file = open(eval_filepath, 'rb')
    eval_trajs = pickle.load(file)
    file.close()

    
    if ['darkroom', 'darkroom_heldout', 'darkroom_stitch'] and use_test:
        file = open(eval_filepath_train, 'rb')
        eval_trajs_train = pickle.load(file)
        file.close()

        train_goals = []
        for traj in eval_trajs_train:
            goal = tuple(traj['goal'])
            if goal not in train_goals:
                train_goals.append(goal)
        
        if envname != 'darkroom_stitch':    # allow repeat goals
            eval_trajs = [traj for traj in eval_trajs if tuple(traj['goal']) not in train_goals]

        eval_trajs2 = []
        eval_goals = []
        for traj in eval_trajs:
            goal = tuple(traj['goal'])
            if goal not in eval_goals:
                eval_trajs2.append(traj)
                eval_goals.append(goal)
        eval_trajs = eval_trajs2

        # repeat eval_trajs until at most n_eval
        assert len(eval_trajs) > 0, "No eval trajs found"
        while len(eval_trajs) < n_eval:
            eval_trajs += eval_trajs
        eval_trajs = eval_trajs[:n_eval]


    n_eval = min(n_eval, len(eval_trajs))

    evals_filename = f"evals_epoch{epoch}"
    if not os.path.exists(f'figs/{evals_filename}'):
        os.makedirs(f'figs/{evals_filename}', exist_ok=True)
        os.makedirs(f'data_results/{evals_filename}', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/pess'):
        os.makedirs(f'figs/{evals_filename}/pess', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/bar'):
        os.makedirs(f'figs/{evals_filename}/bar', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/lines'):
        os.makedirs(f'figs/{evals_filename}/lines', exist_ok=True)
    if not os.path.exists(f'figs/{evals_filename}/online'):
        os.makedirs(f'figs/{evals_filename}/online', exist_ok=True)
    os.makedirs(f'data_results/{evals_filename}/online', exist_ok=True)
    


    if envname in ['bandit', 'bandit_topk', 'bandit_thompson', 'linear_bandit']:
        config = {
            'H': H,
            'horizon': horizon,
            'var': var,
            'n_eval': n_eval,
            'envname': envname,
            'k': topk,
            'type': bandit_type,
            'prior_mean': prior_mean,
            'prior_var': prior_var,
            'test_var': test_var,
        }
        online_results = eval_bandit.online_vec(eval_trajs, model, **config)
        plt.savefig(f'figs/{evals_filename}/online/{save_filename}.png')
        plt.clf()

        # save online_results as pickle in data_results
        file = open(f'data_results/{evals_filename}/online/{save_filename}', 'wb')
        pickle.dump(online_results, file)
        file.close()

    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        config = { 
            'Heps': 40,
            'horizon': horizon,
            'H': H,
            'n_eval': min(20, n_eval),
            'dim': dim,
            'stitch': True if envname == 'darkroom_stitch' else False,
        }
        eval_darkroom.online(eval_trajs, model, **config)
        plt.savefig(f'figs/{evals_filename}/online/{save_filename}.png')
        plt.clf()




    all_xs = []
    all_rs_lnr = []
    all_rs_greedy = []
    all_rs_opt = []
    all_rs_emp = []
    all_rs_pess = []
    all_rs_rnd = []
    all_rs_lnr_greedy = []
    all_rs_lin = []
    all_rs_thmp = []
    all_rs_thmp2 = []

    envs = []
    trajs = []

    # OFFLINE EVALUATION SINGLE
    for i_eval in range(n_eval):
        print(f"Eval traj: {i_eval}")

        traj = eval_trajs[i_eval]        
        batch = {
            'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device),
            'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device),
            'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device),
            'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).float().to(device)
        }


        if envname in ['bandit', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:
            means = traj['means']
            env = BanditEnv(means, horizon, var=test_var, type=bandit_type)       # naming issue here for length of contexts

            true_opt = OptPolicy(env)
            greedy = GreedyOptPolicy(env)
            # lnr = BanditTransformerController(model, sample=False)
            # emp = EmpMeanPolicy(env)
            # pess = PessMeanPolicy(env, .8)
            # thmp2 = ThompsonSamplingPolicy2(env, std=var, prior_mean=prior_mean, prior_var=prior_var)

            true_opt.set_batch(batch)
            greedy.set_batch(batch)
            # lnr.set_batch(batch)
            # emp.set_batch(batch)
            # pess.set_batch(batch)
            # thmp2.set_batch(batch)


            xs_greedy, us_greedy, xps_greedy, rs_greedy = env.deploy_eval(greedy)
            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            # xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(lnr)
            # xs_emp, us_emp, xps_emp, rs_emp = env.deploy_eval(emp)
            # xs_pess, us_pess, xps_pess, rs_pess = env.deploy_eval(pess)
            # xs_thmp2, us_thmp2, xps_thmp2, rs_thmp2 = env.deploy_eval(thmp2)

            # all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            # all_rs_lnr.append(np.sum(rs_lnr))
            all_rs_greedy.append(np.sum(rs_greedy))
            # all_rs_emp.append(np.sum(rs_emp))
            # all_rs_pess.append(np.sum(rs_pess))
            # all_rs_thmp2.append(np.sum(rs_thmp2))
            envs.append(env)
            trajs.append(traj)

            # print("\n\n")

            # a = lnr.model(batch)
            # print(batch['rollin_us'].sum(axis=1)[0].cpu().detach().numpy())
            # print(env.means.round(2))
            # print(a[0].cpu().detach().numpy().round(2))
            # print("\n")
            # print(us_lnr)
            # print(us_opt)
            # print(np.all(us_lnr == us_opt))

            
            # embed()

        
        elif envname == 'bandit_topk':
            means = traj['means']
            env = TopKBanditEnv(means, horizon, var=test_var, k=topk)

            true_opt = OptPolicy(env)
            greedy = GreedyOptPolicy(env)
            lnr = TopKBanditTransformerController(model, k=topk, sample=False)
            rnd = TopKRandCommitPolicy(env, topk, horizon, immediate=True)
            lin = LinUCB(env, topk, const=0.0)

            true_opt.set_batch(batch)
            greedy.set_batch(batch)
            lnr.set_batch(batch)
            rnd.set_batch(batch)
            lin.set_batch(batch)

            xs_greedy, us_greedy, xps_greedy, rs_greedy = env.deploy_eval(greedy)
            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(lnr)
            xs_rnd, us_rnd, xps_rnd, rs_rnd = env.deploy_eval(rnd)
            xs_lin, us_lin, xps_lin, rs_lin = env.deploy_eval(lin)

            all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            all_rs_lnr.append(np.sum(rs_lnr))
            all_rs_greedy.append(np.sum(rs_greedy))
            all_rs_rnd.append(np.sum(rs_rnd))
            all_rs_lin.append(np.sum(rs_lin))



        elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
            goal = traj['goal']
            if envname == 'darkroom_stitch':
                env = DarkroomEnvStitch(dim, goal, H, eval=False)
            else:
                env = DarkroomEnv(dim, goal, H)

            true_opt = DarkroomOptPolicy(env)
            lnr = DarkroomTransformerController(model, sample=True)
            lnr_greedy = DarkroomTransformerController(model, sample=False)
            rnd = RandCommit(env)

            true_opt.set_batch(batch)
            lnr.set_batch(batch)
            lnr_greedy.set_batch(batch)
            rnd.set_batch(batch)

            xs_opt, us_opt, xps_opt, rs_opt = env.deploy_eval(true_opt)
            xs_lnr, us_lnr, xps_lnr, rs_lnr = env.deploy_eval(lnr)
            xs_lnr_greedy, us_lnr_greedy, xps_lnr_greedy, rs_lnr_greedy = env.deploy_eval(lnr_greedy)
            xs_rnd, us_rnd, xps_rnd, rs_rnd = env.deploy_eval(rnd)


            all_xs.append((xs_opt, xs_lnr))
            all_rs_opt.append(np.sum(rs_opt))
            all_rs_lnr.append(np.sum(rs_lnr))
            all_rs_lnr_greedy.append(np.sum(rs_lnr_greedy))
            all_rs_rnd.append(np.sum(rs_rnd))

    if envname in ['bandit', 'linear_bandit']:
        print("Running bandit offline evaluations in parallel")

        batch = {
            'rollin_xs': torch.tensor(np.array([traj['rollin_xs'] for traj in trajs])).float().to(device),
            'rollin_us': torch.tensor(np.array([traj['rollin_us'] for traj in trajs])).float().to(device),
            'rollin_xps': torch.tensor(np.array([traj['rollin_xps'] for traj in trajs])).float().to(device),
            'rollin_rs': torch.tensor(np.array([traj['rollin_rs'][:, None] for traj in trajs])).reshape(n_eval,-1,1).float().to(device),
        }
        batch_numpy = {
            'rollin_xs': np.array([traj['rollin_xs'] for traj in trajs]),
            'rollin_us': np.array([traj['rollin_us'] for traj in trajs]),
            'rollin_xps': np.array([traj['rollin_xps'] for traj in trajs]),
            'rollin_rs': np.array([traj['rollin_rs'][:, None] for traj in trajs]),
        }

        vec_env = BanditEnvVec(envs)
        lnr = BanditTransformerController(model, sample=False, batch_size=n_eval)
        emp = EmpMeanPolicy(envs[0], batch_size=n_eval)
        pess = PessMeanPolicy(envs[0], 1.0, batch_size=n_eval)
        thmp2 = ThompsonSamplingPolicy2(
            env,
            std=var,
            prior_mean=prior_mean,
            prior_var=prior_var,
            batch_size=n_eval)

        lnr.set_batch(batch)
        emp.set_batch_numpy_vec(batch_numpy)
        pess.set_batch_numpy_vec(batch_numpy)
        thmp2.set_batch_numpy_vec(batch_numpy)

        xs_lnr, us_lnr, xps_lnr, rs_lnr = vec_env.deploy_eval(lnr)
        xs_emp, us_emp, xps_emp, rs_emp = vec_env.deploy_eval(emp)
        xs_pess, us_pess, xps_pess, rs_pess = vec_env.deploy_eval(pess)
        xs_thmp2, us_thmp2, xps_thmp2, rs_thmp2 = vec_env.deploy_eval(thmp2)

        all_rs_lnr = rs_lnr
        all_rs_emp = rs_emp
        all_rs_pess = rs_pess
        all_rs_thmp2 = rs_thmp2

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            # 'greedy': np.array(all_rs_greedy),
            'emp': np.array(all_rs_emp),
            'pess': np.array(all_rs_pess),
            'thmp2': np.array(all_rs_thmp2)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            # 'greedy': baselines['opt'] - baselines['greedy'],
            'emp': baselines['opt'] - baselines['emp'],
            'pess': baselines['opt'] - baselines['pess'],
            'thmp2': baselines['opt'] - baselines['thmp2']
        }
    elif envname == 'bandit_topk':
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            'greedy': np.array(all_rs_greedy),
            'rnd': np.array(all_rs_rnd),
            'lin': np.array(all_rs_lin)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            'greedy': baselines['opt'] - baselines['greedy'],
            'rnd': baselines['opt'] - baselines['rnd'],
            'lin': baselines['opt'] - baselines['lin']
        }
    elif envname in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        baselines = {
            'opt': np.array(all_rs_opt),
            'lnr': np.array(all_rs_lnr),
            'rnd': np.array(all_rs_rnd),
            'lnr_greedy': np.array(all_rs_lnr_greedy)
        }
        subopt_baselines = {
            'lnr': baselines['opt'] - baselines['lnr'],
            'lnr_greedy': baselines['opt'] - baselines['lnr_greedy'],
            'rnd': baselines['opt'] - baselines['rnd'],

        }
    else:
        raise NotImplementedError

    baselines_means = {
        k: np.mean(v) for k, v in baselines.items()
    }
    subopt_baselines_means = {
        k: np.mean(v) for k, v in subopt_baselines.items()
    }

    # bar plot of the means with keys as labels with colors
    # make a list of viridis colors
    colors = plt.cm.viridis(np.linspace(0, 1, len(baselines_means)))
    plt.bar(baselines_means.keys(), baselines_means.values(), color=colors)
    plt.title(f'Mean Reward on {n_eval} Trajectories')
    plt.savefig(f'figs/{evals_filename}/bar/{save_filename}_bar.png')
    plt.clf()





    # PESSIMISM ANALYSIS
    if False:
        if envname in ['bandit', 'bandit_ood', 'bandit_thompson']:
            pess_consts = np.linspace(0, 5, 41)
            pess_values = analyze_pess(pess_consts, eval_trajs)[:, :, 0]
            pess_subopts = all_rs_opt - pess_values
            pess_means = np.mean(pess_subopts, axis=1)
            pess_sems = scipy.stats.sem(pess_subopts, axis=1)

            
            # plot the means and sems in error band
            plt.plot(pess_consts, pess_means, label='PESS mean')
            plt.fill_between(pess_consts, pess_means-pess_sems, pess_means+pess_sems, alpha=.5)
            plt.plot(pess_consts, np.ones(len(pess_consts))*subopt_baselines_means['lnr'], linestyle='--', label='LNR')
            plt.savefig(f'figs/{evals_filename}/pess/{save_filename}_pess.png')
            plt.clf()




    # OFFLINE EVALUATION GRAPH


    def run_controllers(controllers):
        all_rs = { k: [] for k in controllers.keys() }
        for i_eval in range(n_eval):
            print(f"Eval traj: {i_eval}")

            for cutoff in range(1, horizon + 1):


                traj = eval_trajs[i_eval]        
                
                means = traj['means']
                if envname in ['bandit', 'bandit_ood']:
                    env = BanditEnv(means, horizon, var=test_var, type=bandit_type)
                elif envname == 'linear_bandit':
                    env = LinearBanditEnv(traj['theta'], traj['arms'], horizon, var=test_var)
                elif envname == 'bandit_topk':
                    env = TopKBanditEnv(means, horizon, var=test_var, k=topk)
                else:
                    raise NotImplementedError
                batch = {
                    'rollin_xs': torch.tensor(traj['rollin_xs'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_us': torch.tensor(traj['rollin_us'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_xps': torch.tensor(traj['rollin_xps'][None,:,:]).float().to(device)[:,:cutoff,:],
                    'rollin_rs': torch.tensor(traj['rollin_rs'][None,:,None]).float().to(device)[:,:cutoff,:],
                }

                for key in controllers.keys():
                    controllers[key].set_env(env)
                    controllers[key].set_batch(batch)
                    xs, us, xps, rs = env.deploy_eval(controllers[key])
                    all_rs[key].append(np.sum(rs))

        return all_rs

    if envname in ['bandit', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:

        true_opt = OptPolicy(env)
        greedy = GreedyOptPolicy(env)
        lnr = BanditTransformerController(model, sample=True)
        # lnr_greedy = BanditTransformerController(model, sample=False)
        # emp = EmpMeanPolicy(env)
        # pess = PessMeanPolicy(env, const=.8)
        # thmp2 = ThompsonSamplingPolicy2(env, std=var, prior_mean=prior_mean, prior_var=prior_var)
        controllers = {
            'opt': true_opt,
            # 'greedy': greedy,
            # 'lnr': lnr,
            # 'lnr_greedy': lnr_greedy,
            # 'emp': emp,
            # 'pess': pess,
            # 'thmp': thmp,
            # 'thmp2': thmp2
        }
        if envname == 'linear_bandit':
            env = LinearBanditEnv(traj['theta'], traj['arms'], horizon, var=test_var)
            lin = LinUCB(env, 1, const=0.0)
            controllers['lin'] = lin


    elif envname == 'bandit_topk':
        true_opt = OptPolicy(env)
        greedy = GreedyOptPolicy(env)
        lnr = TopKBanditTransformerController(model, k=topk, sample=True)
        lnr_greedy = TopKBanditTransformerController(model, k=topk, sample=False)
        rnd = TopKRandCommitPolicy(env, topk, horizon, immediate=True)
        lin = LinUCB(env, topk, const=0.0)
        controllers = {
            'opt': true_opt,
            'greedy': greedy,
            'lnr': lnr,
            'lnr_greedy': lnr_greedy,
            'rnd': rnd,
            'lin': lin,
        }
    else:
        raise NotImplementedError

    baselines = run_controllers(controllers)
    baselines = { k: np.array(v) for k, v in baselines.items() }

    print("Running offline graph evaluations in parallel")
    baselines['lnr_greedy'] = []
    baselines['emp'] = []
    baselines['pess'] = []
    baselines['thmp2'] = []

    for cutoff in range(1, horizon + 1):
        envs = []
        trajs = []
        print(f"Step: {cutoff}")
        for i_eval in range(n_eval):
            traj = eval_trajs[i_eval]
            means = traj['means']
            if envname in ['bandit', 'bandit_ood']:
                env = BanditEnv(means, horizon, var=test_var, type=bandit_type)
            elif envname == 'linear_bandit':
                env = LinearBanditEnv(traj['theta'], traj['arms'], horizon, var=test_var)
            elif envname == 'bandit_topk':
                env = TopKBanditEnv(means, horizon, var=test_var, k=topk)
            else:
                raise NotImplementedError

            envs.append(env)
            trajs.append(traj)

        batch = {
            'rollin_xs': torch.tensor(np.array([traj['rollin_xs'] for traj in trajs])).float().to(device)[:,:cutoff,:],
            'rollin_us': torch.tensor(np.array([traj['rollin_us'] for traj in trajs])).float().to(device)[:,:cutoff,:],
            'rollin_xps': torch.tensor(np.array([traj['rollin_xps'] for traj in trajs])).float().to(device)[:,:cutoff,:],
            'rollin_rs': torch.tensor(np.array([traj['rollin_rs'][:, None] for traj in trajs])).reshape(n_eval,-1,1).float().to(device)[:,:cutoff,:],
        }
        batch_numpy = {
            'rollin_xs': np.array([traj['rollin_xs'] for traj in trajs])[:,:cutoff,:],
            'rollin_us': np.array([traj['rollin_us'] for traj in trajs])[:,:cutoff,:],
            'rollin_xps': np.array([traj['rollin_xps'] for traj in trajs])[:,:cutoff,:],
            'rollin_rs': np.array([traj['rollin_rs'][:, None] for traj in trajs])[:,:cutoff,:],
        }

        vec_env = BanditEnvVec(envs)
        lnr_greedy = BanditTransformerController(model, sample=False, batch_size=n_eval)
        emp = EmpMeanPolicy(envs[0], batch_size=n_eval)
        pess = PessMeanPolicy(envs[0], .8, batch_size=n_eval)
        thmp2 = ThompsonSamplingPolicy2(
            env,
            std=var,
            prior_mean=prior_mean,
            prior_var=prior_var,
            batch_size=n_eval)

        lnr_greedy.set_batch(batch)
        emp.set_batch_numpy_vec(batch_numpy)
        pess.set_batch_numpy_vec(batch_numpy)
        thmp2.set_batch_numpy_vec(batch_numpy)

        xs_lnr_greedy, us_lnr_greedy, xps_lnr_greedy, rs_lnr_greedy = vec_env.deploy_eval(lnr_greedy)
        xs_emp, us_emp, xps_emp, rs_emp = vec_env.deploy_eval(emp)
        xs_pess, us_pess, xps_pess, rs_pess = vec_env.deploy_eval(pess)
        xs_thmp2, us_thmp2, xps_thmp2, rs_thmp2 = vec_env.deploy_eval(thmp2)

        baselines['lnr_greedy'].append(rs_lnr_greedy)
        baselines['emp'].append(rs_emp)
        baselines['pess'].append(rs_pess)
        baselines['thmp2'].append(rs_thmp2)

    baselines['lnr_greedy'] = np.array(baselines['lnr_greedy']).transpose().flatten()
    baselines['emp'] = np.array(baselines['emp']).transpose().flatten()
    baselines['pess'] = np.array(baselines['pess']).transpose().flatten()
    baselines['thmp2'] = np.array(baselines['thmp2']).transpose().flatten()

    baselines = { k: np.array(baselines[k]) for k in ['opt', 'lnr_greedy', 'emp', 'pess', 'thmp2'] }
    
    # calculate suboptimality baselines which are the same keys minus opt
    subopt_baselines = {
        k: baselines['opt'] - v for k, v in baselines.items() if k != 'opt'
    }




    for key in subopt_baselines.keys():
        values = subopt_baselines[key].reshape(n_eval, horizon)
        means = np.mean(values, axis=0)
        sems = scipy.stats.sem(values, axis=0)
        plt.plot(np.arange(1, horizon+1), means, label=key)
        plt.fill_between(np.arange(1, horizon+1), means-sems, means+sems, alpha=.2)

    if not os.path.exists('figs/graph'):
        os.makedirs(f'figs/{evals_filename}/graph', exist_ok=True)
    os.makedirs(f'data_results/{evals_filename}/graph', exist_ok=True)
    if not os.path.exists('figs/graph_log'):
        os.makedirs(f'figs/{evals_filename}/graph_log', exist_ok=True)

    plt.legend()
    plt.title('Suboptimality w.r.t optimal')
    plt.xlabel('Dataset size')
    plt.savefig(f'figs/{evals_filename}/graph/{save_filename}_mean_log.png')
    plt.yscale('log')
    plt.savefig(f'figs/{evals_filename}/graph_log/{save_filename}_mean_log.png')
    plt.clf()

    # save baselines in data_results directory
    with open(f'data_results/{evals_filename}/graph/{save_filename}.pkl', 'wb') as f:
        pickle.dump(baselines, f)



    
