from omegaconf import DictConfig, OmegaConf
import hydra
import os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame
from tqdm import tqdm
import seaborn as sns
from sklearn.neural_network import MLPRegressor
from sklearn.neural_network import MLPClassifier

import obp
from obp.dataset import(
    linear_reward_function,
    logistic_reward_function,
    linear_behavior_policy,
)

from obp.ope import(
    # SlateOffPolicyEvaluation,
    RegressionModel,
    SlateStandardIPS as IPS,
    SlateIndependentIPS as IIPS,
    SlateRewardInteractionIPS as RIPS,
)

from dataset_real_modify import RealSlateBanditDataset
from dataset_real_modify import linear_behavior_policy_logit
from estimator import(
    ClickBasedIPS as CIPS,
    ClickBasedDR as CDR,
) 
from ope import OffPolicyEvaluation
from plot import(
    plot,
    plot_normalize,
)

@hydra.main(config_path="../conf",config_name="config", version_base="1.1")
def main(cfg: DictConfig) -> None:
    if cfg.setting.real.deterministic_user_threshold == "-inf":
        cfg.setting.real.deterministic_user_threshold = -np.inf
    np.random.seed(cfg.setting.real.random_state)
    num_runs = cfg.setting.real.num_runs
    num_data = cfg.setting.real.num_data
    estimation_noise_list = cfg.setting.real.estimation_noise_list

    if cfg.setting.real.reward_type_conversion == "continuous":
        dataset = RealSlateBanditDataset(
            n_unique_action=cfg.setting.real.n_unique_action,
            len_list=cfg.setting.real.len_list,
            dim_context=cfg.setting.real.dim_context,
            reward_type=cfg.setting.real.reward_type,
            reward_structure=cfg.setting.real.reward_structure,
            decay_function=cfg.setting.real.decay_function,
            base_reward_function=logistic_reward_function,
            base_reward_function_conversion=linear_reward_function,
            behavior_policy_function=linear_behavior_policy_logit,
            is_factorizable=cfg.setting.real.is_factorizable,
            random_state=cfg.setting.real.random_state,
            reward_type_conversion=cfg.setting.real.reward_type_conversion,
            reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
            deterministic_user_threshold=cfg.setting.real.deterministic_user_threshold,
            effect_from_ranking=cfg.setting.real.effect_from_ranking,
        )
    else: #binary
        dataset = RealSlateBanditDataset(
            n_unique_action=cfg.setting.real.n_unique_action,
            len_list=cfg.setting.real.len_list,
            dim_context=cfg.setting.real.dim_context,
            reward_type=cfg.setting.real.reward_type,
            reward_structure=cfg.setting.real.reward_structure,
            decay_function=cfg.setting.real.decay_function,
            base_reward_function=logistic_reward_function,
            base_reward_function_conversion=logistic_reward_function,
            behavior_policy_function=linear_behavior_policy_logit,
            is_factorizable=cfg.setting.real.is_factorizable,
            random_state=cfg.setting.real.random_state,
            reward_type_conversion=cfg.setting.real.reward_type_conversion,
            reward_structure_conversion=cfg.setting.real.reward_structure_conversion,
            deterministic_user_threshold=cfg.setting.real.deterministic_user_threshold,
            effect_from_ranking=cfg.setting.real.effect_from_ranking,
        )

    #evaluation policy
    n_test = cfg.setting.real.n_test
    fixed_context = dataset.fixed_context
    user_idx = np.random.choice(fixed_context.shape[0], size=n_test)
    context = fixed_context[user_idx]
    
    if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
            evaluation_policy_logit = linear_reward_function(
                context=context,
                action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
                random_state=cfg.setting.real.random_state,
            )
    else:
        evaluation_policy_logit = linear_behavior_policy_logit(
            context=context,
            action_context=np.eye(cfg.setting.real.n_unique_action, dtype=int),
            random_state=cfg.setting.real.random_state,
            tau=cfg.setting.real.tau_pi_e,
        )
        
    pi_e_value = dataset.calc_ground_truth_policy_value_epsilon_greedy(
        context=context,
        evaluation_policy_logit_=evaluation_policy_logit,
        eps=cfg.setting.real.eps,
        user_idx=user_idx,
    )
    print("pi_e_value", pi_e_value)

    result_df_list = []
    for estimation_noise in estimation_noise_list:
        estimated_policy_value_list = []
        for _ in tqdm(range(num_runs), desc=f"estimation_noise={estimation_noise}..."):
            validation_bandit_data = dataset.obtain_batch_bandit_feedback(
                n_rounds=num_data,
                # clip_logit_value=700.0,
            )
            # print(validation_bandit_data["expected_reward_factual_conversion"])
            # print(dataset.expected_reward_conversion[np.repeat(np.arange(num_data), dataset.len_list, axis=0), validation_bandit_data["action"]])
            
            # print("action", validation_bandit_data["action"])
            # print("expected_reward_factual_conversion", validation_bandit_data["expected_reward_factual_conversion"])
            # print("expected_reward_factual_click", validation_bandit_data["expected_reward_factual_click"])
            # print("expected_reward_factual_conversion", validation_bandit_data["expected_reward_factual_conversion"])

            if cfg.setting.real.evaluation_policy_logit == "linear_reward_function":
                evaluation_policy_logit = linear_reward_function(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                )
            else:
                evaluation_policy_logit = linear_behavior_policy_logit(
                    context=validation_bandit_data["context"],
                    action_context=validation_bandit_data["action_context"],
                    random_state=cfg.setting.real.random_state,
                    tau=cfg.setting.real.tau_pi_e,
                )

            (
                evaluation_policy_pscore, 
                evaluation_policy_pscore_item_position, 
                evaluation_policy_pscore_cascade,
                evaluation_policy_p_click, 
                p_click_pi_e,
            )  = dataset.obtain_pscore_given_evaluation_policy_logit_epsilon_greedy(
                context=validation_bandit_data["context"],
                action=validation_bandit_data["action"],
                evaluation_policy_logit_=evaluation_policy_logit,
                eps=cfg.setting.real.eps,
            )
            
            #obtain regression model
            click_probability_true = validation_bandit_data["expected_reward_factual_click"] #p_c(x,a_A)
            ################################################
            # reg_model = RegressionModel(
            #     n_actions=cfg.setting.real.n_unique_action, 
            #     base_model=MLPRegressor(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state),
            # )
            # mask = (validation_bandit_data["reward_click"]==1)
            # reg_model.fit(
            #     context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)[mask], # context; x
            #     action=validation_bandit_data["action"][mask], # action; a
            #     reward=validation_bandit_data["reward"][mask], # reward; r
            # )
            # # estimated_conversion (n_rounds*len_list, n_unique_actions, 1)
            # estimated_conversion = reg_model.predict(
            #     context=np.repeat(validation_bandit_data["context"], dataset.len_list, axis=0)
            # )
            
            # estimated_conversion_for_dm_term = reg_model.predict(
            #     context=validation_bandit_data["context"]
            # )[:,:,0]
            estimated_conversion_for_dm_term = dataset.expected_reward_conversion + np.random.normal(loc=0,scale=estimation_noise, size=(dataset.expected_reward_conversion).shape)
            # print("estimated_conversion", estimated_conversion_for_dm_term)
            # estimated_conversion_factual = estimated_conversion[np.arange(dataset.len_list*validation_bandit_data["context"].shape[0]),validation_bandit_data["action"],0]
            #############
            # estimated_conversion_factual = validation_bandit_data["expected_reward_factual_conversion"] 
            # estimated_conversion_factual += np.random.normal(loc=0,scale=estimation_noise, size=estimated_conversion_factual.shape)
            # #############
            estimated_conversion_factual = estimated_conversion_for_dm_term[np.repeat(np.arange(num_data), dataset.len_list, axis=0),validation_bandit_data["action"]]

            # print(estimated_conversion_factual.shape)
            estimated_CR_factual = click_probability_true * estimated_conversion_factual #true_click * estimated conversion
            ################################################
            ################################################
            #estimate click probability
            click_model=MLPClassifier(hidden_layer_sizes=(30,30,30), max_iter=3000,early_stopping=True,random_state=cfg.setting.real.random_state)
            X_train = np.concatenate([validation_bandit_data["context"], validation_bandit_data["action"].reshape(-1,dataset.len_list)], axis=1)
            y_train = validation_bandit_data["reward_click"].reshape(-1,dataset.len_list)
            click_model.fit(
                X=X_train, 
                y=y_train, 
            )
            
            (
                estimated_behavior_policy_p_click, 
                estimated_evaluation_policy_p_click,
                p_click_pi_e_by_click_model, #p_c(x,a,pi_e)
            )  = dataset.obtain_p_click_pi_given_estimated_click_probability(
                        context=validation_bandit_data["context"],
                        action=validation_bandit_data["action"],
                        click_model=click_model,
                        evaluation_policy_logit_type=cfg.setting.real.evaluation_policy_logit,
                        eps=cfg.setting.real.eps,
                        tau=cfg.setting.real.tau_pi_e,
                )
            click_probability_factual_by_click_model = click_model.predict_proba(X_train).reshape(validation_bandit_data["action"].shape[0])
            estimated_CR_factual_by_click_model = click_probability_factual_by_click_model * estimated_conversion_factual #true_click * estimated conversion
            # print(estimated_CR_factual_by_click_model)
            dm_term = (p_click_pi_e*estimated_conversion_for_dm_term).sum()
            dm_term_by_click_model = (p_click_pi_e_by_click_model*estimated_conversion_for_dm_term).sum()
            # print(dm_term)
            # print(dm_term_by_click_model)
            ################################################

            ope = OffPolicyEvaluation(
                bandit_feedback=validation_bandit_data,
                ope_estimators=[
                        IPS(estimator_name="IPS", len_list=cfg.setting.real.len_list), 
                        IIPS(estimator_name="IIPS", len_list=cfg.setting.real.len_list),  
                        RIPS(estimator_name="RIPS", len_list=cfg.setting.real.len_list),
                        CIPS(estimator_name="CIPS", len_list=cfg.setting.real.len_list),
                        CDR(estimator_name="CDR", len_list=cfg.setting.real.len_list),
                        CIPS(estimator_name="CIPS (estimate)", len_list=cfg.setting.real.len_list, use_estimated_click_model=True),
                        CDR(estimator_name="CDR (estimate)", len_list=cfg.setting.real.len_list, use_estimated_click_model=True),
                    ]
            )

            estimated_policy_values = ope.estimate_policy_values(
                evaluation_policy_pscore=evaluation_policy_pscore,
                evaluation_policy_pscore_item_position=evaluation_policy_pscore_item_position,
                evaluation_policy_pscore_cascade=evaluation_policy_pscore_cascade,
                evaluation_policy_p_click=evaluation_policy_p_click, #pc(x,a,pi)
                behavior_policy_p_click=validation_bandit_data["p_click_factual_pi_0"], #pc(x,a,pi_0)
                estimated_conversion_factual=estimated_conversion_factual, #p_r
                q_hat=estimated_CR_factual, # q_hat
                estimated_behavior_policy_p_click= estimated_behavior_policy_p_click,
                estimated_evaluation_policy_p_click=estimated_evaluation_policy_p_click,
                q_hat_by_estimated_click_model=estimated_CR_factual_by_click_model,
                dm_term=dm_term,
                dm_term_by_click_model=dm_term_by_click_model,
            )
            estimated_policy_value_list.append(estimated_policy_values)
            # print("max_iw", (evaluation_policy_pscore/ validation_bandit_data["pscore"]).max())
            # print("max_iw_CIPS", (evaluation_policy_p_click/ validation_bandit_data["p_click_factual_pi_0"]).max())
        
        
        #summarize result
        result_df = (
            DataFrame(DataFrame(estimated_policy_value_list).stack())
            .reset_index(1)
            .rename(columns={"level_1": "est", 0: "value"})
        )
        result_df["estimation_noise"] = estimation_noise
        result_df["pi_e_value"] = pi_e_value
        result_df["se"] = (result_df.value - pi_e_value) ** 2
        result_df["bias"] = 0.0
        result_df["variance"] = 0.0

        sample_mean = DataFrame(result_df.groupby(["est"]).mean().value).reset_index()
        for est_ in sample_mean["est"]:
            estimates = result_df.loc[result_df["est"] == est_, "value"].values
            mean_estimates = sample_mean.loc[sample_mean["est"] == est_, "value"].values
            mean_estimates = np.ones_like(estimates) * mean_estimates
            result_df.loc[result_df["est"] == est_, "bias"] = (
                pi_e_value - mean_estimates
            ) ** 2
            result_df.loc[result_df["est"] == est_, "variance"] = (
                estimates - mean_estimates
            ) ** 2
        result_df_list.append(result_df)
        print("max_iw", (evaluation_policy_pscore/ validation_bandit_data["pscore"]).max())
        print("max_iw_CIPS", (evaluation_policy_p_click/ validation_bandit_data["p_click_factual_pi_0"]).max())
        result_df = pd.concat(result_df_list).reset_index(level=0)
        result_df.to_csv("estimation_noise.csv")
        plot(vary_list=estimation_noise_list, result_df=result_df, variable_name="estimation_noise")
        plot_normalize(vary_list=estimation_noise_list, result_df=result_df, variable_name="estimation_noise")

        tqdm.write("=====" * 15)
    
    result_df = pd.concat(result_df_list).reset_index(level=0)
    result_df.to_csv("estimation_noise.csv")


    plot(vary_list=estimation_noise_list, result_df=result_df, variable_name="estimation_noise")
    plot_normalize(vary_list=estimation_noise_list, result_df=result_df, variable_name="estimation_noise")

if __name__ == "__main__":
    main()
