from omegaconf import DictConfig, OmegaConf
import hydra
import os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
import seaborn as sns
from sklearn.neural_network import MLPRegressor
from sklearn.neural_network import MLPClassifier

import obp
from obp.dataset import(
    linear_reward_function,
    logistic_reward_function,
    linear_behavior_policy,
)

from obp.ope import(
    # SlateOffPolicyEvaluation,
    RegressionModel,
    SlateStandardIPS as IPS,
    SlateIndependentIPS as IIPS,
    SlateRewardInteractionIPS as RIPS,
)

from dataset_real import RealSlateBanditDataset
from dataset_real import linear_behavior_policy_logit
from estimator import(
    ClickBasedIPS as CIPS,
    ClickBasedDR as CDR,
) 
from ope import OffPolicyEvaluation
from plot import(
    plot,
    plot_normalize,
)


@hydra.main(config_path="../conf",config_name="config", version_base="1.1")
def main(cfg: DictConfig) -> None:
    if cfg.setting.real.deterministic_user_threshold == "-inf":
        cfg.setting.real.deterministic_user_threshold = -np.inf
    if cfg.setting.real.deterministic_user_threshold == "inf":
        cfg.setting.real.deterministic_user_threshold = np.inf

    np.random.seed(cfg.setting.real.random_state)
    num_runs = cfg.setting.real.num_runs
    len_list_list = cfg.setting.real.len_list_list
    num_data = cfg.setting.real.num_data
    deterministic_user_threshold = cfg.setting.real.deterministic_user_threshold


    result_df_list = []
    for len_list in len_list_list:
        if cfg.setting.real.reward_type_conversion == "continuous":
            dataset = RealSlateBanditDataset(
                n_unique_action=cfg.setting.real.n_unique_action,
                len_list=len_list,
                dim_context=cfg.setting.real.dim_context,
                reward_type=cfg.setting.real.reward_type,
                reward_structure=cfg.setting.real.reward_structure,
                decay_function=cfg.setting.real.decay_function,
                base_reward_function=logistic_reward_function,
                base_reward_function_conversion=linear_reward_function,
                behavior_policy_function=linear_behavior_policy_logit,
                is_factorizable=cfg.setting.real.is_factorizable,
                random_state=cfg.setting.real.random_state,
                reward_type_conversion=cfg.setting.real.reward_type_conversion,
                reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
                deterministic_user_threshold=deterministic_user_threshold,
                effect_from_ranking=cfg.setting.real.effect_from_ranking,
            )
        else: #binary
            dataset = RealSlateBanditDataset(
                n_unique_action=cfg.setting.real.n_unique_action,
                len_list=len_list,
                dim_context=cfg.setting.real.dim_context,
                reward_type=cfg.setting.real.reward_type,
                reward_structure=cfg.setting.real.reward_structure,
                decay_function=cfg.setting.real.decay_function,
                base_reward_function=logistic_reward_function,
                base_reward_function_conversion=logistic_reward_function,
                behavior_policy_function=linear_behavior_policy_logit,
                is_factorizable=cfg.setting.real.is_factorizable,
                random_state=cfg.setting.real.random_state,
                reward_type_conversion=cfg.setting.real.reward_type_conversion,
                reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
                deterministic_user_threshold=deterministic_user_threshold,
                effect_from_ranking=cfg.setting.real.effect_from_ranking,
            )

        #evaluation policy
        n_test = cfg.setting.real.n_test
        fixed_context = dataset.fixed_context
        user_idx = np.random.choice(fixed_context.shape[0], size=n_test)
        context = fixed_context[user_idx]
        
        if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
                evaluation_policy_logit = linear_reward_function(
                    context=context,
                    action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
                    random_state=cfg.setting.real.random_state,
                )
        else:
            evaluation_policy_logit = linear_behavior_policy_logit(
                context=context,
                action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
                random_state=cfg.setting.real.random_state,
                tau=cfg.setting.real.tau_pi_e,
            )
            
        pi_e_value = dataset.calc_ground_truth_policy_value_epsilon_greedy(
            context=context,
            evaluation_policy_logit_=evaluation_policy_logit,
            eps=cfg.setting.real.eps,
            user_idx=user_idx,
        )
        print("pi_e_value", pi_e_value)

        estimated_policy_value_list = []
        for _ in tqdm(range(num_runs), desc=f"len_list={len_list}..."):
            validation_bandit_data = dataset.obtain_batch_bandit_feedback(
                n_rounds=num_data,
                # clip_logit_value=700.0,
            )
            # print("expected_reward_factual", validation_bandit_data["expected_reward_factual"])
            # print("expected_reward_factual_click", validation_bandit_data["expected_reward_factual_click"])
            # print("expected_reward_factual_conversion", validation_bandit_data["expected_reward_factual_conversion"])
            # print("action", validation_bandit_data["action"])
            # print("pscore_item_position", validation_bandit_data["pscore_item_position"])
            # print("pscore_cascade", validation_bandit_data["pscore_cascade"])
            # print("pscore", validation_bandit_data["pscore"])
            
            if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
                evaluation_policy_logit = linear_reward_function(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                )
            else:
                evaluation_policy_logit = linear_behavior_policy_logit(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                    tau=cfg.setting.real.tau_pi_e,
                )

            (
                evaluation_policy_pscore, 
                evaluation_policy_pscore_item_position, 
                evaluation_policy_pscore_cascade,
                evaluation_policy_p_click, 
                p_click_pi_e,
            )  = dataset.obtain_pscore_given_evaluation_policy_logit_epsilon_greedy(
                context=validation_bandit_data["context"],
                action=validation_bandit_data["action"],
                evaluation_policy_logit_=evaluation_policy_logit,
                eps=cfg.setting.real.eps,
            )
            
            #obtain regression model
            click_probability_true = validation_bandit_data["expected_reward_factual_click"] 
            ################################################
            reg_model = RegressionModel(
                n_actions=cfg.setting.real.n_unique_action, 
                base_model=MLPRegressor(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state),
            )
            mask = (validation_bandit_data["reward_click"]==1)
            reg_model.fit(
                context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)[mask], # context; x
                action=validation_bandit_data["action"][mask], # action; a
                reward=validation_bandit_data["reward"][mask], # reward; r
            )
            # estimated_conversion (n_rounds*len_list, n_unique_actions, 1)
            estimated_conversion = reg_model.predict(
                context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)
            )
            estimated_conversion_for_dm_term = reg_model.predict(
                context=validation_bandit_data["context"]
            )[:,:,0]
            # print("estimated_conversion", estimated_conversion.shape)
            estimated_conversion_factual = estimated_conversion[np.arange(dataset.len_list*validation_bandit_data["context"].shape[0]),validation_bandit_data["action"],0]
            estimated_CR_factual = click_probability_true * estimated_conversion_factual #true_click * estimated conversion
            ################################################
            ################################################
            #estimate click probability
            click_model=MLPClassifier(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state)
            X_train = np.concatenate([validation_bandit_data["context"], validation_bandit_data["action"].reshape(-1,dataset.len_list)], axis=1)
            y_train = validation_bandit_data["reward_click"].reshape(-1,dataset.len_list)
            click_model.fit(
                X=X_train, 
                y=y_train, 
            )
            
            (
                estimated_behavior_policy_p_click, 
                estimated_evaluation_policy_p_click,
                p_click_pi_e_by_click_model, #p_c(x,a,pi_e)
            )  = dataset.obtain_p_click_pi_given_estimated_click_probability(
                        context=validation_bandit_data["context"],
                        action=validation_bandit_data["action"],
                        click_model=click_model,
                        evaluation_policy_logit_type=cfg.setting.real.evaluation_policy_logit,
                        eps=cfg.setting.real.eps,
                        tau=cfg.setting.real.tau_pi_e,
                )
            click_probability_factual_by_click_model = click_model.predict_proba(X_train).reshape(validation_bandit_data["action"].shape[0])
            estimated_CR_factual_by_click_model = click_probability_factual_by_click_model * estimated_conversion_factual #true_click * estimated conversion
            
            dm_term = (p_click_pi_e*estimated_conversion_for_dm_term).sum()
            dm_term_by_click_model = (p_click_pi_e_by_click_model*estimated_conversion_for_dm_term).sum()
            # print((evaluation_policy_p_click*estimated_conversion_factual).sum())
            # print(dm_term)
            # print(evaluation_policy_p_click)
            # print(p_click_pi_e)
            ################################################

            ope = OffPolicyEvaluation(
                bandit_feedback=validation_bandit_data,
                ope_estimators=[
                        IPS(estimator_name="IPS", len_list=len_list), 
                        IIPS(estimator_name="IIPS", len_list=len_list),  
                        RIPS(estimator_name="RIPS", len_list=len_list),
                        CIPS(estimator_name="CIPS", len_list=len_list),
                        CDR(estimator_name="CDR", len_list=len_list),
                        CIPS(estimator_name="CIPS (estimate)", len_list=len_list, use_estimated_click_model=True),
                        CDR(estimator_name="CDR (estimate)", len_list=len_list, use_estimated_click_model=True),
                    ]
            )

            estimated_policy_values = ope.estimate_policy_values(
                evaluation_policy_pscore=evaluation_policy_pscore,
                evaluation_policy_pscore_item_position=evaluation_policy_pscore_item_position,
                evaluation_policy_pscore_cascade=evaluation_policy_pscore_cascade,
                evaluation_policy_p_click=evaluation_policy_p_click, #pc(x,a,pi)
                behavior_policy_p_click=validation_bandit_data["p_click_factual_pi_0"], #pc(x,a,pi_0)
                estimated_conversion_factual=estimated_conversion_factual, #p_r
                q_hat=estimated_CR_factual, # q_hat
                estimated_behavior_policy_p_click= estimated_behavior_policy_p_click,
                estimated_evaluation_policy_p_click=estimated_evaluation_policy_p_click,
                q_hat_by_estimated_click_model=estimated_CR_factual_by_click_model,
                dm_term=dm_term,
                dm_term_by_click_model=dm_term_by_click_model,
            )
            estimated_policy_value_list.append(estimated_policy_values)
        
        #summarize result
        result_df = (
            DataFrame(DataFrame(estimated_policy_value_list).stack())
            .reset_index(1)
            .rename(columns={"level_1": "est", 0: "value"})
        )
        result_df["len_list"] = len_list
        result_df["pi_e_value"] = pi_e_value
        result_df["se"] = (result_df.value - pi_e_value) ** 2
        result_df["bias"] = 0.0
        result_df["variance"] = 0.0

        sample_mean = DataFrame(result_df.groupby(["est"]).mean().value).reset_index()
        for est_ in sample_mean["est"]:
            estimates = result_df.loc[result_df["est"] == est_, "value"].values
            mean_estimates = sample_mean.loc[sample_mean["est"] == est_, "value"].values
            mean_estimates = np.ones_like(estimates) * mean_estimates
            result_df.loc[result_df["est"] == est_, "bias"] = (
                pi_e_value - mean_estimates
            ) ** 2
            result_df.loc[result_df["est"] == est_, "variance"] = (
                estimates - mean_estimates
            ) ** 2
        result_df_list.append(result_df)
        print("max_iw", (evaluation_policy_pscore/ validation_bandit_data["pscore"]).max())
        print("max_iw_CIPS", (evaluation_policy_p_click/ validation_bandit_data["p_click_factual_pi_0"]).max())
        result_df = pd.concat(result_df_list).reset_index(level=0)
        result_df.to_csv("len_list.csv")
        plot(vary_list=len_list_list, result_df=result_df, variable_name="len_list")
        plot_normalize(vary_list=len_list_list, result_df=result_df, variable_name="len_list")

        tqdm.write("=====" * 15)

    result_df = pd.concat(result_df_list).reset_index(level=0)
    result_df.to_csv("len_list.csv")

    plot(vary_list=len_list_list, result_df=result_df, variable_name="len_list")
    plot_normalize(vary_list=len_list_list, result_df=result_df, variable_name="len_list")

if __name__ == "__main__":
    main()