from utils.datasets import expected_reward_function
from utils.datasets_trunc_to_norm import SyntheticBanditDatasetWithSurrogate
from obp.policy import NNPolicyLearner
from utils.policy import OurLearner
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from obp.policy import Random
import time

exp_name = "exp4"
save_file_name = "test_dr"
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

reward_std_list = [0.001, 1.0, 2.0, 3.0, 4.0, 5.0] 
parips_all_results = []
surips_all_results = []
sdr_all_results = []
sdr_both_all_results = []
s_opt_all_results = []
random_all_results = []
logging_all_results = []
parips_all_surrogate = []
surips_all_surrogate = []
sdr_all_surrogate = []
sdr_both_all_surrogate = []
s_opt_all_surrogate = []
random_all_surrogate = []
logging_all_surrogate = []

n_sims=50

alpha_noise=0.5
beta=0.3
p_o=0.2
n_rounds=2000
n_actions = 10
dim_context = 10
reward_type = "continuous"
s_noise = 0.2
beta_data = -2
lambda_ = 0.7
s_dim = 5
random_state = 12345
n_rounds_test=100000

with tqdm(total=n_sims, desc='Outer', position=0) as outer_bar:
    start_time = time.time()
    for sim in range(n_sims):
        parips_results = []
        surips_results = []
        sdr_results = []
        s_opt_results = []
        random_results=[]
        logging_results = []
        sdr_both_results = []
        parips_surrogate = []
        surips_surrogate = []
        sdr_surrogate = []
        sdr_both_surrogate = []
        random_surrogate=[]
        logging_surrogate = []
        s_opt_surrogate = []
        parips_full = []
        surips_full = []
        sdr_full = []
        sdr_both_full = []
        random_full=[]
        logging_full = []
        s_opt_full = []
        
        for reward_std in reward_std_list:
            dataset=SyntheticBanditDatasetWithSurrogate(alpha_noise=alpha_noise, n_actions = n_actions, dim_context = dim_context, reward_type=reward_type, reward_std=reward_std, beta=beta_data, lambda_=lambda_, s_noise=s_noise, s_dim=s_dim, p_o=p_o, random_state=random_state+sim)
            bandit_feedback_train =  dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds)
            bandit_feedback_test = dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds_test)
            if p_o==0.0:
                sdr_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                sdr_policy_action_dist = sdr_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:
                sdr_policy = OurLearner(
                    n_actions=n_actions,
                    s_dim=s_dim,
                    dim_context=dim_context,
                    hidden_layer_size=(200,),
                    early_stopping=False,
                    activation='identity',
                    solver='adam',
                    off_policy_objective='sdr',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                sdr_policy.fit(
                    context=bandit_feedback_train["contexts"],
                    action=bandit_feedback_train["actions"],
                    surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                    reward=bandit_feedback_train["rewards"],
                    obs_list=bandit_feedback_train["obs_list"],
                    p_o=p_o,
                    pscore=bandit_feedback_train["pscores"],
                )
                sdr_policy_action_dist = sdr_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            sdr_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sdr_policy_action_dist
            )
            sdr_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sdr_policy_action_dist
            )
            sdr_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdr_policy_action_dist,
                beta=beta,
            )
            sdr_results.append(sdr_policy_value)
            sdr_surrogate.append(sdr_surrogate_value)
            sdr_full.append(sdr_full_value)
            
            sdr_both_policy = OurLearner(
                n_actions=n_actions,
                s_dim=s_dim,
                dim_context=dim_context,
                hidden_layer_size=(200,),
                early_stopping=False,
                activation='identity',
                solver='adam',
                off_policy_objective='sdr-both',
                batch_size=64,
                random_state=random_state+sim,
            )
            sdr_both_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                reward=bandit_feedback_train["rewards"],
                obs_list=bandit_feedback_train["obs_list"],
                p_o=p_o,
                pscore=bandit_feedback_train["pscores"],
                s_sum=bandit_feedback_train["s_sum"],
                beta = beta,
            )
            sdr_both_policy_action_dist = sdr_both_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            sdr_both_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=sdr_both_policy_action_dist
            )
            sdr_both_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=sdr_both_policy_action_dist
            )
            sdr_both_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=sdr_both_policy_action_dist,
                beta=beta,
            )
            sdr_both_results.append(sdr_both_policy_value)
            sdr_both_surrogate.append(sdr_both_surrogate_value)
            sdr_both_full.append(sdr_both_full_value)
            
            s_opt_policy = OurLearner(
                n_actions=n_actions,
                s_dim=s_dim,
                dim_context=dim_context,
                hidden_layer_size=(200,),
                early_stopping=False,
                activation='identity',
                solver='adam',
                off_policy_objective='sdr-both',
                batch_size=64,
                random_state=random_state+sim,
            )
            s_opt_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                surrogate_reward=bandit_feedback_train["surrogate_rewards"],
                reward=bandit_feedback_train["rewards"],
                obs_list=bandit_feedback_train["obs_list"],
                p_o=p_o,
                pscore=bandit_feedback_train["pscores"],
                s_sum=bandit_feedback_train["s_sum"],
                beta = 1.0,
            )
            s_opt_policy_action_dist = s_opt_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            s_opt_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=s_opt_policy_action_dist
            )
            s_opt_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=s_opt_policy_action_dist
            )
            s_opt_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=s_opt_policy_action_dist,
                beta=beta,
            )
            
            s_opt_results.append(s_opt_policy_value)
            s_opt_surrogate.append(s_opt_surrogate_value)
            s_opt_full.append(s_opt_full_value)
            
            # random_policy = Random(n_actions=n_actions, random_state=random_state+sim)
            # random_policy_action_dist = random_policy.compute_batch_action_dist(
            #     n_rounds=n_rounds_test,
            # )
            # random_policy_value = dataset.calc_ground_truth_policy_value(
            #     expected_reward=bandit_feedback_test["all_q_x_a_f"],
            #     action_dist=random_policy_action_dist
            # )
            # random_surrogate_value = dataset.calc_ground_truth_policy_value(
            #     expected_reward = bandit_feedback_test["f_sum"],
            #     action_dist=random_policy_action_dist
            # )
            # random_results.append(random_policy_value)
            # random_surrogate.append(random_surrogate_value)
            
            logging_policy = bandit_feedback_test["pi_b"]
            logging_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=logging_policy
            )
            logging_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=logging_policy
            )
            logging_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=logging_policy,
                beta=beta,
            )
            logging_results.append(logging_policy_value)
            logging_surrogate.append(logging_surrogate_value)
            logging_full.append(logging_full_value)
            
            if p_o==0.0:
                parips_policy=Random(n_actions=n_actions, random_state=random_state+sim)
                parips_policy_action_dist = parips_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:    
                parips_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=dim_context,
                    hidden_layer_size=(200,),
                    early_stopping=False,
                    activation='identity',
                    solver='adam',
                    off_policy_objective='dr',
                    batch_size=64,
                    random_state=random_state+sim,
                )
                parips_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                parips_policy_action_dist = parips_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            parips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=parips_policy_action_dist
            )
            parips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=parips_policy_action_dist
            )
            parips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=parips_policy_action_dist,
                beta=beta,
            )
            parips_results.append(parips_policy_value)
            parips_surrogate.append(parips_surrogate_value)
            parips_full.append(parips_full_value)
            
            surips_policy = NNPolicyLearner(
                n_actions=n_actions,
                dim_context=dim_context,
                hidden_layer_size=(200,),
                early_stopping=False,
                activation='identity',
                solver='adam',
                off_policy_objective='dr',
                batch_size=64,
                random_state=random_state+sim,
            )
            surips_policy.fit(
                context=bandit_feedback_train["contexts"],
                action=bandit_feedback_train["actions"],
                reward=bandit_feedback_train["f_s"],
                pscore=bandit_feedback_train["pscores"],
            )
            surips_policy_action_dist = surips_policy.predict(
                context=bandit_feedback_test["contexts"],
            )
            surips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=surips_policy_action_dist
            )
            surips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward = bandit_feedback_test["f_sum"],
                action_dist=surips_policy_action_dist
            )
            surips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward = bandit_feedback_test["f_sum"],
                action_dist=surips_policy_action_dist,
                beta=beta,
            )
            surips_results.append(surips_policy_value)
            surips_surrogate.append(surips_surrogate_value)
            surips_full.append(surips_full_value)
            
        if sim==0:
            parips_all_results=parips_results
            surips_all_results=surips_results
            sdr_all_results=sdr_results
            sdr_both_all_results=sdr_both_results
            s_opt_all_results=s_opt_results
            # random_all_results=random_results
            logging_all_results=logging_results
            parips_all_surrogate=parips_surrogate
            surips_all_surrogate=surips_surrogate
            sdr_all_surrogate=sdr_surrogate
            sdr_both_all_surrogate=sdr_both_surrogate
            s_opt_all_surrogate=s_opt_surrogate
            # random_all_surrogate=random_surrogate
            logging_all_surrogate=logging_surrogate
            parips_all_full = parips_full
            surips_all_full = surips_full
            sdr_all_full = sdr_full
            sdr_both_all_full = sdr_both_full
            s_opt_all_full = s_opt_full
            # random_all_full = random_full
            logging_all_full = logging_full
            
        else:
            parips_all_results=np.vstack([parips_all_results, parips_results])
            surips_all_results=np.vstack([surips_all_results, surips_results])
            sdr_all_results=np.vstack([sdr_all_results, sdr_results])
            sdr_both_all_results=np.vstack([sdr_both_all_results, sdr_both_results])
            s_opt_all_results=np.vstack([s_opt_all_results, s_opt_results])
            # random_all_results=np.vstack([random_all_results, random_results])
            logging_all_results=np.vstack([logging_all_results, logging_results])
            parips_all_surrogate=np.vstack([parips_all_surrogate, parips_surrogate])
            surips_all_surrogate=np.vstack([surips_all_surrogate, surips_surrogate])
            sdr_all_surrogate=np.vstack([sdr_all_surrogate, sdr_surrogate])
            sdr_both_all_surrogate=np.vstack([sdr_both_all_surrogate, sdr_both_surrogate])
            s_opt_all_surrogate=np.vstack([s_opt_all_surrogate, s_opt_surrogate])
            # random_all_surrogate=np.vstack([random_all_surrogate, random_surrogate])
            logging_all_surrogate=np.vstack([logging_all_surrogate, logging_surrogate])
            parips_all_full = np.vstack([parips_all_full, parips_full])
            surips_all_full = np.vstack([surips_all_full, surips_full])
            sdr_all_full = np.vstack([sdr_all_full, sdr_full])
            sdr_both_all_full = np.vstack([sdr_both_all_full, sdr_both_full])
            s_opt_all_full = np.vstack([s_opt_all_full, s_opt_full])
            # random_all_full = random_full
            logging_all_full = np.vstack([logging_all_full, logging_full])
        outer_bar.update(1)
        elapsed_time = time.time() - start_time
        remaining_time = (elapsed_time / (sim + 1)) * (n_sims - (sim + 1))
        tqdm.write(f"\rSimulations: {outer_bar.n}/{outer_bar.total} - "
                   f"{tqdm.format_interval(elapsed_time)}<{tqdm.format_interval(remaining_time)}")

parips_means  = parips_all_results.mean(axis=0)
surips_means  = surips_all_results.mean(axis=0)
sdr_means  = sdr_all_results.mean(axis=0)
sdr_both_means  = sdr_both_all_results.mean(axis=0)
s_opt_means  = s_opt_all_results.mean(axis=0)
# random_means  = random_all_results.mean(axis=0)
logging_means  = logging_all_results.mean(axis=0)

parips_s_means  = parips_all_surrogate.mean(axis=0)
surips_s_means  = surips_all_surrogate.mean(axis=0)
sdr_s_means  = sdr_all_surrogate.mean(axis=0)
sdr_both_s_means  = sdr_both_all_surrogate.mean(axis=0)
s_opt_s_means  = s_opt_all_surrogate.mean(axis=0)
# random_s_means  = random_all_surrogate.mean(axis=0)
logging_s_means  = logging_all_surrogate.mean(axis=0)

parips_full_means = parips_all_full.mean(axis=0)
surips_full_means = surips_all_full.mean(axis=0)
sdr_full_means = sdr_all_full.mean(axis=0)
sdr_both_full_means = sdr_both_all_full.mean(axis=0)
s_opt_full_means = s_opt_all_full.mean(axis=0)
# random_full_means = random_all_full.mean(axis=0)
logging_full_means = logging_all_full.mean(axis=0)


parips_stds = parips_all_results.std(axis=0)
surips_stds = surips_all_results.std(axis=0)
sdr_stds = sdr_all_results.std(axis=0)
sdr_both_stds = sdr_both_all_results.std(axis=0)
s_opt_stds = s_opt_all_results.std(axis=0)
# random_stds = random_all_results.std(axis=0)
logging_stds = logging_all_results.std(axis=0)

parips_s_stds = parips_all_surrogate.std(axis=0)
surips_s_stds = surips_all_surrogate.std(axis=0)
sdr_s_stds = sdr_all_surrogate.std(axis=0)
sdr_both_s_stds = sdr_both_all_surrogate.std(axis=0)
s_opt_s_stds = s_opt_all_surrogate.std(axis=0)
# random_s_stds = random_all_surrogate.std(axis=0)
logging_s_stds = logging_all_surrogate.std(axis=0)

parips_full_stds = parips_all_full.std(axis=0)
surips_full_stds = surips_all_full.std(axis=0)
sdr_full_stds = sdr_all_full.std(axis=0)
sdr_both_full_stds = sdr_both_all_full.std(axis=0)
s_opt_full_stds = s_opt_all_full.std(axis=0)
# random_full_stds = random_all_full.std(axis=0)
logging_full_stds = logging_all_full.std(axis=0)

# 標準誤差を計算
parips_stes = parips_stds / np.sqrt(n_sims)
surips_stes = surips_stds / np.sqrt(n_sims)
sdr_stes = sdr_stds / np.sqrt(n_sims)
sdr_both_stes = sdr_both_stds / np.sqrt(n_sims)
s_opt_stes = s_opt_stds / np.sqrt(n_sims)
# random_stes = random_stds / np.sqrt(n_sims)
logging_stes = logging_stds / np.sqrt(n_sims)

parips_s_stes = parips_s_stds / np.sqrt(n_sims)
surips_s_stes = surips_s_stds / np.sqrt(n_sims)
sdr_s_stes = sdr_s_stds / np.sqrt(n_sims)
sdr_both_s_stes = sdr_both_s_stds / np.sqrt(n_sims)
s_opt_s_stes = s_opt_s_stds / np.sqrt(n_sims)
# random_s_stes = random_s_stds / np.sqrt(n_sims)
logging_s_stes = logging_s_stds / np.sqrt(n_sims)

parips_full_stes = parips_full_stds / np.sqrt(n_sims)
surips_full_stes = surips_full_stds / np.sqrt(n_sims)
sdr_full_stes = sdr_full_stds / np.sqrt(n_sims)
sdr_both_full_stes = sdr_both_full_stds / np.sqrt(n_sims)
s_opt_full_stes = s_opt_full_stds / np.sqrt(n_sims)
# random_full_stes = random_full_stds / np.sqrt(n_sims)
logging_full_stes = logging_full_stds / np.sqrt(n_sims)

# 95% 信頼区間の計算
confidence = 1.96
parips_conf_intervals = (parips_stes * confidence) 
surips_conf_intervals = (surips_stes * confidence) 
sdr_conf_intervals = (sdr_stes * confidence) 
sdr_both_conf_intervals = (sdr_both_stes * confidence) 
s_opt_conf_intervals = (s_opt_stes * confidence) 
# random_conf_intervals = (random_stes * confidence) 
logging_conf_intervals = (logging_stes * confidence) 

parips_s_conf_intervals = (parips_s_stes * confidence) 
surips_s_conf_intervals = (surips_s_stes * confidence) 
sdr_s_conf_intervals = (sdr_s_stes * confidence) 
sdr_both_s_conf_intervals = (sdr_both_s_stes * confidence) 
s_opt_s_conf_intervals = (s_opt_s_stes * confidence) 
# random_s_conf_intervals = (random_s_stes * confidence) 
logging_s_conf_intervals = (logging_s_stes * confidence) 

parips_full_conf_intervals = (parips_full_stes * confidence) 
surips_full_conf_intervals = (surips_full_stes * confidence) 
sdr_full_conf_intervals = (sdr_full_stes * confidence) 
sdr_both_full_conf_intervals = (sdr_both_full_stes * confidence) 
s_opt_full_conf_intervals = (s_opt_full_stes * confidence) 
# random_full_conf_intervals = (random_full_stes * confidence) 
logging_full_conf_intervals = (logging_full_stes * confidence) 

reward_std_list=[str(0.0) if i==0.001 else str(i) for i in reward_std_list]

#save results in csv with s_opt
if not os.path.exists(f'../../results/synthetic/{exp_name}/data'):
    os.makedirs(f'../../results/synthetic/{exp_name}/data')
# results should contain the ones without _s
results = pd.DataFrame({'reward_std': reward_std_list, 'parips_means ': parips_means , 'parips_conf_intervals': parips_conf_intervals, 'surips_means ': surips_means , 'surips_conf_intervals': surips_conf_intervals, 'sdr_means ': sdr_means , 'sdr_conf_intervals': sdr_conf_intervals, 'sdr_both_means ': sdr_both_means , 'sdr_both_conf_intervals': sdr_both_conf_intervals, 's_opt_means ': s_opt_means , 's_opt_conf_intervals': s_opt_conf_intervals, 'logging_means ': logging_means , 'logging_conf_intervals': logging_conf_intervals})
results.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}.csv', index=False)
results_s = pd.DataFrame({'reward_std': reward_std_list, 'parips_s_means ': parips_s_means, 'parips_s_conf_intervals': parips_s_conf_intervals, 'surips_s_means ': surips_s_means, 'surips_s_conf_intervals': surips_s_conf_intervals, 'sdr_s_means ': sdr_s_means, 'sdr_s_conf_intervals': sdr_s_conf_intervals, 'sdr_both_s_means ': sdr_both_s_means, 'sdr_both_s_conf_intervals': sdr_both_conf_intervals, 's_opt_s_means ': s_opt_s_means, 's_opt_s_conf_intervals': s_opt_s_conf_intervals, 'logging_s_means ': logging_s_means, 'logging_s_conf_intervals': logging_conf_intervals})
results_s.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}-s.csv', index=False)
results_full = pd.DataFrame({'reward_std': reward_std_list, 'parips_full_means ': parips_full_means, 'parips_full_conf_intervals': parips_full_conf_intervals, 'surips_full_means ': surips_full_means, 'surips_full_conf_intervals': surips_full_conf_intervals, 'sdr_full_means ': sdr_full_means, 'sdr_full_conf_intervals': sdr_full_conf_intervals, 'sdr_both_full_means ': sdr_both_full_means, 'sdr_both_full_conf_intervals': sdr_both_full_conf_intervals, 's_opt_full_means ': s_opt_full_means, 's_opt_full_conf_intervals': s_opt_full_conf_intervals, 'logging_full_means ': logging_full_means, 'logging_full_conf_intervals': logging_full_conf_intervals})
results_full.to_csv(f'../../results/synthetic/{exp_name}/data/{save_file_name}-full.csv', index=False)


# Plotting for Relative Primary Reward Policy Value
plt.figure(figsize=(10, 6))

plt.plot(reward_std_list, parips_means , label='r-ParDR', color='blue')
plt.fill_between(reward_std_list, parips_means  - parips_conf_intervals, parips_means  + parips_conf_intervals, color='blue', alpha=0.2)

plt.plot(reward_std_list, surips_means , label='r-SurDR', color='red')
plt.fill_between(reward_std_list, surips_means  - surips_conf_intervals, surips_means  + surips_conf_intervals, color='red', alpha=0.2)

plt.plot(reward_std_list, sdr_means , label='r-TwoDR', color='orange')
plt.fill_between(reward_std_list, sdr_means  - sdr_conf_intervals, sdr_means  + sdr_conf_intervals, color='orange', alpha=0.2)

# Additional plots
plt.plot(reward_std_list, sdr_both_means , label='Ours', color='green')
plt.fill_between(reward_std_list, sdr_both_means  - sdr_both_conf_intervals, sdr_both_means  + sdr_both_conf_intervals, color='green', alpha=0.2)

plt.plot(reward_std_list, s_opt_means , label='s-SurDR', color='purple')
plt.fill_between(reward_std_list, s_opt_means  - s_opt_conf_intervals, s_opt_means  + s_opt_conf_intervals, color='purple', alpha=0.2)

plt.plot(reward_std_list, logging_means , label='Logging', color='saddlebrown')
plt.fill_between(reward_std_list, logging_means  - logging_conf_intervals, logging_means  + logging_conf_intervals, color='saddlebrown', alpha=0.2)

# Configuring plot details
plt.xlabel('Noise of r', fontsize=18)
plt.ylabel('Relative Primary Reward Policy Value', fontsize=18)
plt.title('Relative Primary Reward Policy Value vs Noise of r', fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}.png')

# Plotting for Relative Surrogate Policy Value
plt.figure(figsize=(10, 6))

# Existing plots
plt.plot(reward_std_list, parips_s_means, label='r-ParDR', color='blue')
plt.fill_between(reward_std_list, parips_s_means - parips_s_conf_intervals, parips_s_means + parips_s_conf_intervals, color='blue', alpha=0.2)

plt.plot(reward_std_list, surips_s_means, label='r-SurDR', color='red')
plt.fill_between(reward_std_list, surips_s_means - surips_s_conf_intervals, surips_s_means + surips_s_conf_intervals, color='red', alpha=0.2)
# Additional plots
plt.plot(reward_std_list, sdr_s_means, label='r-TwoDR', color='orange')
plt.fill_between(reward_std_list, sdr_s_means - sdr_s_conf_intervals, sdr_s_means + sdr_s_conf_intervals, color='orange', alpha=0.2)

# Additional plots
plt.plot(reward_std_list, sdr_both_s_means , label='Ours', color='green')
plt.fill_between(reward_std_list, sdr_both_s_means  - sdr_both_s_conf_intervals, sdr_both_s_means  + sdr_both_s_conf_intervals, color='green', alpha=0.2)

plt.plot(reward_std_list, s_opt_s_means , label='s-SurDR', color='purple')
plt.fill_between(reward_std_list, s_opt_s_means  - s_opt_s_conf_intervals, s_opt_s_means  + s_opt_s_conf_intervals, color='purple', alpha=0.2)

plt.plot(reward_std_list, logging_s_means , label='Logging', color='saddlebrown')
plt.fill_between(reward_std_list, logging_s_means  - logging_s_conf_intervals, logging_s_means  + logging_s_conf_intervals, color='saddlebrown', alpha=0.2)

# Configuring plot details
plt.xlabel('Noise of r', fontsize=18)
plt.ylabel('Relative Surrogate Policy Value', fontsize=18)
plt.title('Relative Surrogate Policy Value vs Noise of r', fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}-s.png')


plt.figure(figsize=(10, 6))

plt.plot(reward_std_list, parips_full_means, label='r-ParDR', color='blue')
plt.fill_between(reward_std_list, parips_full_means - parips_full_conf_intervals, parips_full_means + parips_full_conf_intervals, color='blue', alpha=0.2)

plt.plot(reward_std_list, surips_full_means, label='r-SurDR', color='red')
plt.fill_between(reward_std_list, surips_full_means - surips_full_conf_intervals, surips_full_means + surips_full_conf_intervals, color='red', alpha=0.2)
# Additional plots
plt.plot(reward_std_list, sdr_full_means, label='r-TwoDR', color='orange')
plt.fill_between(reward_std_list, sdr_full_means - sdr_full_conf_intervals, sdr_full_means + sdr_full_conf_intervals, color='orange', alpha=0.2)

plt.plot(reward_std_list, sdr_both_full_means, label='Ours', color='green')
plt.fill_between(reward_std_list, sdr_both_full_means - sdr_both_full_conf_intervals, sdr_both_full_means + sdr_both_full_conf_intervals, color='green', alpha=0.2)

plt.plot(reward_std_list, s_opt_full_means, label='s-SurDR', color='purple')
plt.fill_between(reward_std_list, s_opt_full_means - s_opt_full_conf_intervals, s_opt_full_means + s_opt_full_conf_intervals, color='purple', alpha=0.2)

plt.plot(reward_std_list, logging_full_means , label='Logging', color='saddlebrown')
plt.fill_between(reward_std_list, logging_full_means  - logging_full_conf_intervals, logging_full_means  + logging_full_conf_intervals, color='saddlebrown', alpha=0.2)

plt.xlabel('Noise of r', fontsize=18)
plt.ylabel('Full Policy Value', fontsize=18)
plt.title('Full Policy Value vs Noise of r', fontsize=18)
plt.legend(fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.savefig(f'../../results/synthetic/{exp_name}/{save_file_name}-full.png')