from utils.datasets_no_n_user import expected_reward_function
from utils.datasets_no_n_user import BanditDatasetWithSurrogate
from obp.policy import NNPolicyLearner
from utils.policy import OurLearner
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from obp.policy import Random
import time
from sklearn.utils import check_random_state
from utils.policy import GammaOptimizer3 as GammaOptimizer


exp_name = "exp3"
print(torch.cuda.is_available())


beta_list = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]  
parips_all_results = []
surips_all_results = []
sdr_all_results = []
sdr_both_all_results = []
s_opt_all_results = []
random_all_results = []
logging_all_results = []
parips_all_surrogate = []
surips_all_surrogate = []
sdr_all_surrogate = []
sdr_both_all_surrogate = []
s_opt_all_surrogate = []
random_all_surrogate = []
logging_all_surrogate = []
parips_all_full = []
surips_all_full = []
sdr_all_full = []
sdr_both_all_full = []
s_opt_all_full = []
random_all_full = []
logging_all_full = []

n_sims=100
hidden_layer_size = (150, )
num_tries=4
num_gamma = 6
batch_size = 64
n_rounds=1000
alpha_noise = 0.00
p_o = 0.2
reward_type = "continuous"
beta_data = -2
random_state = 2255
n_actions = 100
base_path = "../../"
path = "../../data/"
expected_rewards_df = pd.read_csv(path+"expected_rewards_df.csv")
user_features_df = pd.read_csv(path+"user_features_df.csv")
unique_user_ids = expected_rewards_df['user_id'].unique()
unique_video_ids = expected_rewards_df['video_id'].unique()
save_file_name = f"po02"

with tqdm(total=n_sims, desc='Outer', position=0) as outer_bar:
    start_time = time.time()
    for sim in range(n_sims):
        parips_results = []
        surips_results = []
        sdr_results = []
        s_opt_results = []
        random_results=[]
        logging_results = []
        sdr_both_results = []
        parips_surrogate = []
        surips_surrogate = []
        sdr_surrogate = []
        sdr_both_surrogate = []
        random_surrogate=[]
        logging_surrogate = []
        s_opt_surrogate = []
        parips_full = []
        surips_full = []
        sdr_full = []
        sdr_both_full = []
        random_full=[]
        logging_full = []
        s_opt_full = []
        ours_gamma_results = []
        ours_gamma_surrogate = []
        ours_gamma_full = []
        opt_gamma_results = []
        opt_gamma_surrogate = []
        opt_gamma_full = []
        rips_results = []
        rips_surrogate = []
        rips_full = []
        sips_results = []
        sips_surrogate = []
        sips_full = []
        rdm_results = []
        rdm_surrogate = []
        rdm_full = []
        sdm_results = []
        sdm_surrogate = []
        sdm_full = []
        
        gammas=[]
        gamma_trues=[]
        
        random_ = check_random_state(random_state+sim)
        dataset=BanditDatasetWithSurrogate(n_actions = n_actions, reward_type=reward_type, beta=beta_data, p_o=p_o, alpha_noise=alpha_noise, random_state=random_state+sim)
        train_user_ids = random_.choice(unique_user_ids, size=int(0.7*len(unique_user_ids)), replace=False)
        random_videoids = expected_rewards_df['video_id'].drop_duplicates().sample(n=n_actions, random_state=random_state+sim)
        expected_rewards_df = expected_rewards_df[expected_rewards_df['video_id'].isin(random_videoids)]

        train_df = expected_rewards_df[expected_rewards_df['user_id'].isin(train_user_ids)]
        validation_df = expected_rewards_df[~expected_rewards_df['user_id'].isin(train_user_ids)]
        
        bandit_feedback_train =  dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds, expected_rewards_df=train_df, user_features_df=user_features_df, test=False)
        bandit_feedback_test =  dataset.obtain_batch_bandit_feedback(expected_rewards_df=validation_df, user_features_df=user_features_df, test=True)
        
        logging_policy = bandit_feedback_test["pi_b"]
        logging_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=logging_policy
        )
        logging_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=logging_policy
        )
        logging_results = [logging_policy_value] * len(beta_list)
        logging_surrogate = [logging_surrogate_value] * len(beta_list)
        n_rounds_test = bandit_feedback_test["n_rounds"]
        
        if p_o==0.0:
            parips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            parips_policy_action_dist = parips_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:    
            parips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                max_iter=200,
                learning_rate_init=0.0001,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='dr',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            parips_policy.fit(
                context=bandit_feedback_train["obs_contexts"],
                action=bandit_feedback_train["obs_actions"],
                reward=bandit_feedback_train["obs_rewards"],
                pscore=bandit_feedback_train["obs_pscores"],
            )
            parips_policy_action_dist = parips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        parips_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=parips_policy_action_dist
        )
        parips_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=parips_policy_action_dist
        )
        parips_results = [parips_policy_value] * len(beta_list)
        parips_surrogate = [parips_surrogate_value] * len(beta_list)
        
        if p_o==0.0:
            surips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            surips_policy_action_dist = surips_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:    
            surips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='dr',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            surips_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                reward=bandit_feedback_train["f_s"],
                pscore=bandit_feedback_train["pscores"],
            )
            surips_policy_action_dist = surips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        surips_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=surips_policy_action_dist
        )
        surips_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=surips_policy_action_dist
        )
        
        surips_results = [surips_policy_value] * len(beta_list)
        surips_surrogate = [surips_surrogate_value] * len(beta_list)
        
        if p_o==0.0:
            rips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            rips_policy_action_dist = rips_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:    
            rips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                max_iter=200,
                learning_rate_init=0.0001,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='ipw',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            rips_policy.fit(
                context=bandit_feedback_train["obs_contexts"],
                action=bandit_feedback_train["obs_actions"],
                reward=bandit_feedback_train["obs_rewards"],
                pscore=bandit_feedback_train["obs_pscores"],
            )
            rips_policy_action_dist = rips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        rips_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=rips_policy_action_dist
        )
        rips_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=rips_policy_action_dist
        )
        rips_results = [rips_policy_value] * len(beta_list)
        rips_surrogate = [rips_surrogate_value] * len(beta_list)
        
        if p_o==0.0:
            rdm_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            rdm_policy_action_dist = rdm_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:    
            rdm_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                max_iter=200,
                learning_rate_init=0.0001,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='dm',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            rdm_policy.fit(
                context=bandit_feedback_train["obs_contexts"],
                action=bandit_feedback_train["obs_actions"],
                reward=bandit_feedback_train["obs_rewards"],
                pscore=bandit_feedback_train["obs_pscores"],
            )
            rdm_policy_action_dist = rdm_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        rdm_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=rdm_policy_action_dist
        )
        rdm_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=rdm_policy_action_dist
        )
        rdm_results = [rdm_policy_value] * len(beta_list)
        rdm_surrogate = [rdm_surrogate_value] * len(beta_list)
        
        sips_policy = NNPolicyLearner(
            n_actions=n_actions,
            dim_context=bandit_feedback_train["dim_context"],
            hidden_layer_size=hidden_layer_size,
            early_stopping=True,
            activation='relu',
            solver='adam',
            off_policy_objective='ipw',
            batch_size=batch_size,
            random_state=random_state+sim,
        )
        sips_policy.fit(
            context=bandit_feedback_train["contexts"],
            action=bandit_feedback_train["actions"],
            reward=bandit_feedback_train["f_s"],
            pscore=bandit_feedback_train["pscores"],
        )
        sips_policy_action_dist = sips_policy.predict(
            context=bandit_feedback_test["contexts"],
        )
        sips_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=sips_policy_action_dist
        )
        sips_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=sips_policy_action_dist
        )
        
        sips_results = [sips_policy_value] * len(beta_list)
        sips_surrogate = [sips_surrogate_value] * len(beta_list)
        
        sdm_policy = NNPolicyLearner(
            n_actions=n_actions,
            dim_context=bandit_feedback_train["dim_context"],
            hidden_layer_size=hidden_layer_size,
            early_stopping=True,
            activation='relu',
            solver='adam',
            off_policy_objective='dm',
            batch_size=batch_size,
            random_state=random_state+sim,
        )
        sdm_policy.fit(
            context=bandit_feedback_train["contexts"],
            action=bandit_feedback_train["actions"],
            reward=bandit_feedback_train["f_s"],
            pscore=bandit_feedback_train["pscores"],
        )
        sdm_policy_action_dist = sdm_policy.predict(
            context=bandit_feedback_test["contexts"],
        )
        sdm_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=sdm_policy_action_dist
        )
        sdm_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=sdm_policy_action_dist
        )
        
        sdm_results = [sdm_policy_value] * len(beta_list)
        sdm_surrogate = [sdm_surrogate_value] * len(beta_list)
        
        if p_o==0.0:
            s_opt_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            s_opt_policy_action_dist = s_opt_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:    
            s_opt_policy = OurLearner(
                n_actions=n_actions,
                s_dim=bandit_feedback_train["s_dim"],
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='sdr-both',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            s_opt_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                reward=bandit_feedback_train["rewards"],
                obs_list=bandit_feedback_train["obs_list"],
                p_o=p_o,
                pscore=bandit_feedback_train["pscores"],
                s_sum=bandit_feedback_train["s_sum"],
                beta=1.0,
            )
            s_opt_policy_action_dist = s_opt_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        s_opt_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=s_opt_policy_action_dist
        )
        s_opt_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=s_opt_policy_action_dist
        )
        s_opt_results = [s_opt_policy_value] * len(beta_list)
        s_opt_surrogate = [s_opt_surrogate_value] * len(beta_list)
        
        if p_o==0.0:
            sdr_policy=Random(n_actions=n_actions, random_state=random_state+sim)
            sdr_policy_action_dist = sdr_policy.compute_batch_action_dist(
                n_rounds=n_rounds_test,
            )
        else:  
            sdr_policy = OurLearner(
                n_actions=n_actions,
                s_dim=bandit_feedback_train["s_dim"],
                dim_context=bandit_feedback_train["dim_context"],
                hidden_layer_size=hidden_layer_size,
                max_iter=200,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='sdr',
                batch_size=batch_size,
                random_state=random_state+sim,
            )
            sdr_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                reward=bandit_feedback_train["rewards"],
                obs_list=bandit_feedback_train["obs_list"],
                p_o=p_o,
                pscore=bandit_feedback_train["pscores"],
                beta = 0.0,
            )
            sdr_policy_action_dist = sdr_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
        sdr_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=sdr_policy_action_dist
        )
        sdr_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward = bandit_feedback_test["f_sum"],
            action_dist=sdr_policy_action_dist
        )
        sdr_results = [sdr_policy_value] * len(beta_list)
        sdr_surrogate = [sdr_surrogate_value] * len(beta_list)

        for beta in beta_list:
            logging_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=logging_policy,
                beta=beta
            )
            logging_full.append(logging_full_value)
            parips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=parips_policy_action_dist,
                beta=beta
            )
            parips_full.append(parips_full_value)
            surips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=surips_policy_action_dist,
                beta=beta
            )
            surips_full.append(surips_full_value)
            s_opt_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=s_opt_policy_action_dist,
                beta=beta
            )
            s_opt_full.append(s_opt_full_value)
            
            sdr_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdr_policy_action_dist,
                beta=beta
            )
            sdr_full.append(sdr_full_value)
            
            rips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rips_policy_action_dist,
                beta=beta
            )
            rips_full.append(rips_full_value)
            rdm_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rdm_policy_action_dist,
                beta=beta
            )
            rdm_full.append(rdm_full_value)
            sips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sips_policy_action_dist,
                beta=beta
            )
            sips_full.append(sips_full_value)
            sdm_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdm_policy_action_dist,
                beta=beta
            )
            sdm_full.append(sdm_full_value)

            if beta==0:
                sdr_both_results.append(sdr_policy_value)
                sdr_both_surrogate.append(sdr_surrogate_value)
                sdr_both_full.append(sdr_full_value)
            elif beta==1:
                sdr_both_results.append(s_opt_policy_value)
                sdr_both_surrogate.append(s_opt_surrogate_value)
                sdr_both_full.append(s_opt_full_value)
            else:
                sdr_both_policy = OurLearner(
                    n_actions=n_actions,
                    s_dim=bandit_feedback_train["s_dim"],
                    dim_context=bandit_feedback_train["dim_context"],
                    hidden_layer_size=hidden_layer_size,
                    max_iter=200,
                    learning_rate_init=0.0001,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='sdr-both',
                    batch_size=batch_size,
                    random_state=random_state+sim,
                )
                sdr_both_policy.fit(
                    context=bandit_feedback_train["contexts"],
                    action=bandit_feedback_train["actions"],
                    surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                    reward=bandit_feedback_train["rewards"],
                    obs_list=bandit_feedback_train["obs_list"],
                    p_o=p_o,
                    pscore=bandit_feedback_train["pscores"],
                    s_sum=bandit_feedback_train["s_sum"],
                    beta = beta,
                )
                sdr_both_policy_action_dist = sdr_both_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
                sdr_both_policy_value = dataset.calc_ground_truth_policy_value(
                    expected_reward=bandit_feedback_test["all_q_x_a_f"],
                    action_dist=sdr_both_policy_action_dist
                )
                sdr_both_surrogate_value = dataset.calc_ground_truth_policy_value(
                    expected_reward = bandit_feedback_test["f_sum"],
                    action_dist=sdr_both_policy_action_dist
                )
                sdr_both_full_value = dataset.calc_full_policy_value(
                    expected_reward=bandit_feedback_test["all_q_x_a_f"],
                    expected_surrogate_reward=bandit_feedback_test["f_sum"],
                    action_dist=sdr_both_policy_action_dist,
                    beta=beta
                )
                sdr_both_results.append(sdr_both_policy_value)
                sdr_both_surrogate.append(sdr_both_surrogate_value)
                sdr_both_full.append(sdr_both_full_value)
            
            if beta==1.0:
                ours_gamma_full.append(s_opt_full_value)
                ours_gamma_results.append(s_opt_policy_value)
                ours_gamma_surrogate.append(s_opt_surrogate_value)
                gammas.append(1.0)
            else:
                optimizer=GammaOptimizer(n_actions=n_actions, 
                        context=bandit_feedback_train["contexts"],
                        reward=bandit_feedback_train["rewards"],
                        action=bandit_feedback_train["actions"],
                        surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                        obs_list = bandit_feedback_train["obs_list"],
                        p_o=p_o,
                        s_sum=bandit_feedback_train["s_sum"],
                        pscore = bandit_feedback_train["pscores"],
                        beta=beta,
                        alpha_noise = alpha_noise,
                        s_dim=bandit_feedback_train["s_dim"],
                        dim_context=bandit_feedback_train["dim_context"],
                        hidden_layer_size=hidden_layer_size,
                        early_stopping=True,
                        activation='relu',
                        solver='adam',
                        off_policy_objective='sdr-both',
                        batch_size=batch_size,
                        random_state=random_state+sim,
                        )
                gamma=optimizer.optimize(
                    num_tries=num_tries,
                    num_gamma = num_gamma,
                )
                ours_gamma_policy = OurLearner(
                    n_actions=n_actions,
                    s_dim=bandit_feedback_train["s_dim"],
                    dim_context=bandit_feedback_train["dim_context"],
                    hidden_layer_size=hidden_layer_size,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='sdr-both',
                    batch_size=batch_size,
                    random_state=random_state+sim,
                )
                ours_gamma_policy.fit(
                    context=bandit_feedback_train["contexts"],
                    action=bandit_feedback_train["actions"],
                    surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                    reward=bandit_feedback_train["rewards"],
                    obs_list=bandit_feedback_train["obs_list"],
                    p_o=p_o,
                    pscore=bandit_feedback_train["pscores"],
                    s_sum=bandit_feedback_train["s_sum"],
                    beta = gamma
                )
                ours_gamma_policy_action_dist = ours_gamma_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
                ours_gamma_policy_value = dataset.calc_ground_truth_policy_value(
                    expected_reward=bandit_feedback_test["all_q_x_a_f"],
                    action_dist=ours_gamma_policy_action_dist
                )
                ours_gamma_surrogate_value = dataset.calc_ground_truth_policy_value(
                    expected_reward = bandit_feedback_test["f_sum"],
                    action_dist=ours_gamma_policy_action_dist
                )
                ours_gamma_full_value = dataset.calc_full_policy_value(
                    expected_reward=bandit_feedback_test["all_q_x_a_f"],
                    expected_surrogate_reward=bandit_feedback_test["f_sum"],
                    action_dist=ours_gamma_policy_action_dist,
                    beta=beta
                )
                ours_gamma_results.append(ours_gamma_policy_value)
                ours_gamma_surrogate.append(ours_gamma_surrogate_value)
                ours_gamma_full.append(ours_gamma_full_value)
                gammas.append(gamma)
            
            if beta==1.0:
                opt_gamma_full.append(s_opt_full_value)
                opt_gamma_results.append(s_opt_policy_value)
                opt_gamma_surrogate.append(s_opt_surrogate_value)
                gamma_trues.append(1.0)
            else:
                optimizer=GammaOptimizer(n_actions=n_actions, 
                        context=bandit_feedback_train["contexts"],
                        reward=bandit_feedback_train["rewards"],
                        action=bandit_feedback_train["actions"],
                        surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                        obs_list = bandit_feedback_train["obs_list"],
                        p_o=p_o,
                        s_sum=bandit_feedback_train["s_sum"],
                        pscore = bandit_feedback_train["pscores"],
                        beta=beta,
                        alpha_noise = alpha_noise,
                        s_dim=bandit_feedback_train["s_dim"],
                        dim_context=bandit_feedback_train["dim_context"],
                        hidden_layer_size=hidden_layer_size,
                        early_stopping=True,
                        activation='relu',
                        solver='adam',
                        off_policy_objective='sdr-both',
                        batch_size=batch_size,
                        random_state=random_state+sim,
                        )
                gamma_true, opt_gamma_full_value, opt_gamma_policy_value, opt_gamma_surrogate_value = optimizer.optimize_true(
                    bandit_feedback_train,
                    bandit_feedback_test,
                    num_gamma = int(num_gamma+(beta*0.5))
                )
                opt_gamma_results.append(opt_gamma_policy_value)
                opt_gamma_surrogate.append(opt_gamma_surrogate_value)
                opt_gamma_full.append(opt_gamma_full_value)
                gamma_trues.append(gamma_true)
        
        if sim==0:
            parips_all_results=parips_results
            surips_all_results=surips_results
            sdr_all_results=sdr_results
            sdr_both_all_results=sdr_both_results
            s_opt_all_results=s_opt_results
            # random_all_results=random_results
            logging_all_results=logging_results
            parips_all_surrogate=parips_surrogate
            surips_all_surrogate=surips_surrogate
            sdr_all_surrogate=sdr_surrogate
            sdr_both_all_surrogate=sdr_both_surrogate
            s_opt_all_surrogate=s_opt_surrogate
            # random_all_surrogate=random_surrogate
            logging_all_surrogate=logging_surrogate
            parips_all_full = parips_full
            surips_all_full = surips_full
            sdr_all_full = sdr_full
            sdr_both_all_full = sdr_both_full
            s_opt_all_full = s_opt_full
            # random_all_full = random_full
            logging_all_full = logging_full
            ours_gamma_all_results = ours_gamma_results
            ours_gamma_all_surrogate = ours_gamma_surrogate
            ours_gamma_all_full = ours_gamma_full
            opt_gamma_all_results = opt_gamma_results
            opt_gamma_all_surrogate = opt_gamma_surrogate
            opt_gamma_all_full = opt_gamma_full
            gamma_list = gammas
            gamma_true_list = gamma_trues
            rips_all_results = rips_results
            rips_all_surrogate = rips_surrogate
            rips_all_full = rips_full
            sips_all_results = sips_results
            sips_all_surrogate = sips_surrogate
            sips_all_full = sips_full
            rdm_all_results = rdm_results
            rdm_all_surrogate = rdm_surrogate
            rdm_all_full = rdm_full
            sdm_all_results = sdm_results
            sdm_all_surrogate = sdm_surrogate
            sdm_all_full = sdm_full
        else:
            parips_all_results=np.vstack([parips_all_results, parips_results])
            surips_all_results=np.vstack([surips_all_results, surips_results])
            sdr_all_results=np.vstack([sdr_all_results, sdr_results])
            sdr_both_all_results=np.vstack([sdr_both_all_results, sdr_both_results])
            s_opt_all_results=np.vstack([s_opt_all_results, s_opt_results])
            # random_all_results=np.vstack([random_all_results, random_results])
            logging_all_results=np.vstack([logging_all_results, logging_results])
            parips_all_surrogate=np.vstack([parips_all_surrogate, parips_surrogate])
            surips_all_surrogate=np.vstack([surips_all_surrogate, surips_surrogate])
            sdr_all_surrogate=np.vstack([sdr_all_surrogate, sdr_surrogate])
            sdr_both_all_surrogate=np.vstack([sdr_both_all_surrogate, sdr_both_surrogate])
            s_opt_all_surrogate=np.vstack([s_opt_all_surrogate, s_opt_surrogate])
            # random_all_surrogate=np.vstack([random_all_surrogate, random_surrogate])
            logging_all_surrogate=np.vstack([logging_all_surrogate, logging_surrogate])
            parips_all_full = np.vstack([parips_all_full, parips_full])
            surips_all_full = np.vstack([surips_all_full, surips_full])
            sdr_all_full = np.vstack([sdr_all_full, sdr_full])
            sdr_both_all_full = np.vstack([sdr_both_all_full, sdr_both_full])
            s_opt_all_full = np.vstack([s_opt_all_full, s_opt_full])
            # random_all_full = random_full
            logging_all_full = np.vstack([logging_all_full, logging_full])
            
            ours_gamma_all_results=np.vstack([ours_gamma_all_results, ours_gamma_results])
            ours_gamma_all_surrogate=np.vstack([ours_gamma_all_surrogate, ours_gamma_surrogate])
            ours_gamma_all_full=np.vstack([ours_gamma_all_full, ours_gamma_full])
            opt_gamma_all_results=np.vstack([opt_gamma_all_results, opt_gamma_results])
            opt_gamma_all_surrogate=np.vstack([opt_gamma_all_surrogate, opt_gamma_surrogate])
            opt_gamma_all_full=np.vstack([opt_gamma_all_full, opt_gamma_full])
            gamma_list = np.vstack([gamma_list, gammas])
            gamma_true_list = np.vstack([gamma_true_list, gamma_trues])
            rips_all_results = np.vstack([rips_all_results, rips_results])
            rips_all_surrogate = np.vstack([rips_all_surrogate, rips_surrogate])
            rips_all_full = np.vstack([rips_all_full, rips_full])
            sips_all_results = np.vstack([sips_all_results, sips_results])
            sips_all_surrogate = np.vstack([sips_all_surrogate, sips_surrogate])
            sips_all_full = np.vstack([sips_all_full, sips_full])
            rdm_all_results = np.vstack([rdm_all_results, rdm_results])
            rdm_all_surrogate = np.vstack([rdm_all_surrogate, rdm_surrogate])
            rdm_all_full = np.vstack([rdm_all_full, rdm_full])
            sdm_all_results = np.vstack([sdm_all_results, sdm_results])
            sdm_all_surrogate = np.vstack([sdm_all_surrogate, sdm_surrogate])
            sdm_all_full = np.vstack([sdm_all_full, sdm_full])
            
        outer_bar.update(1)
        elapsed_time = time.time() - start_time
        remaining_time = (elapsed_time / (sim + 1)) * (n_sims - (sim + 1))
        tqdm.write(f"\rSimulations: {outer_bar.n}/{outer_bar.total} - "
                   f"{tqdm.format_interval(elapsed_time)}<{tqdm.format_interval(remaining_time)}")

        if sim<2:
            continue
        parips_means  = parips_all_results.mean(axis=0)
        surips_means  = surips_all_results.mean(axis=0)
        sdr_means  = sdr_all_results.mean(axis=0)
        sdr_both_means  = sdr_both_all_results.mean(axis=0)
        s_opt_means  = s_opt_all_results.mean(axis=0)
        # random_means  = random_all_results.mean(axis=0)
        logging_means  = logging_all_results.mean(axis=0)

        parips_s_means  = parips_all_surrogate.mean(axis=0)
        surips_s_means  = surips_all_surrogate.mean(axis=0)
        sdr_s_means  = sdr_all_surrogate.mean(axis=0)
        sdr_both_s_means  = sdr_both_all_surrogate.mean(axis=0)
        s_opt_s_means  = s_opt_all_surrogate.mean(axis=0)
        # random_s_means  = random_all_surrogate.mean(axis=0)
        logging_s_means  = logging_all_surrogate.mean(axis=0)

        parips_full_means = parips_all_full.mean(axis=0)
        surips_full_means = surips_all_full.mean(axis=0)
        sdr_full_means = sdr_all_full.mean(axis=0)
        sdr_both_full_means = sdr_both_all_full.mean(axis=0)
        s_opt_full_means = s_opt_all_full.mean(axis=0)
        # random_full_means = random_all_full.mean(axis=0)
        logging_full_means = logging_all_full.mean(axis=0)


        parips_stds = parips_all_results.std(axis=0)
        surips_stds = surips_all_results.std(axis=0)
        sdr_stds = sdr_all_results.std(axis=0)
        sdr_both_stds = sdr_both_all_results.std(axis=0)
        s_opt_stds = s_opt_all_results.std(axis=0)
        # random_stds = random_all_results.std(axis=0)
        logging_stds = logging_all_results.std(axis=0)

        parips_s_stds = parips_all_surrogate.std(axis=0)
        surips_s_stds = surips_all_surrogate.std(axis=0)
        sdr_s_stds = sdr_all_surrogate.std(axis=0)
        sdr_both_s_stds = sdr_both_all_surrogate.std(axis=0)
        s_opt_s_stds = s_opt_all_surrogate.std(axis=0)
        # random_s_stds = random_all_surrogate.std(axis=0)
        logging_s_stds = logging_all_surrogate.std(axis=0)

        parips_full_stds = parips_all_full.std(axis=0)
        surips_full_stds = surips_all_full.std(axis=0)
        sdr_full_stds = sdr_all_full.std(axis=0)
        sdr_both_full_stds = sdr_both_all_full.std(axis=0)
        s_opt_full_stds = s_opt_all_full.std(axis=0)
        # random_full_stds = random_all_full.std(axis=0)
        logging_full_stds = logging_all_full.std(axis=0)

        # 標準誤差を計算
        parips_stes = parips_stds / np.sqrt(n_sims)
        surips_stes = surips_stds / np.sqrt(n_sims)
        sdr_stes = sdr_stds / np.sqrt(n_sims)
        sdr_both_stes = sdr_both_stds / np.sqrt(n_sims)
        s_opt_stes = s_opt_stds / np.sqrt(n_sims)
        # random_stes = random_stds / np.sqrt(n_sims)
        logging_stes = logging_stds / np.sqrt(n_sims)

        parips_s_stes = parips_s_stds / np.sqrt(n_sims)
        surips_s_stes = surips_s_stds / np.sqrt(n_sims)
        sdr_s_stes = sdr_s_stds / np.sqrt(n_sims)
        sdr_both_s_stes = sdr_both_s_stds / np.sqrt(n_sims)
        s_opt_s_stes = s_opt_s_stds / np.sqrt(n_sims)
        # random_s_stes = random_s_stds / np.sqrt(n_sims)
        logging_s_stes = logging_s_stds / np.sqrt(n_sims)

        parips_full_stes = parips_full_stds / np.sqrt(n_sims)
        surips_full_stes = surips_full_stds / np.sqrt(n_sims)
        sdr_full_stes = sdr_full_stds / np.sqrt(n_sims)
        sdr_both_full_stes = sdr_both_full_stds / np.sqrt(n_sims)
        s_opt_full_stes = s_opt_full_stds / np.sqrt(n_sims)
        # random_full_stes = random_full_stds / np.sqrt(n_sims)
        logging_full_stes = logging_full_stds / np.sqrt(n_sims)

        # 95% 信頼区間の計算
        confidence = 1.96
        parips_conf_intervals = (parips_stes * confidence)
        surips_conf_intervals = (surips_stes * confidence)
        sdr_conf_intervals = (sdr_stes * confidence)
        sdr_both_conf_intervals = (sdr_both_stes * confidence) 
        s_opt_conf_intervals = (s_opt_stes * confidence) 
        # random_conf_intervals = (random_stes * confidence) 
        logging_conf_intervals = (logging_stes * confidence) 

        parips_s_conf_intervals = (parips_s_stes * confidence) 
        surips_s_conf_intervals = (surips_s_stes * confidence) 
        sdr_s_conf_intervals = (sdr_s_stes * confidence) 
        sdr_both_s_conf_intervals = (sdr_both_s_stes * confidence)
        s_opt_s_conf_intervals = (s_opt_s_stes * confidence) 
        # random_s_conf_intervals = (random_s_stes * confidence) 
        logging_s_conf_intervals = (logging_s_stes * confidence) 

        parips_full_conf_intervals = (parips_full_stes * confidence) 
        surips_full_conf_intervals = (surips_full_stes * confidence) 
        sdr_full_conf_intervals = (sdr_full_stes * confidence) 
        sdr_both_full_conf_intervals = (sdr_both_full_stes * confidence) 
        s_opt_full_conf_intervals = (s_opt_full_stes * confidence) 
        # random_full_conf_intervals = (random_full_stes * confidence) 
        logging_full_conf_intervals = (logging_full_stes * confidence) 

        opt_gamma_means = opt_gamma_all_results.mean(axis=0)
        opt_gamma_s_means = opt_gamma_all_surrogate.mean(axis=0)
        opt_gamma_full_means = opt_gamma_all_full.mean(axis=0)
        opt_gamma_stds = opt_gamma_all_results.std(axis=0)
        opt_gamma_s_stds = opt_gamma_all_surrogate.std(axis=0)
        opt_gamma_full_stds = opt_gamma_all_full.std(axis=0)
        opt_gamma_stes = opt_gamma_stds / np.sqrt(n_sims)
        opt_gamma_s_stes = opt_gamma_s_stds / np.sqrt(n_sims)
        opt_gamma_full_stes = opt_gamma_full_stds / np.sqrt(n_sims)
        opt_gamma_conf_intervals = (opt_gamma_stes * confidence)
        opt_gamma_s_conf_intervals = (opt_gamma_s_stes * confidence)
        opt_gamma_full_conf_intervals = (opt_gamma_full_stes * confidence)

        ours_gamma_means = ours_gamma_all_results.mean(axis=0)
        ours_gamma_s_means = ours_gamma_all_surrogate.mean(axis=0)
        ours_gamma_full_means = ours_gamma_all_full.mean(axis=0)
        ours_gamma_stds = ours_gamma_all_results.std(axis=0)
        ours_gamma_s_stds = ours_gamma_all_surrogate.std(axis=0)
        ours_gamma_full_stds = ours_gamma_all_full.std(axis=0)
        ours_gamma_stes = ours_gamma_stds / np.sqrt(n_sims)
        ours_gamma_s_stes = ours_gamma_s_stds / np.sqrt(n_sims)
        ours_gamma_full_stes = ours_gamma_full_stds / np.sqrt(n_sims)
        ours_gamma_conf_intervals = (ours_gamma_stes * confidence)
        ours_gamma_s_conf_intervals = (ours_gamma_s_stes * confidence)
        ours_gamma_full_conf_intervals = (ours_gamma_full_stes * confidence)

        
        rips_means = rips_all_results.mean(axis=0)
        rips_s_means = rips_all_surrogate.mean(axis=0)
        rips_full_means = rips_all_full.mean(axis=0)
        rips_stds = rips_all_results.std(axis=0)
        rips_s_stds = rips_all_surrogate.std(axis=0)
        rips_full_stds = rips_all_full.std(axis=0)
        rips_stes = rips_stds / np.sqrt(n_sims)
        rips_s_stes = rips_s_stds / np.sqrt(n_sims)
        rips_full_stes = rips_full_stds / np.sqrt(n_sims)
        rips_conf_intervals = (rips_stes * confidence)
        rips_s_conf_intervals = (rips_s_stes * confidence)
        rips_full_conf_intervals = (rips_full_stes * confidence)
        sips_means = sips_all_results.mean(axis=0)
        sips_s_means = sips_all_surrogate.mean(axis=0)
        sips_full_means = sips_all_full.mean(axis=0)
        sips_stds = sips_all_results.std(axis=0)
        sips_s_stds = sips_all_surrogate.std(axis=0)
        sips_full_stds = sips_all_full.std(axis=0)
        sips_stes = sips_stds / np.sqrt(n_sims)
        sips_s_stes = sips_s_stds / np.sqrt(n_sims)
        sips_full_stes = sips_full_stds / np.sqrt(n_sims)
        sips_conf_intervals = (sips_stes * confidence)
        sips_s_conf_intervals = (sips_s_stes * confidence)
        sips_full_conf_intervals = (sips_full_stes * confidence)
        rdm_means = rdm_all_results.mean(axis=0)
        rdm_s_means = rdm_all_surrogate.mean(axis=0)
        rdm_full_means = rdm_all_full.mean(axis=0)
        rdm_stds = rdm_all_results.std(axis=0)
        rdm_s_stds = rdm_all_surrogate.std(axis=0)
        rdm_full_stds = rdm_all_full.std(axis=0)
        rdm_stes = rdm_stds / np.sqrt(n_sims)
        rdm_s_stes = rdm_s_stds / np.sqrt(n_sims)
        rdm_full_stes = rdm_full_stds / np.sqrt(n_sims)
        rdm_conf_intervals = (rdm_stes * confidence)
        rdm_s_conf_intervals = (rdm_s_stes * confidence)
        rdm_full_conf_intervals = (rdm_full_stes * confidence)
        sdm_means = sdm_all_results.mean(axis=0)
        sdm_s_means = sdm_all_surrogate.mean(axis=0)
        sdm_full_means = sdm_all_full.mean(axis=0)
        sdm_stds = sdm_all_results.std(axis=0)
        sdm_s_stds = sdm_all_surrogate.std(axis=0)
        sdm_full_stds = sdm_all_full.std(axis=0)
        sdm_stes = sdm_stds / np.sqrt(n_sims)
        sdm_s_stes = sdm_s_stds / np.sqrt(n_sims)
        sdm_full_stes = sdm_full_stds / np.sqrt(n_sims)
        sdm_conf_intervals = (sdm_stes * confidence)
        sdm_s_conf_intervals = (sdm_s_stes * confidence)
        sdm_full_conf_intervals = (sdm_full_stes * confidence)


        list_=[str(i) for i in beta_list]

        #save results in csv with s_opt
        if not os.path.exists(f'{base_path}/results/realdata/{exp_name}/data'):
            os.makedirs(f'{base_path}/results/realdata/{exp_name}/data')
        # results should contain the ones without _s
        results = pd.DataFrame({
            'list': list_,
            'parips_means': parips_means,
            'parips_conf_intervals': parips_conf_intervals,
            'surips_means': surips_means,
            'surips_conf_intervals': surips_conf_intervals,
            'sdr_means': sdr_means,
            'sdr_conf_intervals': sdr_conf_intervals,
            'sdr_both_means': sdr_both_means,
            'sdr_both_conf_intervals': sdr_both_conf_intervals,
            'logging_means': logging_means,
            'logging_conf_intervals': logging_conf_intervals,
            'ours_gamma_means': ours_gamma_means,
            'ours_gamma_conf_intervals': ours_gamma_conf_intervals,
            'opt_gamma_means': opt_gamma_means,
            'opt_gamma_conf_intervals': opt_gamma_conf_intervals,
            'rips_means': rips_means,
            'rips_conf_intervals': rips_conf_intervals,
            'sips_means': sips_means,
            'sips_conf_intervals': sips_conf_intervals,
            'rdm_means': rdm_means,
            'rdm_conf_intervals': rdm_conf_intervals,
            'sdm_means': sdm_means,
            'sdm_conf_intervals': sdm_conf_intervals
        })
        results.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}.csv', index=False)

        results_s = pd.DataFrame({
            'list': list_,
            'parips_s_means': parips_s_means,
            'parips_s_conf_intervals': parips_s_conf_intervals,
            'surips_s_means': surips_s_means,
            'surips_s_conf_intervals': surips_s_conf_intervals,
            'sdr_s_means': sdr_s_means,
            'sdr_s_conf_intervals': sdr_conf_intervals,
            'sdr_both_s_means': sdr_both_s_means,
            'sdr_both_s_conf_intervals': sdr_both_conf_intervals,
            'logging_s_means': logging_s_means,
            'logging_s_conf_intervals': logging_conf_intervals,
            'ours_gamma_s_means': ours_gamma_s_means,
            'ours_gamma_s_conf_intervals': ours_gamma_conf_intervals,
            'opt_gamma_s_means': opt_gamma_s_means,
            'opt_gamma_s_conf_intervals': opt_gamma_conf_intervals,
            'rips_s_means': rips_s_means,
            'rips_s_conf_intervals': rips_s_conf_intervals,
            'sips_s_means': sips_s_means,
            'sips_s_conf_intervals': sips_s_conf_intervals,
            'rdm_s_means': rdm_s_means,
            'rdm_s_conf_intervals': rdm_s_conf_intervals,
            'sdm_s_means': sdm_s_means,
            'sdm_s_conf_intervals': sdm_s_conf_intervals
        })
        results_s.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-s.csv', index=False)

        results_full = pd.DataFrame({
            'list': list_,
            'parips_full_means': parips_full_means,
            'parips_full_conf_intervals': parips_full_conf_intervals,
            'surips_full_means': surips_full_means,
            'surips_full_conf_intervals': surips_full_conf_intervals,
            'sdr_full_means': sdr_full_means,
            'sdr_full_conf_intervals': sdr_conf_intervals,
            'sdr_both_full_means': sdr_both_full_means,
            'sdr_both_full_conf_intervals': sdr_both_conf_intervals,
            'logging_full_means': logging_full_means,
            'logging_full_conf_intervals': logging_conf_intervals,
            'ours_gamma_full_means': ours_gamma_full_means,
            'ours_gamma_full_conf_intervals': ours_gamma_full_conf_intervals,
            'opt_gamma_full_means': opt_gamma_full_means,
            'opt_gamma_full_conf_intervals': opt_gamma_full_conf_intervals,
            'rips_full_means': rips_full_means,
            'rips_full_conf_intervals': rips_full_conf_intervals,
            'sips_full_means': sips_full_means,
            'sips_full_conf_intervals': sips_full_conf_intervals,
            'rdm_full_means': rdm_full_means,
            'rdm_full_conf_intervals': rdm_full_conf_intervals,
            'sdm_full_means': sdm_full_means,
            'sdm_full_conf_intervals': sdm_full_conf_intervals
        })
        results_full.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-full.csv', index=False)


        gamma_list_means = gamma_list.mean(axis=0)
        gamma_true_list_means = gamma_true_list.mean(axis=0)
        #save gamma_list_means and true to each csv
        gamma_list_df = pd.DataFrame({'beta': list_, 'gamma_list_means': gamma_list_means})
        gamma_list_df.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-gamma.csv', index=False)
        gamma_true_list_df = pd.DataFrame({'beta': list_, 'gamma_true_list_means': gamma_true_list_means})
        gamma_true_list_df.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-gamma-true.csv', index=False)


        # Plotting for Relative Primary Reward Policy Value
        plt.figure(figsize=(10, 6))
        # Existing plots
        plt.plot(list_, ours_gamma_means, label=r'Ours(Tuned $\hat{\gamma}^*$)', color='turquoise')
        plt.fill_between(list_, ours_gamma_means - ours_gamma_conf_intervals, ours_gamma_means + ours_gamma_conf_intervals, color='turquoise', alpha=0.2)

        plt.plot(list_, sdr_both_means, label=r'Ours($\gamma = \beta$)', color='green')
        plt.fill_between(list_, sdr_both_means - sdr_both_conf_intervals, sdr_both_means + sdr_both_conf_intervals, color='green', alpha=0.2)

        plt.plot(list_, opt_gamma_means, label=r'Ours(Optimal $\gamma^*$)', color='hotpink')
        plt.fill_between(list_, opt_gamma_means - opt_gamma_conf_intervals, opt_gamma_means + opt_gamma_conf_intervals, color='hotpink', alpha=0.2)

        plt.plot(list_, parips_means, label='r-ParDR', color='blue')
        plt.fill_between(list_, parips_means - parips_conf_intervals, parips_means + parips_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_means, label='r-SurDR', color='red')
        plt.fill_between(list_, surips_means - surips_conf_intervals, surips_means + surips_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_means, label='r-TwoDR', color='orange')
        plt.fill_between(list_, sdr_means - sdr_conf_intervals, sdr_means + sdr_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, logging_means, label='Logging', color='saddlebrown')
        plt.fill_between(list_, logging_means - logging_conf_intervals, logging_means + logging_conf_intervals, color='saddlebrown', alpha=0.2)

        # New plots
        plt.plot(list_, rips_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_means - rips_conf_intervals, rips_means + rips_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_means - sips_conf_intervals, sips_means + sips_conf_intervals, color='cyan', alpha=0.2)

        plt.plot(list_, rdm_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_means - rdm_conf_intervals, rdm_means + rdm_conf_intervals, color='magenta', alpha=0.2)

        plt.plot(list_, sdm_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_means - sdm_conf_intervals, sdm_means + sdm_conf_intervals, color='lime', alpha=0.2)

        # Configuring plot details
        plt.xlabel('Beta Value', fontsize=18)
        plt.ylabel('Relative Target Reward Policy Value', fontsize=18)
        plt.title('Relative Target Reward Policy Value vs Beta Value', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}.png')

        # Plotting for Relative Surrogate Policy Value
        plt.figure(figsize=(10, 6))
        # Existing plots
        plt.plot(list_, ours_gamma_s_means, label=r'Ours(Tuned $\hat{\gamma}^*$)', color='turquoise')
        plt.fill_between(list_, ours_gamma_s_means - ours_gamma_s_conf_intervals, ours_gamma_s_means + ours_gamma_s_conf_intervals, color='turquoise', alpha=0.2)

        plt.plot(list_, sdr_both_s_means, label=r'Ours($\gamma = \beta$)', color='green')
        plt.fill_between(list_, sdr_both_s_means - sdr_both_s_conf_intervals, sdr_both_s_means + sdr_both_s_conf_intervals, color='green', alpha=0.2)

        plt.plot(list_, opt_gamma_s_means, label=r'Ours(Optimal $\gamma^*$)', color='hotpink')
        plt.fill_between(list_, opt_gamma_s_means - opt_gamma_s_conf_intervals, opt_gamma_s_means + opt_gamma_s_conf_intervals, color='hotpink', alpha=0.2)

        plt.plot(list_, parips_s_means, label='r-ParDR', color='blue')
        plt.fill_between(list_, parips_s_means - parips_s_conf_intervals, parips_s_means + parips_s_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_s_means, label='r-SurDR', color='red')
        plt.fill_between(list_, surips_s_means - surips_s_conf_intervals, surips_s_means + surips_s_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_s_means, label='r-TwoDR', color='orange')
        plt.fill_between(list_, sdr_s_means - sdr_s_conf_intervals, sdr_s_means + sdr_s_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, logging_s_means, label='Logging', color='saddlebrown')
        plt.fill_between(list_, logging_s_means - logging_s_conf_intervals, logging_s_means + logging_s_conf_intervals, color='saddlebrown', alpha=0.2)

        # New plots
        plt.plot(list_, rips_s_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_s_means - rips_s_conf_intervals, rips_s_means + rips_s_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_s_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_s_means - sips_s_conf_intervals, sips_s_means + sips_s_conf_intervals, color='cyan', alpha=0.2)

        plt.plot(list_, rdm_s_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_s_means - rdm_s_conf_intervals, rdm_s_means + rdm_s_conf_intervals, color='magenta', alpha=0.2)

        plt.plot(list_, sdm_s_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_s_means - sdm_s_conf_intervals, sdm_s_means + sdm_s_conf_intervals, color='lime', alpha=0.2)

        plt.xlabel('Beta Value', fontsize=18)
        plt.ylabel('Relative Secondary Policy Value', fontsize=18)
        plt.title('Relative Secondary Policy Value vs Beta Value', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}-s.png')

        # Plotting for Relative True Policy Value
        plt.figure(figsize=(10, 6))

        # Existing plots
        plt.plot(list_, ours_gamma_full_means, label=r'Ours(Tuned $\hat{\gamma}^*$)', color='turquoise')
        plt.fill_between(list_, ours_gamma_full_means - ours_gamma_full_conf_intervals, ours_gamma_full_means + ours_gamma_full_conf_intervals, color='turquoise', alpha=0.2)

        plt.plot(list_, sdr_both_full_means, label=r'Ours($\gamma = \beta$)', color='green')
        plt.fill_between(list_, sdr_both_full_means - sdr_both_full_conf_intervals, sdr_both_full_means + sdr_both_full_conf_intervals, color='green', alpha=0.2)

        plt.plot(list_, opt_gamma_full_means, label=r'Ours(Optimal $\gamma^*$)', color='hotpink')
        plt.fill_between(list_, opt_gamma_full_means - opt_gamma_full_conf_intervals, opt_gamma_full_means + opt_gamma_full_conf_intervals, color='hotpink', alpha=0.2)

        plt.plot(list_, parips_full_means, label='r-ParDR', color='blue')
        plt.fill_between(list_, parips_full_means - parips_full_conf_intervals, parips_full_means + parips_full_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_full_means, label='r-SurDR', color='red')
        plt.fill_between(list_, surips_full_means - surips_full_conf_intervals, surips_full_means + surips_full_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_full_means, label='r-TwoDR', color='orange')
        plt.fill_between(list_, sdr_full_means - sdr_full_conf_intervals, sdr_full_means + sdr_full_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, logging_full_means, label='Logging', color='saddlebrown')
        plt.fill_between(list_, logging_full_means - logging_full_conf_intervals, logging_full_means + logging_full_conf_intervals, color='saddlebrown', alpha=0.2)

        plt.plot(list_, rips_full_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_full_means - rips_full_conf_intervals, rips_full_means + rips_full_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_full_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_full_means - sips_full_conf_intervals, sips_full_means + sips_full_conf_intervals, color='cyan', alpha=0.2)

        plt.plot(list_, rdm_full_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_full_means - rdm_full_conf_intervals, rdm_full_means + rdm_full_conf_intervals, color='magenta', alpha=0.2)

        plt.plot(list_, sdm_full_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_full_means - sdm_full_conf_intervals, sdm_full_means + sdm_full_conf_intervals, color='lime', alpha=0.2)

        plt.xlabel('Beta Value', fontsize=18)
        plt.ylabel('Relative Combined Policy Value', fontsize=18)
        plt.title('Relative Combined Policy Value vs Beta Value', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}-full.png')