from utils.datasets import expected_reward_function
from utils.datasets_trunc_to_norm import SyntheticBanditDatasetWithSurrogate
from obp.policy import NNPolicyLearner
from utils.policy import OurLearner
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from obp.policy import Random
import time

exp_name = 'exp1'
print(torch.cuda.is_available())
base_path = "../../"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

p_o_list = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

n_sims=150

beta=0.3

alpha_noise = 0.4
hidden_layer_size = (150, )
num_tries=4
num_gamma = 5
batch_size = 64
n_rounds=1000
n_actions = 10
dim_context = 10
reward_type = "continuous"
reward_std = 0.5
beta_data = -2
lambda_ = 0.7
s_noise = 0.5
s_dim = 5
random_state = 12350
n_rounds_test=100000

save_file_name = f"al{alpha_noise}_b{beta}_lam{lambda_}_n{n_rounds}_noise{s_noise}"

with tqdm(total=n_sims, desc='Outer', position=0) as outer_bar:
    start_time = time.time()
    for sim in range(n_sims):
        parips_results = []
        surips_results = []
        sdr_results = []
        random_results=[]
        logging_results = []
        sdr_both_results = []
        parips_surrogate = []
        surips_surrogate = []
        sdr_surrogate = []
        sdr_both_surrogate = []
        random_surrogate=[]
        logging_surrogate = []
        parips_full = []
        surips_full = []
        sdr_full = []
        sdr_both_full = []
        random_full=[]
        logging_full = []
        rips_results = []
        rips_surrogate = []
        rips_full = []
        sips_results = []
        sips_surrogate = []
        sips_full = []
        rdm_results = []
        rdm_surrogate = []
        rdm_full = []
        sdm_results = []
        sdm_surrogate = []
        sdm_full = []
        
        for p_o in p_o_list:
            dataset=SyntheticBanditDatasetWithSurrogate(n_actions = n_actions, dim_context = dim_context, reward_type=reward_type, reward_std=reward_std, beta=beta_data, lambda_=lambda_, s_noise=s_noise, s_dim=s_dim, p_o=p_o, alpha_noise=alpha_noise, random_state=random_state+sim)
            bandit_feedback_train =  dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds)
            bandit_feedback_test = dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds_test)
            if p_o==0.0:
                sdr_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                sdr_policy_action_dist = sdr_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:
                sdr_policy = OurLearner(
                    n_actions=n_actions,
                    s_dim=s_dim,
                    dim_context=dim_context,
                    hidden_layer_size=hidden_layer_size,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='sdr-both',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                sdr_policy.fit(
                    context=bandit_feedback_train["contexts"],
                    action=bandit_feedback_train["actions"],
                    surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                    reward=bandit_feedback_train["rewards"],
                    obs_list=bandit_feedback_train["obs_list"],
                    p_o=p_o,
                    pscore=bandit_feedback_train["pscores"],
                    s_sum=bandit_feedback_train["s_sum"],
                    beta = 0.0,
                )
                sdr_policy_action_dist = sdr_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            sdr_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sdr_policy_action_dist
            )
            sdr_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sdr_policy_action_dist
            )
            sdr_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdr_policy_action_dist,
                beta=beta,
            )
            sdr_results.append(sdr_policy_value)
            sdr_surrogate.append(sdr_surrogate_value)
            sdr_full.append(sdr_full_value)
            
            sdr_both_policy = OurLearner(
                n_actions=n_actions,
                s_dim=s_dim,
                dim_context=dim_context,
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='sdr-both',
                batch_size=64,
                random_state=random_state+sim,
            )
            sdr_both_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                reward=bandit_feedback_train["rewards"],
                obs_list=bandit_feedback_train["obs_list"],
                p_o=p_o,
                pscore=bandit_feedback_train["pscores"],
                s_sum=bandit_feedback_train["s_sum"],
                beta = beta,
            )
            sdr_both_policy_action_dist = sdr_both_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            sdr_both_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sdr_both_policy_action_dist
            )
            sdr_both_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sdr_both_policy_action_dist
            )
            sdr_both_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdr_both_policy_action_dist,
                beta=beta,
            )
            sdr_both_results.append(sdr_both_policy_value)
            sdr_both_surrogate.append(sdr_both_surrogate_value)
            sdr_both_full.append(sdr_both_full_value)
            
            # random_policy = Random(n_actions=n_actions, random_state=random_state+sim)
            # random_policy_action_dist = random_policy.compute_batch_action_dist(
            #     n_rounds=n_rounds_test,
            # )
            # random_policy_value = dataset.calc_ground_truth_policy_value(
            #     expected_reward=bandit_feedback_test["all_q_x_a_f"],
            #     action_dist=random_policy_action_dist
            # )
            # random_surrogate_value = dataset.calc_ground_truth_policy_value(
            #     expected_reward = bandit_feedback_test["f_sum"],
            #     action_dist=random_policy_action_dist
            # )
            # random_results.append(random_policy_value)
            # random_surrogate.append(random_surrogate_value)
            
            logging_policy = bandit_feedback_test["pi_b"]
            logging_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=logging_policy
            )
            logging_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=logging_policy
            )
            logging_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=logging_policy,
                beta=beta,
            )
            logging_results.append(logging_policy_value)
            logging_surrogate.append(logging_surrogate_value)
            logging_full.append(logging_full_value)
            
            if p_o==0.0:
                parips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                parips_policy_action_dist = parips_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:    
                parips_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=dim_context,
                    hidden_layer_size=hidden_layer_size,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='dr',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                parips_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                parips_policy_action_dist = parips_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            parips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=parips_policy_action_dist
            )
            parips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=parips_policy_action_dist
            )
            parips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=parips_policy_action_dist,
                beta=beta,
            )
            parips_results.append(parips_policy_value)
            parips_surrogate.append(parips_surrogate_value)
            parips_full.append(parips_full_value)
            
            surips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=dim_context,
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='dr',
                batch_size=64,
                random_state=random_state+sim,
            )
            surips_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                reward=bandit_feedback_train["f_s"],
                pscore=bandit_feedback_train["pscores"],
            )
            surips_policy_action_dist = surips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            surips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=surips_policy_action_dist
            )
            surips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=surips_policy_action_dist
            )
            surips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward = bandit_feedback_test["f_sum"],
                action_dist=surips_policy_action_dist,
                beta=beta,
            )
            surips_results.append(surips_policy_value)
            surips_surrogate.append(surips_surrogate_value)
            surips_full.append(surips_full_value)
            
            if p_o==0.0:
                rips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                rips_policy_action_dist = rips_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:    
                rips_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=dim_context,
                    hidden_layer_size=hidden_layer_size,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='ipw',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                rips_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                rips_policy_action_dist = rips_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            rips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=rips_policy_action_dist
            )
            rips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=rips_policy_action_dist
            )
            rips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rips_policy_action_dist,
                beta=beta,
            )
            rips_results.append(rips_policy_value)
            rips_surrogate.append(rips_surrogate_value)
            rips_full.append(rips_full_value)
            
            if p_o==0.0:
                rdm_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                rdm_policy_action_dist = rdm_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:    
                rdm_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=dim_context,
                    hidden_layer_size=hidden_layer_size,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='dm',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                rdm_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                rdm_policy_action_dist = rdm_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            rdm_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=rdm_policy_action_dist
            )
            rdm_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=rdm_policy_action_dist
            )
            rdm_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rdm_policy_action_dist,
                beta=beta,
            )
            rdm_results.append(rdm_policy_value)
            rdm_surrogate.append(rdm_surrogate_value)
            rdm_full.append(rdm_full_value)
            
            sips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=dim_context,
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='ipw',
                batch_size=64,
                random_state=random_state+sim,
            )
            sips_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                reward=bandit_feedback_train["f_s"],
                pscore=bandit_feedback_train["pscores"],
            )
            sips_policy_action_dist = sips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            sips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sips_policy_action_dist
            )
            sips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sips_policy_action_dist
            )
            sips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward = bandit_feedback_test["f_sum"],
                action_dist=sips_policy_action_dist,
                beta=beta,
            )
            sips_results.append(sips_policy_value)
            sips_surrogate.append(sips_surrogate_value)
            sips_full.append(sips_full_value)
            
            sdm_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=dim_context,
                hidden_layer_size=hidden_layer_size,
                early_stopping=True,
                activation='relu',
                solver='adam',
                off_policy_objective='dm',
                batch_size=64,
                random_state=random_state+sim,
            )
            sdm_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                reward=bandit_feedback_train["f_s"],
                pscore=bandit_feedback_train["pscores"],
            )
            sdm_policy_action_dist = sdm_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            sdm_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sdm_policy_action_dist
            )
            sdm_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sdm_policy_action_dist
            )
            sdm_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward = bandit_feedback_test["f_sum"],
                action_dist=sdm_policy_action_dist,
                beta=beta,
            )
            sdm_results.append(sdm_policy_value)
            sdm_surrogate.append(sdm_surrogate_value)
            sdm_full.append(sdm_full_value)

        if sim==0:
            parips_all_results=parips_results
            surips_all_results=surips_results
            sdr_all_results=sdr_results
            sdr_both_all_results=sdr_both_results
            # random_all_results=random_results
            logging_all_results=logging_results
            parips_all_surrogate=parips_surrogate
            surips_all_surrogate=surips_surrogate
            sdr_all_surrogate=sdr_surrogate
            sdr_both_all_surrogate=sdr_both_surrogate
            # random_all_surrogate=random_surrogate
            logging_all_surrogate=logging_surrogate
            parips_all_full = parips_full
            surips_all_full = surips_full
            sdr_all_full = sdr_full
            sdr_both_all_full = sdr_both_full
            # random_all_full = random_full
            logging_all_full = logging_full
            rips_all_results = rips_results
            rips_all_surrogate = rips_surrogate
            rips_all_full = rips_full
            sips_all_results = sips_results
            sips_all_surrogate = sips_surrogate
            sips_all_full = sips_full
            rdm_all_results = rdm_results
            rdm_all_surrogate = rdm_surrogate
            rdm_all_full = rdm_full
            sdm_all_results = sdm_results
            sdm_all_surrogate = sdm_surrogate
            sdm_all_full = sdm_full
            
        else:
            parips_all_results=np.vstack([parips_all_results, parips_results])
            surips_all_results=np.vstack([surips_all_results, surips_results])
            sdr_all_results=np.vstack([sdr_all_results, sdr_results])
            sdr_both_all_results=np.vstack([sdr_both_all_results, sdr_both_results])
            # random_all_results=np.vstack([random_all_results, random_results])
            logging_all_results=np.vstack([logging_all_results, logging_results])
            parips_all_surrogate=np.vstack([parips_all_surrogate, parips_surrogate])
            surips_all_surrogate=np.vstack([surips_all_surrogate, surips_surrogate])
            sdr_all_surrogate=np.vstack([sdr_all_surrogate, sdr_surrogate])
            sdr_both_all_surrogate=np.vstack([sdr_both_all_surrogate, sdr_both_surrogate])
            # random_all_surrogate=np.vstack([random_all_surrogate, random_surrogate])
            logging_all_surrogate=np.vstack([logging_all_surrogate, logging_surrogate])
            parips_all_full = np.vstack([parips_all_full, parips_full])
            surips_all_full = np.vstack([surips_all_full, surips_full])
            sdr_all_full = np.vstack([sdr_all_full, sdr_full])
            sdr_both_all_full = np.vstack([sdr_both_all_full, sdr_both_full])
            # random_all_full = random_full
            logging_all_full = np.vstack([logging_all_full, logging_full])
            rips_all_results = np.vstack([rips_all_results, rips_results])
            rips_all_surrogate = np.vstack([rips_all_surrogate, rips_surrogate])
            rips_all_full = np.vstack([rips_all_full, rips_full])
            sips_all_results = np.vstack([sips_all_results, sips_results])
            sips_all_surrogate = np.vstack([sips_all_surrogate, sips_surrogate])
            sips_all_full = np.vstack([sips_all_full, sips_full])
            rdm_all_results = np.vstack([rdm_all_results, rdm_results])
            rdm_all_surrogate = np.vstack([rdm_all_surrogate, rdm_surrogate])
            rdm_all_full = np.vstack([rdm_all_full, rdm_full])
            sdm_all_results = np.vstack([sdm_all_results, sdm_results])
            sdm_all_surrogate = np.vstack([sdm_all_surrogate, sdm_surrogate])
            sdm_all_full = np.vstack([sdm_all_full, sdm_full])
        outer_bar.update(1)
        elapsed_time = time.time() - start_time
        remaining_time = (elapsed_time / (sim + 1)) * (n_sims - (sim + 1))
        tqdm.write(f"\rSimulations: {outer_bar.n}/{outer_bar.total} - "
                   f"{tqdm.format_interval(elapsed_time)}<{tqdm.format_interval(remaining_time)}")

        if sim<2:
            continue
        parips_means  = parips_all_results.mean(axis=0)
        surips_means  = surips_all_results.mean(axis=0)
        sdr_means  = sdr_all_results.mean(axis=0)
        sdr_both_means  = sdr_both_all_results.mean(axis=0)
        # random_means  = random_all_results.mean(axis=0)
        logging_means  = logging_all_results.mean(axis=0)

        parips_s_means  = parips_all_surrogate.mean(axis=0)
        surips_s_means  = surips_all_surrogate.mean(axis=0)
        sdr_s_means  = sdr_all_surrogate.mean(axis=0)
        sdr_both_s_means  = sdr_both_all_surrogate.mean(axis=0)
        # random_s_means  = random_all_surrogate.mean(axis=0)
        logging_s_means  = logging_all_surrogate.mean(axis=0)

        parips_full_means = parips_all_full.mean(axis=0)
        surips_full_means = surips_all_full.mean(axis=0)
        sdr_full_means = sdr_all_full.mean(axis=0)
        sdr_both_full_means = sdr_both_all_full.mean(axis=0)
        # random_full_means = random_all_full.mean(axis=0)
        logging_full_means = logging_all_full.mean(axis=0)


        parips_stds = parips_all_results.std(axis=0)
        surips_stds = surips_all_results.std(axis=0)
        sdr_stds = sdr_all_results.std(axis=0)
        sdr_both_stds = sdr_both_all_results.std(axis=0)
        # random_stds = random_all_results.std(axis=0)
        logging_stds = logging_all_results.std(axis=0)

        parips_s_stds = parips_all_surrogate.std(axis=0)
        surips_s_stds = surips_all_surrogate.std(axis=0)
        sdr_s_stds = sdr_all_surrogate.std(axis=0)
        sdr_both_s_stds = sdr_both_all_surrogate.std(axis=0)
        # random_s_stds = random_all_surrogate.std(axis=0)
        logging_s_stds = logging_all_surrogate.std(axis=0)

        parips_full_stds = parips_all_full.std(axis=0)
        surips_full_stds = surips_all_full.std(axis=0)
        sdr_full_stds = sdr_all_full.std(axis=0)
        sdr_both_full_stds = sdr_both_all_full.std(axis=0)
        # random_full_stds = random_all_full.std(axis=0)
        logging_full_stds = logging_all_full.std(axis=0)

        # 標準誤差を計算
        parips_stes = parips_stds / np.sqrt(n_sims)
        surips_stes = surips_stds / np.sqrt(n_sims)
        sdr_stes = sdr_stds / np.sqrt(n_sims)
        sdr_both_stes = sdr_both_stds / np.sqrt(n_sims)
        # random_stes = random_stds / np.sqrt(n_sims)
        logging_stes = logging_stds / np.sqrt(n_sims)

        parips_s_stes = parips_s_stds / np.sqrt(n_sims)
        surips_s_stes = surips_s_stds / np.sqrt(n_sims)
        sdr_s_stes = sdr_s_stds / np.sqrt(n_sims)
        sdr_both_s_stes = sdr_both_s_stds / np.sqrt(n_sims)
        # random_s_stes = random_s_stds / np.sqrt(n_sims)
        logging_s_stes = logging_s_stds / np.sqrt(n_sims)

        parips_full_stes = parips_full_stds / np.sqrt(n_sims)
        surips_full_stes = surips_full_stds / np.sqrt(n_sims)
        sdr_full_stes = sdr_full_stds / np.sqrt(n_sims)
        sdr_both_full_stes = sdr_both_full_stds / np.sqrt(n_sims)
        # random_full_stes = random_full_stds / np.sqrt(n_sims)
        logging_full_stes = logging_full_stds / np.sqrt(n_sims)

        # 95% 信頼区間の計算
        confidence = 1.96
        parips_conf_intervals = (parips_stes * confidence) 
        surips_conf_intervals = (surips_stes * confidence) 
        sdr_conf_intervals = (sdr_stes * confidence) 
        sdr_both_conf_intervals = (sdr_both_stes * confidence) 
        # random_conf_intervals = (random_stes * confidence) 
        logging_conf_intervals = (logging_stes * confidence) 

        parips_s_conf_intervals = (parips_s_stes * confidence) 
        surips_s_conf_intervals = (surips_s_stes * confidence) 
        sdr_s_conf_intervals = (sdr_s_stes * confidence) 
        sdr_both_s_conf_intervals = (sdr_both_s_stes * confidence) 
        # random_s_conf_intervals = (random_s_stes * confidence) 
        logging_s_conf_intervals = (logging_s_stes * confidence) 

        parips_full_conf_intervals = (parips_full_stes * confidence) 
        surips_full_conf_intervals = (surips_full_stes * confidence) 
        sdr_full_conf_intervals = (sdr_full_stes * confidence) 
        sdr_both_full_conf_intervals = (sdr_both_full_stes * confidence) 
        # random_full_conf_intervals = (random_full_stes * confidence) 
        logging_full_conf_intervals = (logging_full_stes * confidence) 
        
        rips_means  = rips_all_results.mean(axis=0)
        rips_s_means  = rips_all_surrogate.mean(axis=0)
        rips_full_means = rips_all_full.mean(axis=0)
        rips_stds = rips_all_results.std(axis=0)
        rips_stes = rips_stds / np.sqrt(n_sims)
        rips_conf_intervals = (rips_stes * confidence)
        rips_s_stds = rips_all_surrogate.std(axis=0)
        rips_s_stes = rips_s_stds / np.sqrt(n_sims)
        rips_s_conf_intervals = (rips_s_stes * confidence)
        rips_full_stds = rips_all_full.std(axis=0)
        rips_full_stes = rips_full_stds / np.sqrt(n_sims)
        rips_full_conf_intervals = (rips_full_stes * confidence)
        sips_means = sips_all_results.mean(axis=0)
        sips_s_means = sips_all_surrogate.mean(axis=0)
        sips_full_means = sips_all_full.mean(axis=0)
        sips_stds = sips_all_results.std(axis=0)
        sips_stes = sips_stds / np.sqrt(n_sims)
        sips_conf_intervals = (sips_stes * confidence)
        sips_s_stds = sips_all_surrogate.std(axis=0)
        sips_s_stes = sips_s_stds / np.sqrt(n_sims)
        sips_s_conf_intervals = (sips_s_stes * confidence)
        sips_full_stds = sips_all_full.std(axis=0)
        sips_full_stes = sips_full_stds / np.sqrt(n_sims)
        sips_full_conf_intervals = (sips_full_stes * confidence)
        rdm_means = rdm_all_results.mean(axis=0)
        rdm_s_means = rdm_all_surrogate.mean(axis=0)
        rdm_full_means = rdm_all_full.mean(axis=0)
        rdm_stds = rdm_all_results.std(axis=0)
        rdm_stes = rdm_stds / np.sqrt(n_sims)
        rdm_conf_intervals = (rdm_stes * confidence)
        rdm_s_stds= rdm_all_surrogate.std(axis=0)
        rdm_s_stes = rdm_s_stds / np.sqrt(n_sims)
        rdm_s_conf_intervals = (rdm_s_stes * confidence)
        rdm_full_stds = rdm_all_full.std(axis=0)
        rdm_full_stes = rdm_full_stds / np.sqrt(n_sims)
        rdm_full_conf_intervals = (rdm_full_stes * confidence)
        sdm_means = sdm_all_results.mean(axis=0)
        sdm_s_means = sdm_all_surrogate.mean(axis=0)
        sdm_full_means = sdm_all_full.mean(axis=0)
        sdm_stds = sdm_all_results.std(axis=0)
        sdm_stes = sdm_stds / np.sqrt(n_sims)
        sdm_conf_intervals = (sdm_stes * confidence)
        sdm_s_stds = sdm_all_surrogate.std(axis=0)
        sdm_s_stes = sdm_s_stds / np.sqrt(n_sims)
        sdm_s_conf_intervals = (sdm_s_stes * confidence)
        sdm_full_stds = sdm_all_full.std(axis=0)
        sdm_full_stes = sdm_full_stds / np.sqrt(n_sims)
        sdm_full_conf_intervals = (sdm_full_stes * confidence)

        list_=[str(0.0) if i==0.001 else str(i) for i in p_o_list]


        #save results in csv with s_opt
        if not os.path.exists(f'../../results/synthetic/{exp_name}/data'):
            os.makedirs(f'../../results/synthetic/{exp_name}/data')
        results = pd.DataFrame({
            'list_': list_,
            'parips_means': parips_means,
            'parips_conf_intervals': parips_conf_intervals,
            'surips_means': surips_means,
            'surips_conf_intervals': surips_conf_intervals,
            'sdr_means': sdr_means,
            'sdr_conf_intervals': sdr_conf_intervals,
            'sdr_both_means': sdr_both_means,
            'sdr_both_conf_intervals': sdr_both_conf_intervals,
            'logging_means': logging_means,
            'logging_conf_intervals': logging_conf_intervals,
            'rips_means': rips_means,
            'rips_conf_intervals': rips_conf_intervals,
            'sips_means': sips_means,
            'sips_conf_intervals': sips_conf_intervals,
            'rdm_means': rdm_means,
            'rdm_conf_intervals': rdm_conf_intervals,
            'sdm_means': sdm_means,
            'sdm_conf_intervals': sdm_conf_intervals
        })
        results.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}.csv', index=False)

        results_s = pd.DataFrame({
            'list_': list_,
            'parips_s_means': parips_s_means,
            'parips_s_conf_intervals': parips_s_conf_intervals,
            'surips_s_means': surips_s_means,
            'surips_s_conf_intervals': surips_s_conf_intervals,
            'sdr_s_means': sdr_s_means,
            'sdr_s_conf_intervals': sdr_s_conf_intervals,
            'sdr_both_s_means': sdr_both_s_means,
            'sdr_both_s_conf_intervals': sdr_both_s_conf_intervals,
            'logging_s_means': logging_s_means,
            'logging_s_conf_intervals': logging_s_conf_intervals,
            'rips_s_means': rips_s_means,
            'rips_s_conf_intervals': rips_s_conf_intervals,
            'sips_s_means': sips_s_means,
            'sips_s_conf_intervals': sips_s_conf_intervals,
            'rdm_s_means': rdm_s_means,
            'rdm_s_conf_intervals': rdm_s_conf_intervals,
            'sdm_s_means': sdm_s_means,
            'sdm_s_conf_intervals': sdm_s_conf_intervals
        })
        results_s.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}-s.csv', index=False)

        results_full = pd.DataFrame({
            'list_': list_,
            'parips_full_means': parips_full_means,
            'parips_full_conf_intervals': parips_full_conf_intervals,
            'surips_full_means': surips_full_means,
            'surips_full_conf_intervals': surips_full_conf_intervals,
            'sdr_full_means': sdr_full_means,
            'sdr_full_conf_intervals': sdr_full_conf_intervals,
            'sdr_both_full_means': sdr_both_full_means,
            'sdr_both_full_conf_intervals': sdr_both_full_conf_intervals,
            'logging_full_means': logging_full_means,
            'logging_full_conf_intervals': logging_full_conf_intervals,
            'rips_full_means': rips_full_means,
            'rips_full_conf_intervals': rips_full_conf_intervals,
            'sips_full_means': sips_full_means,
            'sips_full_conf_intervals': sips_full_conf_intervals,
            'rdm_full_means': rdm_full_means,
            'rdm_full_conf_intervals': rdm_full_conf_intervals,
            'sdm_full_means': sdm_full_means,
            'sdm_full_conf_intervals': sdm_full_conf_intervals
        })
        results_full.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}-full.csv', index=False)


        # Plotting for Relative Primary Reward Policy Value
        plt.figure(figsize=(10, 6))
        # Existing plots
        # New plots
        plt.plot(list_, rips_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_means - rips_conf_intervals, rips_means + rips_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_means, label='s-IPS', color='brown')
        plt.fill_between(list_, sips_means - sips_conf_intervals, sips_means + sips_conf_intervals, color='brown', alpha=0.2)

        plt.plot(list_, rdm_means, label='r-DM', color='pink')
        plt.fill_between(list_, rdm_means - rdm_conf_intervals, rdm_means + rdm_conf_intervals, color='pink', alpha=0.2)

        plt.plot(list_, sdm_means, label='s-DM', color='gray')
        plt.fill_between(list_, sdm_means - sdm_conf_intervals, sdm_means + sdm_conf_intervals, color='gray', alpha=0.2)
        
        plt.plot(list_, parips_means, label='r-DR', color='blue')
        plt.fill_between(list_, parips_means - parips_conf_intervals, parips_means + parips_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_means, label='s-DR', color='red')
        plt.fill_between(list_, surips_means - surips_conf_intervals, surips_means + surips_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_means, label='c-DR', color='orange')
        plt.fill_between(list_, sdr_means - sdr_conf_intervals, sdr_means + sdr_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, sdr_both_means, label=r'HyPeR($\beta=\gamma$)', color='green')
        plt.fill_between(list_, sdr_both_means - sdr_both_conf_intervals, sdr_both_means + sdr_both_conf_intervals, color='green', alpha=0.2)

        # Configuring plot details
        plt.xlabel('Primary Reward Observation Probability', fontsize=18)
        plt.ylabel('Relative Primary Reward Policy Value', fontsize=18)
        plt.title('Relative Primary Reward Policy Value vs Primary Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}.png')

        # Plotting for Relative Surrogate Policy Value
        plt.figure(figsize=(10, 6))

        # Existing plots
        
        plt.plot(list_, rips_s_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_s_means - rips_s_conf_intervals, rips_s_means + rips_s_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_s_means, label='s-IPS', color='brown')
        plt.fill_between(list_, sips_s_means - sips_s_conf_intervals, sips_s_means + sips_s_conf_intervals, color='brown', alpha=0.2)

        plt.plot(list_, rdm_s_means, label='r-DM', color='pink')
        plt.fill_between(list_, rdm_s_means - rdm_s_conf_intervals, rdm_s_means + rdm_s_conf_intervals, color='pink', alpha=0.2)

        plt.plot(list_, sdm_s_means, label='s-DM', color='gray')
        plt.fill_between(list_, sdm_s_means - sdm_s_conf_intervals, sdm_s_means + sdm_s_conf_intervals, color='gray', alpha=0.2)
        
        plt.plot(list_, parips_s_means, label='r-DR', color='blue')
        plt.fill_between(list_, parips_s_means - parips_s_conf_intervals, parips_s_means + parips_s_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_s_means, label='s-DR', color='red')
        plt.fill_between(list_, surips_s_means - surips_s_conf_intervals, surips_s_means + surips_s_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_s_means, label='c-DR', color='orange')
        plt.fill_between(list_, sdr_s_means - sdr_s_conf_intervals, sdr_s_means + sdr_s_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, sdr_both_s_means, label=r'HyPeR($\beta=\gamma$)', color='green')
        plt.fill_between(list_, sdr_both_s_means - sdr_both_s_conf_intervals, sdr_both_s_means + sdr_both_s_conf_intervals, color='green', alpha=0.2)

        # New plots

        # Configuring plot details
        plt.xlabel('Primary Reward Observation Probability', fontsize=18)
        plt.ylabel('Relative Surrogate Policy Value', fontsize=18)
        plt.title('Relative Surrogate Policy Value vs Primary Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}-s.png')

        # Plotting for Full Policy Value
        plt.figure(figsize=(10, 6))
        plt.plot(list_, rips_full_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_full_means - rips_full_conf_intervals, rips_full_means + rips_full_conf_intervals, color='purple', alpha=0.2)

        plt.plot(list_, sips_full_means, label='s-IPS', color='brown')
        plt.fill_between(list_, sips_full_means - sips_full_conf_intervals, sips_full_means + sips_full_conf_intervals, color='brown', alpha=0.2)

        plt.plot(list_, rdm_full_means, label='r-DM', color='pink')
        plt.fill_between(list_, rdm_full_means - rdm_full_conf_intervals, rdm_full_means + rdm_full_conf_intervals, color='pink', alpha=0.2)

        plt.plot(list_, sdm_full_means, label='s-DM', color='gray')
        plt.fill_between(list_, sdm_full_means - sdm_full_conf_intervals, sdm_full_means + sdm_full_conf_intervals, color='gray', alpha=0.2)
        
        plt.plot(list_, parips_full_means, label='r-DR', color='blue')
        plt.fill_between(list_, parips_full_means - parips_full_conf_intervals, parips_full_means + parips_full_conf_intervals, color='blue', alpha=0.2)

        plt.plot(list_, surips_full_means, label='s-DR', color='red')
        plt.fill_between(list_, surips_full_means - surips_full_conf_intervals, surips_full_means + surips_full_conf_intervals, color='red', alpha=0.2)

        plt.plot(list_, sdr_full_means, label='c-DR', color='orange')
        plt.fill_between(list_, sdr_full_means - sdr_full_conf_intervals, sdr_full_means + sdr_full_conf_intervals, color='orange', alpha=0.2)

        plt.plot(list_, sdr_both_full_means, label=r'HyPeR($\beta=\gamma$)', color='green')
        plt.fill_between(list_, sdr_both_full_means - sdr_both_full_conf_intervals, sdr_both_full_means + sdr_both_full_conf_intervals, color='green', alpha=0.2)

        # New plots

        plt.xlabel('Primary Reward Observation Probability', fontsize=18)
        plt.ylabel('Full Policy Value', fontsize=18)
        plt.title('Full Policy Value vs Primary Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}-full.png')