from utils.datasets_no_n_user import BanditDatasetWithSurrogate
from obp.policy import NNPolicyLearner
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from obp.policy import Random
import time
from sklearn.utils import check_random_state

exp_name = 'exp1'
print(torch.cuda.is_available())

p_o_list = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

n_sims = 200
hidden_layer_size = (150,)
batch_size = 64
beta = 0.4
alpha_noise = 0.00
n_rounds = 1000
reward_type = "continuous"
beta_data = -2
random_state = 2255

n_actions = 100
base_path = "../../"
path = "../../data/"
expected_rewards_df = pd.read_csv(path + "expected_rewards_df.csv")
user_features_df = pd.read_csv(path + "user_features_df.csv")

unique_user_ids = expected_rewards_df['user_id'].unique()

save_file_name = f"only_side"

with tqdm(total=n_sims, desc='Outer', position=0) as outer_bar:
    start_time = time.time()
    for sim in range(n_sims):
        rips_results = []
        rips_surrogate = []
        rips_full = []
        rdm_results = []
        rdm_surrogate = []
        rdm_full = []
        sips_results = []
        sips_surrogate = []
        sips_full = []
        sdm_results = []
        sdm_surrogate = []
        sdm_full = []

        random_ = check_random_state(random_state + sim)
        dataset = BanditDatasetWithSurrogate(n_actions=n_actions, reward_type=reward_type, beta=beta_data, p_o=0.0,
                                             alpha_noise=alpha_noise, random_state=random_state + sim)
        train_user_ids = random_.choice(unique_user_ids, size=int(0.7 * len(unique_user_ids)), replace=False)
        random_videoids = expected_rewards_df['video_id'].drop_duplicates().sample(n=n_actions, random_state=random_state + sim)
        expected_rewards_df = expected_rewards_df[expected_rewards_df['video_id'].isin(random_videoids)]
        train_df = expected_rewards_df[expected_rewards_df['user_id'].isin(train_user_ids)]
        validation_df = expected_rewards_df[~expected_rewards_df['user_id'].isin(train_user_ids)]

        bandit_feedback_train = dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds, expected_rewards_df=train_df,
                                                                     user_features_df=user_features_df, test=False)
        bandit_feedback_test = dataset.obtain_batch_bandit_feedback(expected_rewards_df=validation_df,
                                                                    user_features_df=user_features_df, test=True)

        # s-IPS
        sips_policy = NNPolicyLearner(
            n_actions=n_actions,
            dim_context=bandit_feedback_train["dim_context"],
            hidden_layer_size=hidden_layer_size,
            early_stopping=True,
            activation='relu',
            solver='adam',
            off_policy_objective='ipw',
            batch_size=batch_size,
            random_state=random_state + sim,
        )
        sips_policy.fit(
            context=bandit_feedback_train["contexts"],
            action=bandit_feedback_train["actions"],
            reward=bandit_feedback_train["f_s"],
            pscore=bandit_feedback_train["pscores"],
        )
        sips_policy_action_dist = sips_policy.predict(
            context=bandit_feedback_test["contexts"],
        )
        sips_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=sips_policy_action_dist
        )
        sips_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["f_sum"],
            action_dist=sips_policy_action_dist
        )
        sips_full_value = dataset.calc_full_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            expected_surrogate_reward=bandit_feedback_test["f_sum"],
            action_dist=sips_policy_action_dist,
            beta=beta,
        )
        sips_results = [sips_policy_value] * len(p_o_list)
        sips_surrogate = [sips_surrogate_value] * len(p_o_list)
        sips_full = [sips_full_value] * len(p_o_list)

        # s-DM
        sdm_policy = NNPolicyLearner(
            n_actions=n_actions,
            dim_context=bandit_feedback_train["dim_context"],
            hidden_layer_size=hidden_layer_size,
            early_stopping=True,
            activation='relu',
            solver='adam',
            off_policy_objective='dm',
            batch_size=batch_size,
            random_state=random_state + sim,
        )
        sdm_policy.fit(
            context=bandit_feedback_train["contexts"],
            action=bandit_feedback_train["actions"],
            reward=bandit_feedback_train["f_s"],
            pscore=bandit_feedback_train["pscores"],
        )
        sdm_policy_action_dist = sdm_policy.predict(
            context=bandit_feedback_test["contexts"],
        )
        sdm_policy_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            action_dist=sdm_policy_action_dist
        )
        sdm_surrogate_value = dataset.calc_ground_truth_policy_value(
            expected_reward=bandit_feedback_test["f_sum"],
            action_dist=sdm_policy_action_dist
        )
        sdm_full_value = dataset.calc_full_policy_value(
            expected_reward=bandit_feedback_test["all_q_x_a_f"],
            expected_surrogate_reward=bandit_feedback_test["f_sum"],
            action_dist=sdm_policy_action_dist,
            beta=beta,
        )
        sdm_results = [sdm_policy_value] * len(p_o_list)
        sdm_surrogate = [sdm_surrogate_value] * len(p_o_list)
        sdm_full = [sdm_full_value] * len(p_o_list)

        for p_o in p_o_list:
            dataset = BanditDatasetWithSurrogate(n_actions=n_actions, reward_type=reward_type, beta=beta_data, p_o=p_o,
                                                 alpha_noise=alpha_noise, random_state=random_state + sim)
            bandit_feedback_train = dataset.obtain_batch_bandit_feedback(n_rounds=n_rounds, expected_rewards_df=train_df,
                                                                         user_features_df=user_features_df, test=False)
            n_rounds_test = bandit_feedback_test["n_rounds"]

            # r-IPS
            if p_o == 0.0:
                rips_policy = Random(n_actions=n_actions, random_state=random_state + sim)
                rips_policy_action_dist = rips_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:
                rips_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=bandit_feedback_train["dim_context"],
                    hidden_layer_size=hidden_layer_size,
                    max_iter=200,
                    learning_rate_init=0.0001,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='ipw',
                    batch_size=batch_size,
                    random_state=random_state + sim,
                )
                rips_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                rips_policy_action_dist = rips_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            rips_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=rips_policy_action_dist
            )
            rips_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["f_sum"],
                action_dist=rips_policy_action_dist
            )
            rips_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rips_policy_action_dist,
                beta=beta,
            )
            rips_results.append(rips_policy_value)
            rips_surrogate.append(rips_surrogate_value)
            rips_full.append(rips_full_value)

            # r-DM
            if p_o == 0.0:
                rdm_policy = Random(n_actions=n_actions, random_state=random_state + sim)
                rdm_policy_action_dist = rdm_policy.compute_batch_action_dist(
                    n_rounds=n_rounds_test,
                )
            else:
                rdm_policy = NNPolicyLearner(
                    n_actions=n_actions,
                    dim_context=bandit_feedback_train["dim_context"],
                    hidden_layer_size=hidden_layer_size,
                    max_iter=200,
                    learning_rate_init=0.0001,
                    early_stopping=True,
                    activation='relu',
                    solver='adam',
                    off_policy_objective='dm',
                    batch_size=batch_size,
                    random_state=random_state + sim,
                )
                rdm_policy.fit(
                    context=bandit_feedback_train["obs_contexts"],
                    action=bandit_feedback_train["obs_actions"],
                    reward=bandit_feedback_train["obs_rewards"],
                    pscore=bandit_feedback_train["obs_pscores"],
                )
                rdm_policy_action_dist = rdm_policy.predict(
                    context=bandit_feedback_test["contexts"],
                )
            rdm_policy_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                action_dist=rdm_policy_action_dist
            )
            rdm_surrogate_value = dataset.calc_ground_truth_policy_value(
                expected_reward=bandit_feedback_test["f_sum"],
                action_dist=rdm_policy_action_dist
            )
            rdm_full_value = dataset.calc_full_policy_value(
                expected_reward=bandit_feedback_test["all_q_x_a_f"],
                expected_surrogate_reward=bandit_feedback_test["f_sum"],
                action_dist=rdm_policy_action_dist,
                beta=beta,
            )
            rdm_results.append(rdm_policy_value)
            rdm_surrogate.append(rdm_surrogate_value)
            rdm_full.append(rdm_full_value)

        # Storing results in the first simulation
        if sim == 0:
            rips_all_results = rips_results
            rips_all_surrogate = rips_surrogate
            rips_all_full = rips_full
            sips_all_results = sips_results
            sips_all_surrogate = sips_surrogate
            sips_all_full = sips_full
            rdm_all_results = rdm_results
            rdm_all_surrogate = rdm_surrogate
            rdm_all_full = rdm_full
            sdm_all_results = sdm_results
            sdm_all_surrogate = sdm_surrogate
            sdm_all_full = sdm_full
        else:
            rips_all_results = np.vstack([rips_all_results, rips_results])
            rips_all_surrogate = np.vstack([rips_all_surrogate, rips_surrogate])
            rips_all_full = np.vstack([rips_all_full, rips_full])
            sips_all_results = np.vstack([sips_all_results, sips_results])
            sips_all_surrogate = np.vstack([sips_all_surrogate, sips_surrogate])
            sips_all_full = np.vstack([sips_all_full, sips_full])
            rdm_all_results = np.vstack([rdm_all_results, rdm_results])
            rdm_all_surrogate = np.vstack([rdm_all_surrogate, rdm_surrogate])
            rdm_all_full = np.vstack([rdm_all_full, rdm_full])
            sdm_all_results = np.vstack([sdm_all_results, sdm_results])
            sdm_all_surrogate = np.vstack([sdm_all_surrogate, sdm_surrogate])
            sdm_all_full = np.vstack([sdm_all_full, sdm_full])

        outer_bar.update(1)
        elapsed_time = time.time() - start_time
        remaining_time = (elapsed_time / (sim + 1)) * (n_sims - (sim + 1))
        tqdm.write(f"\rSimulations: {outer_bar.n}/{outer_bar.total} - "
                   f"{tqdm.format_interval(elapsed_time)}<{tqdm.format_interval(remaining_time)}")

        if sim < 2:
            continue

        # 95% Confidence Interval calculation
        confidence = 1.96
        rips_means = rips_all_results.mean(axis=0)
        rips_s_means = rips_all_surrogate.mean(axis=0)
        rips_full_means = rips_all_full.mean(axis=0)
        rips_stds = rips_all_results.std(axis=0)
        rips_s_stds = rips_all_surrogate.std(axis=0)
        rips_full_stds = rips_all_full.std(axis=0)
        rips_stes = rips_stds / np.sqrt(n_sims)
        rips_s_stes = rips_s_stds / np.sqrt(n_sims)
        rips_full_stes = rips_full_stds / np.sqrt(n_sims)
        rips_conf_intervals = (rips_stes * confidence)
        rips_s_conf_intervals = (rips_s_stes * confidence)
        rips_full_conf_intervals = (rips_full_stes * confidence)
        sips_means = sips_all_results.mean(axis=0)
        sips_s_means = sips_all_surrogate.mean(axis=0)
        sips_full_means = sips_all_full.mean(axis=0)
        sips_stds = sips_all_results.std(axis=0)
        sips_s_stds = sips_all_surrogate.std(axis=0)
        sips_full_stds = sips_all_full.std(axis=0)
        sips_stes = sips_stds / np.sqrt(n_sims)
        sips_s_stes = sips_s_stds / np.sqrt(n_sims)
        sips_full_stes = sips_full_stds / np.sqrt(n_sims)
        sips_conf_intervals = (sips_stes * confidence)
        sips_s_conf_intervals = (sips_s_stes * confidence)
        sips_full_conf_intervals = (sips_full_stes * confidence)
        rdm_means = rdm_all_results.mean(axis=0)
        rdm_s_means = rdm_all_surrogate.mean(axis=0)
        rdm_full_means = rdm_all_full.mean(axis=0)
        rdm_stds = rdm_all_results.std(axis=0)
        rdm_s_stds = rdm_all_surrogate.std(axis=0)
        rdm_full_stds = rdm_all_full.std(axis=0)
        rdm_stes = rdm_stds / np.sqrt(n_sims)
        rdm_s_stes = rdm_s_stds / np.sqrt(n_sims)
        rdm_full_stes = rdm_full_stds / np.sqrt(n_sims)
        rdm_conf_intervals = (rdm_stes * confidence)
        rdm_s_conf_intervals = (rdm_s_stes * confidence)
        rdm_full_conf_intervals = (rdm_full_stes * confidence)
        sdm_means = sdm_all_results.mean(axis=0)
        sdm_s_means = sdm_all_surrogate.mean(axis=0)
        sdm_full_means = sdm_all_full.mean(axis=0)
        sdm_stds = sdm_all_results.std(axis=0)
        sdm_s_stds = sdm_all_surrogate.std(axis=0)
        sdm_full_stds = sdm_all_full.std(axis=0)
        sdm_stes = sdm_stds / np.sqrt(n_sims)
        sdm_s_stes = sdm_s_stds / np.sqrt(n_sims)
        sdm_full_stes = sdm_full_stds / np.sqrt(n_sims)
        sdm_conf_intervals = (sdm_stes * confidence)
        sdm_s_conf_intervals = (sdm_s_stes * confidence)
        sdm_full_conf_intervals = (sdm_full_stes * confidence)

        list_ = [str(i) for i in p_o_list]

        # Save results in CSV
        if not os.path.exists(f'{base_path}/results/realdata/{exp_name}/data'):
            os.makedirs(f'{base_path}/results/realdata/{exp_name}/data')

        # Save policy values
        results = pd.DataFrame({
            'list': list_,
            'rips_means': rips_means,
            'rips_conf_intervals': rips_conf_intervals,
            'rdm_means': rdm_means,
            'rdm_conf_intervals': rdm_conf_intervals,
            'sips_means': sips_means,
            'sips_conf_intervals': sips_conf_intervals,
            'sdm_means': sdm_means,
            'sdm_conf_intervals': sdm_conf_intervals
        })
        results.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}.csv', index=False)

        # Save surrogate values
        results_s = pd.DataFrame({
            'list': list_,
            'rips_s_means': rips_s_means,
            'rips_s_conf_intervals': rips_s_conf_intervals,
            'rdm_s_means': rdm_s_means,
            'rdm_s_conf_intervals': rdm_s_conf_intervals,
            'sips_s_means': sips_s_means,
            'sips_s_conf_intervals': sips_s_conf_intervals,
            'sdm_s_means': sdm_s_means,
            'sdm_s_conf_intervals': sdm_s_conf_intervals
        })
        results_s.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-s.csv', index=False)

        # Save full values
        results_full = pd.DataFrame({
            'list': list_,
            'rips_full_means': rips_full_means,
            'rips_full_conf_intervals': rips_full_conf_intervals,
            'rdm_full_means': rdm_full_means,
            'rdm_full_conf_intervals': rdm_full_conf_intervals,
            'sips_full_means': sips_full_means,
            'sips_full_conf_intervals': sips_full_conf_intervals,
            'sdm_full_means': sdm_full_means,
            'sdm_full_conf_intervals': sdm_full_conf_intervals
        })
        results_full.to_csv(f'{base_path}/results/realdata/{exp_name}/data/{save_file_name}-full.csv', index=False)

        # Plotting for Relative Primary Reward Policy Value
        plt.figure(figsize=(10, 6))

        # Plot r-IPS
        plt.plot(list_, rips_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_means - rips_conf_intervals, rips_means + rips_conf_intervals, color='purple', alpha=0.2)

        # Plot s-IPS
        plt.plot(list_, sips_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_means - sips_conf_intervals, sips_means + sips_conf_intervals, color='cyan', alpha=0.2)

        # Plot r-DM
        plt.plot(list_, rdm_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_means - rdm_conf_intervals, rdm_means + rdm_conf_intervals, color='magenta', alpha=0.2)

        # Plot s-DM
        plt.plot(list_, sdm_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_means - sdm_conf_intervals, sdm_means + sdm_conf_intervals, color='lime', alpha=0.2)

        # Configuring plot details
        plt.xlabel('Target Reward Observation Probability', fontsize=18)
        plt.ylabel('Relative Target Reward Policy Value', fontsize=18)
        plt.title('Relative Target Reward Policy Value vs Target Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}.png')

        # Plotting for Relative Surrogate Policy Value
        plt.figure(figsize=(10, 6))

        # Plot r-IPS
        plt.plot(list_, rips_s_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_s_means - rips_s_conf_intervals, rips_s_means + rips_s_conf_intervals, color='purple', alpha=0.2)

        # Plot s-IPS
        plt.plot(list_, sips_s_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_s_means - sips_s_conf_intervals, sips_s_means + sips_s_conf_intervals, color='cyan', alpha=0.2)

        # Plot r-DM
        plt.plot(list_, rdm_s_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_s_means - rdm_s_conf_intervals, rdm_s_means + rdm_s_conf_intervals, color='magenta', alpha=0.2)

        # Plot s-DM
        plt.plot(list_, sdm_s_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_s_means - sdm_s_conf_intervals, sdm_s_means + sdm_s_conf_intervals, color='lime', alpha=0.2)

        plt.xlabel('Target Reward Observation Probability', fontsize=18)
        plt.ylabel('Relative Surrogate Policy Value', fontsize=18)
        plt.title('Relative Surrogate Policy Value vs Target Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}-s.png')

        # Plotting for Relative True Policy Value (Full)
        plt.figure(figsize=(10, 6))

        # Plot r-IPS
        plt.plot(list_, rips_full_means, label='r-IPS', color='purple')
        plt.fill_between(list_, rips_full_means - rips_full_conf_intervals, rips_full_means + rips_full_conf_intervals, color='purple', alpha=0.2)

        # Plot s-IPS
        plt.plot(list_, sips_full_means, label='s-IPS', color='cyan')
        plt.fill_between(list_, sips_full_means - sips_full_conf_intervals, sips_full_means + sips_full_conf_intervals, color='cyan', alpha=0.2)

        # Plot r-DM
        plt.plot(list_, rdm_full_means, label='r-DM', color='magenta')
        plt.fill_between(list_, rdm_full_means - rdm_full_conf_intervals, rdm_full_means + rdm_full_conf_intervals, color='magenta', alpha=0.2)

        # Plot s-DM
        plt.plot(list_, sdm_full_means, label='s-DM', color='lime')
        plt.fill_between(list_, sdm_full_means - sdm_full_conf_intervals, sdm_full_means + sdm_full_conf_intervals, color='lime', alpha=0.2)

        plt.xlabel('Target Reward Observation Probability', fontsize=18)
        plt.ylabel('Relative Full Policy Value', fontsize=18)
        plt.title('Relative Full Policy Value vs Target Reward Observation Probability', fontsize=18)
        plt.legend(fontsize=16)
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.savefig(f'{base_path}/results/realdata/{exp_name}/{save_file_name}-full.png')
