from omegaconf import DictConfig, OmegaConf
import hydra
import os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
import seaborn as sns
from sklearn.neural_network import MLPRegressor
from sklearn.neural_network import MLPClassifier

import obp
from obp.dataset import(
    linear_reward_function,
    logistic_reward_function,
    linear_behavior_policy,
)

from obp.ope import(
    # SlateOffPolicyEvaluation,
    RegressionModel,
    SlateStandardIPS as IPS,
    SlateIndependentIPS as IIPS,
    SlateRewardInteractionIPS as RIPS,
)

from dataset_real import RealSlateBanditDataset
from dataset_real import linear_behavior_policy_logit
from estimator import(
    ClickBasedIPS as CIPS,
    ClickBasedDR as CDR,
) 
from ope import OffPolicyEvaluation
from plot import(
    plot,
    plot_normalize,
)


@hydra.main(config_path="../conf",config_name="config", version_base="1.1")
def main(cfg: DictConfig) -> None:
    np.random.seed(cfg.setting.real.random_state)
    num_runs = cfg.setting.real.num_runs
    deterministic_user_threshold_list = cfg.setting.real.deterministic_user_threshold_list
    num_data = cfg.setting.real.num_data


    result_df_list = []
    for deterministic_user_threshold in deterministic_user_threshold_list:
        if cfg.setting.real.reward_type_conversion == "continuous":
            dataset = RealSlateBanditDataset(
                n_unique_action=cfg.setting.real.n_unique_action,
                len_list=cfg.setting.real.len_list,
                dim_context=cfg.setting.real.dim_context,
                reward_type=cfg.setting.real.reward_type,
                reward_structure=cfg.setting.real.reward_structure,
                decay_function=cfg.setting.real.decay_function,
                base_reward_function=logistic_reward_function,
                base_reward_function_conversion=linear_reward_function,
                behavior_policy_function=linear_behavior_policy_logit,
                is_factorizable=cfg.setting.real.is_factorizable,
                random_state=cfg.setting.real.random_state,
                reward_type_conversion=cfg.setting.real.reward_type_conversion,
                reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
                deterministic_user_threshold=deterministic_user_threshold,
                effect_from_ranking=cfg.setting.real.effect_from_ranking,
            )
        else: #binary
            dataset = RealSlateBanditDataset(
                n_unique_action=cfg.setting.real.n_unique_action,
                len_list=cfg.setting.real.len_list,
                dim_context=cfg.setting.real.dim_context,
                reward_type=cfg.setting.real.reward_type,
                reward_structure=cfg.setting.real.reward_structure,
                decay_function=cfg.setting.real.decay_function,
                base_reward_function=logistic_reward_function,
                base_reward_function_conversion=logistic_reward_function,
                behavior_policy_function=linear_behavior_policy_logit,
                is_factorizable=cfg.setting.real.is_factorizable,
                random_state=cfg.setting.real.random_state,
                reward_type_conversion=cfg.setting.real.reward_type_conversion,
                reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
                deterministic_user_threshold=deterministic_user_threshold,
                effect_from_ranking=cfg.setting.real.effect_from_ranking,
            )

        #evaluation policy
        n_test = cfg.setting.real.n_test
        fixed_context = dataset.fixed_context
        user_idx = np.random.choice(fixed_context.shape[0], size=n_test)
        context = fixed_context[user_idx]
        
        if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
                evaluation_policy_logit = linear_reward_function(
                    context=context,
                    action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
                    random_state=cfg.setting.real.random_state,
                )
        else:
            evaluation_policy_logit = linear_behavior_policy_logit(
                context=context,
                action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
                random_state=cfg.setting.real.random_state,
                tau=cfg.setting.real.tau_pi_e,
            )
            
        pi_e_value = dataset.calc_ground_truth_policy_value_epsilon_greedy(
            context=context,
            evaluation_policy_logit_=evaluation_policy_logit,
            eps=cfg.setting.real.eps,
            user_idx=user_idx,
        )
        print("pi_e_value", pi_e_value)

        estimated_policy_value_list = []
        for _ in tqdm(range(num_runs), desc=f"deterministic_user_threshold={deterministic_user_threshold}..."):
            validation_bandit_data = dataset.obtain_batch_bandit_feedback(
                n_rounds=num_data,
                # clip_logit_value=700.0,
            )
            # print("expected_reward_factual", validation_bandit_data["expected_reward_factual"])
            # print("expected_reward_factual_click", validation_bandit_data["expected_reward_factual_click"])
            # print("expected_reward_factual_conversion", validation_bandit_data["expected_reward_factual_conversion"])
            # print("action", validation_bandit_data["action"])
            # print("pscore_item_position", validation_bandit_data["pscore_item_position"])
            # print("pscore_cascade", validation_bandit_data["pscore_cascade"])
            # print("pscore", validation_bandit_data["pscore"])
            
            if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
                evaluation_policy_logit = linear_reward_function(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                )
            else:
                evaluation_policy_logit = linear_behavior_policy_logit(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                    tau=cfg.setting.real.tau_pi_e,
                )

            (
                evaluation_policy_pscore, 
                evaluation_policy_pscore_item_position, 
                evaluation_policy_pscore_cascade,
                evaluation_policy_p_click, 
                p_click_pi_e,
            )  = dataset.obtain_pscore_given_evaluation_policy_logit_epsilon_greedy(
                context=validation_bandit_data["context"],
                action=validation_bandit_data["action"],
                evaluation_policy_logit_=evaluation_policy_logit,
                eps=cfg.setting.real.eps,
            )
            
            #obtain regression model
            click_probability_true = validation_bandit_data["expected_reward_factual_click"] 
            ################################################
            reg_model = RegressionModel(
                n_actions=cfg.setting.real.n_unique_action, 
                base_model=MLPRegressor(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state),
            )
            mask = (validation_bandit_data["reward_click"]==1)
            reg_model.fit(
                context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)[mask], # context; x
                action=validation_bandit_data["action"][mask], # action; a
                reward=validation_bandit_data["reward"][mask], # reward; r
            )
            # estimated_conversion (n_rounds*len_list, n_unique_actions, 1)
            estimated_conversion = reg_model.predict(
                context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)
            )
            estimated_conversion_for_dm_term = reg_model.predict(
                context=validation_bandit_data["context"]
            )[:,:,0]
            # print("estimated_conversion", estimated_conversion.shape)
            estimated_conversion_factual = estimated_conversion[np.arange(dataset.len_list*validation_bandit_data["context"].shape[0]),validation_bandit_data["action"],0]
            # print(estimated_conversion_factual.shape)
            estimated_CR_factual = click_probability_true * estimated_conversion_factual #true_click * estimated conversion
            ################################################
            ################################################
            #estimate click probability
            click_model=MLPClassifier(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state)
            X_train = np.concatenate([validation_bandit_data["context"], validation_bandit_data["action"].reshape(-1,dataset.len_list)], axis=1)
            y_train = validation_bandit_data["reward_click"].reshape(-1,dataset.len_list)
            click_model.fit(
                X=X_train, 
                y=y_train, 
            )
            
            (
                estimated_behavior_policy_p_click, 
                estimated_evaluation_policy_p_click,
                p_click_pi_e_by_click_model, #p_c(x,a,pi_e)
            )  = dataset.obtain_p_click_pi_given_estimated_click_probability(
                        context=validation_bandit_data["context"],
                        action=validation_bandit_data["action"],
                        click_model=click_model,
                        evaluation_policy_logit_type=cfg.setting.real.evaluation_policy_logit,
                        eps=cfg.setting.real.eps,
                        tau=cfg.setting.real.tau_pi_e,
                )
            click_probability_factual_by_click_model = click_model.predict_proba(X_train).reshape(validation_bandit_data["action"].shape[0])
            estimated_CR_factual_by_click_model = click_probability_factual_by_click_model * estimated_conversion_factual #true_click * estimated conversion

            dm_term = (p_click_pi_e*estimated_conversion_for_dm_term).sum()
            dm_term_by_click_model = (p_click_pi_e_by_click_model*estimated_conversion_for_dm_term).sum()
            ################################################

            ope = OffPolicyEvaluation(
                bandit_feedback=validation_bandit_data,
                ope_estimators=[
                        IPS(estimator_name="IPS", len_list=cfg.setting.real.len_list), 
                        IIPS(estimator_name="IIPS", len_list=cfg.setting.real.len_list),  
                        RIPS(estimator_name="RIPS", len_list=cfg.setting.real.len_list),
                        CIPS(estimator_name="CIPS", len_list=cfg.setting.real.len_list),
                        CDR(estimator_name="CDR", len_list=cfg.setting.real.len_list),
                        CIPS(estimator_name="CIPS (estimate)", len_list=cfg.setting.real.len_list, use_estimated_click_model=True),
                        CDR(estimator_name="CDR (estimate)", len_list=cfg.setting.real.len_list, use_estimated_click_model=True),
                    ]
            )

            estimated_policy_values = ope.estimate_policy_values(
                evaluation_policy_pscore=evaluation_policy_pscore,
                evaluation_policy_pscore_item_position=evaluation_policy_pscore_item_position,
                evaluation_policy_pscore_cascade=evaluation_policy_pscore_cascade,
                evaluation_policy_p_click=evaluation_policy_p_click, #pc(x,a,pi)
                behavior_policy_p_click=validation_bandit_data["p_click_factual_pi_0"], #pc(x,a,pi_0)
                estimated_conversion_factual=estimated_conversion_factual, #p_r
                q_hat=estimated_CR_factual, # q_hat
                estimated_behavior_policy_p_click= estimated_behavior_policy_p_click,
                estimated_evaluation_policy_p_click=estimated_evaluation_policy_p_click,
                q_hat_by_estimated_click_model=estimated_CR_factual_by_click_model,
                dm_term=dm_term,
                dm_term_by_click_model=dm_term_by_click_model,
            )
            estimated_policy_value_list.append(estimated_policy_values)
        
        #summarize result
        result_df = (
            DataFrame(DataFrame(estimated_policy_value_list).stack())
            .reset_index(1)
            .rename(columns={"level_1": "est", 0: "value"})
        )
        result_df["deterministic_user_threshold"] = deterministic_user_threshold
        result_df["pi_e_value"] = pi_e_value
        result_df["se"] = (result_df.value - pi_e_value) ** 2
        result_df["bias"] = 0.0
        result_df["variance"] = 0.0

        sample_mean = DataFrame(result_df.groupby(["est"]).mean().value).reset_index()
        for est_ in sample_mean["est"]:
            estimates = result_df.loc[result_df["est"] == est_, "value"].values
            mean_estimates = sample_mean.loc[sample_mean["est"] == est_, "value"].values
            mean_estimates = np.ones_like(estimates) * mean_estimates
            result_df.loc[result_df["est"] == est_, "bias"] = (
                pi_e_value - mean_estimates
            ) ** 2
            result_df.loc[result_df["est"] == est_, "variance"] = (
                estimates - mean_estimates
            ) ** 2
        result_df_list.append(result_df)
        print("max_iw", (evaluation_policy_pscore/ validation_bandit_data["pscore"]).max())
        print("max_iw_CIPS", (evaluation_policy_p_click/ validation_bandit_data["p_click_factual_pi_0"]).max())

        result_df = pd.concat(result_df_list).reset_index(level=0)
        result_df.to_csv("deterministic_user_threshold.csv")
        plot(vary_list=deterministic_user_threshold_list, result_df=result_df, variable_name="deterministic_user_threshold")
        plot_normalize(vary_list=deterministic_user_threshold_list, result_df=result_df, variable_name="deterministic_user_threshold")
        tqdm.write("=====" * 15)
    
    result_df = pd.concat(result_df_list).reset_index(level=0)
    result_df.to_csv("deterministic_user_threshold.csv")

    plot(vary_list=deterministic_user_threshold_list, result_df=result_df, variable_name="deterministic_user_threshold")
    plot_normalize(vary_list=deterministic_user_threshold_list, result_df=result_df, variable_name="deterministic_user_threshold")

if __name__ == "__main__":
    main()
