{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "313d98f5-34e0-4e5a-9211-4f373427d6d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "from typing import List, Tuple, Dict\n",
    "\n",
    "from scipy import stats\n",
    "from densratio import densratio\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import copy\n",
    "import json\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import torch\n",
    "from pandas import DataFrame\n",
    "from sklearn.utils import check_random_state\n",
    "from policylearners import GradientBasedPolicyLearner, GradientBasedPolicyLearnerMDOPE\n",
    "from utils import softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363fd077-2f5a-4f3e-9e65-fb8ce2df0c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you are running locally, make sure you are in the directory of KuaiRec.\n",
    "rootpath=\"../../../\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30ca6e46-a1af-40e4-8c67-62031520839a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Loading small matrix...\")\n",
    "small_matrix = pd.read_csv(rootpath + \"data/small_matrix.csv\")\n",
    "\n",
    "print(\"Loading user features...\")\n",
    "user_features = pd.read_csv(rootpath + \"data/user_features.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a72851ef-3f16-4fb7-bf64-7cd4e80d5cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = small_matrix.copy()\n",
    "df = df[[\"user_id\", \"video_id\",\"watch_ratio\"]]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c05bbd8-f6d2-4cf2-bc29-d77165cd15ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "SEED = 1111\n",
    "rng = np.random.default_rng(SEED)\n",
    "rng"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e086eecf-162e-4c31-afcb-a88bd14bd996",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "num_action = 30 \n",
    "domain_cluster_num = 5 \n",
    "td_cluster_num = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd2bc0e5-93f3-4ca7-abb4-0a49260c23bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cat_col_name = [\"user_active_degree\",\"follow_user_num_range\",\"fans_user_num_range\",\"friend_user_num_range\",\"register_days_range\"]\n",
    "for col in cat_col_name:\n",
    "    le = LabelEncoder()\n",
    "    encoded = le.fit_transform(user_features[col].values)\n",
    "    user_features[col] = encoded\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7442ac-f517-46cf-b9c2-81fbfefb2666",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_feature_null = [\"onehot_feat4\",\"onehot_feat12\",\"onehot_feat13\",\"onehot_feat14\",\"onehot_feat15\",\"onehot_feat16\",\"onehot_feat17\"]\n",
    "user_features = user_features.drop(user_feature_null, axis=1)\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3b431f-4018-4404-8415-64ce60db34b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_features_onehot = pd.get_dummies(user_features, columns=user_features.drop([\"user_id\",\"follow_user_num\",\"fans_user_num\",\"friend_user_num\",\"register_days\"],axis=1).columns.tolist(), dtype=int)\n",
    "user_features_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5a186f6-dbe0-4b9b-aef6-3e17e6e4b7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df_count = df.groupby(\"video_id\")[[\"user_id\"]].nunique()\n",
    "df_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87e90588-0dd0-4e69-89b1-e578a6f4aae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "video_id_all_user = df_count[df_count[\"user_id\"]==df[\"user_id\"].nunique()].index.values\n",
    "video_id_all_user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6b8ce83-1f59-4441-b2b7-44a9325d21fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[df[\"video_id\"].isin(video_id_all_user)]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd485854-6d0d-4bb1-ad13-3ec0001695fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_base = df.copy()\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8d0860e-0bee-4a43-8dc2-50ba6a7b3b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "use_action = rng.choice(df_base[\"video_id\"].unique(), size=num_action, replace=False, shuffle=False)\n",
    "use_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "753e4ca7-fd77-40fe-943b-cb0e2a3d93fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df_base = df_base[df_base[\"video_id\"].isin(use_action)].sort_values([\"user_id\",\"video_id\"]).reset_index(drop=True)\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b665efa-c3c2-4fad-b036-bcc8d6056ca9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "reindex_action = {video_index: i for i, video_index in enumerate(use_action)}\n",
    "reindex_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6508a273-db87-4095-848a-5285758b2d8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_base[\"video_id\"] = df_base[\"video_id\"].map(reindex_action)\n",
    "df_base.sort_values([\"user_id\",\"video_id\"]).reset_index(drop=True)\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8f406c-d76a-4738-9a71-01fa8fb0925d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "key_ls = [\"user_id\"] + [f\"action_{a_id}\" for a_id in range(num_action)]\n",
    "user_reward_dict_for_clustering = {k:[] for k in key_ls}\n",
    "user_reward_dict_for_clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19263ccd-f673-4ae2-8012-62db53f8f259",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for u_id in df_base[\"user_id\"].unique():\n",
    "    user_reward_dict_for_clustering[\"user_id\"].append(u_id)\n",
    "    for a_id in range(num_action):\n",
    "        r = df_base[(df_base[\"user_id\"]==u_id) & (df_base[\"video_id\"]==a_id)][\"watch_ratio\"].values[0]\n",
    "        user_reward_dict_for_clustering[f\"action_{a_id}\"].append(r)\n",
    "user_reward_dict_for_clustering       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700b2180-9695-486b-a40b-7b79d56d2dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_reward_df = pd.DataFrame(user_reward_dict_for_clustering)\n",
    "user_reward_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d512bc8-7887-423d-a539-74a8ebb11706",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features = user_features[user_features[\"user_id\"].isin(df_base[\"user_id\"].unique())].reset_index(drop=True)\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee56eef6-5286-427a-83b1-60d528fb1654",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features_onehot = user_features_onehot[user_features_onehot[\"user_id\"].isin(df_base[\"user_id\"].unique())].reset_index(drop=True)\n",
    "user_features_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f74ee1c-6858-4e80-8080-aa2ea40b4909",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_feature_not_null_except_user_id_row = user_features.drop(\"user_id\",axis=1).columns.tolist()\n",
    "\n",
    "user_feature_not_null_except_user_id = user_features_onehot.drop(\"user_id\",axis=1).columns.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fea139-a463-4287-8207-191c88d087d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_u = user_features_onehot[user_feature_not_null_except_user_id].values\n",
    "X_u.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcfc7635-b0ec-42a4-a7e6-0a16cc57e763",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_reward_df[\"all_action_reward_mean\"] = user_reward_df.drop(\"user_id\", axis=1).mean(axis=1)\n",
    "user_reward_df = user_reward_df.sort_values(\"all_action_reward_mean\").reset_index(drop=True)\n",
    "user_reward_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e549c04-21e8-49b3-b04b-f6753ab7af6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def user_split_softmax(use_domain: int) -> List[float]:\n",
    "    reward_mean_each_user  = np.array([user_reward_df[user_reward_df[\"user_id\"]==user_id][\"all_action_reward_mean\"].values[0] for user_id in user_reward_df[\"user_id\"].unique()])\n",
    "    prob_each_user = np.exp(alpha_ls[use_domain] * reward_mean_each_user) / np.sum(np.exp(alpha_ls[use_domain] * reward_mean_each_user))\n",
    "    return prob_each_user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fd793a-c25e-41b3-b028-85b9cb455e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_ls = [-0.5, -0.4, -0.3, -0.2, -0.1, 0.2, 0.4, 0.6, 0.8, 1.0]\n",
    "alpha_ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3736bcd-6174-48dc-85e9-5b0e0b344553",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_domain = 10\n",
    "num_domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8336a6c6-c835-4dd2-8250-27f9f5683811",
   "metadata": {},
   "outputs": [],
   "source": [
    "td_num = 1 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b66b9f8-47dc-4abe-8faa-a6e9ba19f4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "r_matrix = defaultdict(lambda: defaultdict(float))\n",
    "r_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c866c3a2-0cfe-4c3d-b9b5-ddc59c795ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for u,i,r in zip(df_base[\"user_id\"].values,df_base[\"video_id\"].values,df_base[\"watch_ratio\"].values):\n",
    "    r_matrix[u][i] = r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42b39352-143b-4a4f-9a5e-b298929ce224",
   "metadata": {},
   "outputs": [],
   "source": [
    "lp_beta = [1.7554052086783454,\n",
    " -0.0778302798603224,\n",
    " 0.4358573654398761,\n",
    " 0.9618667569858963,\n",
    " -1.6885912553614273,\n",
    " 1.4200556375331574,\n",
    " 0.07585659348725438,\n",
    " 1.4690940931261292,\n",
    " -0.1691014474606618,\n",
    " 1.666700954096115]\n",
    "\n",
    "lp_beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49205f24-3901-441a-b886-73e6b206737d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "noise_dict = {u: rng.uniform(-3,3,num_action).tolist() for u in user_reward_df[\"user_id\"].unique()}\n",
    "noise_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af28d48-7704-42e8-a72c-823a0155697d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def logging_policy_softmax(user_id: int, use_domain: int) -> List[float]:\n",
    "    reward_each_action  = np.array([r_matrix[user_id][action] + noise_dict[user_id][action] for action in range(num_action)])\n",
    "    reward_each_action_beta = lp_beta[use_domain] * reward_each_action\n",
    "    \n",
    "    reward_each_action_beta -= reward_each_action_beta.max()\n",
    "    prob_each_action = np.exp(reward_each_action_beta) / np.sum(np.exp(reward_each_action_beta))\n",
    "        \n",
    "    return prob_each_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0f63647-3dd7-4a9a-b610-d4c5f3237f32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_max_f(q_hat, beta=20):\n",
    "    q_hat_beta = beta * q_hat\n",
    "    \n",
    "    q_hat_beta -= q_hat_beta.max()\n",
    "\n",
    "    prob_each_action = np.exp(q_hat_beta) / np.sum(np.exp(q_hat_beta))\n",
    "\n",
    "    return prob_each_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b391c5-2168-4d2f-b405-9db9cafad007",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_true_v_lp(td_domain_num):\n",
    "    true_v = 0\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_domain_num)\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        #prob_each_action = evaluation_policy_mix_deterministic_and_epsilongreedy(u_id, td_domain_num)\n",
    "        prob_each_action = logging_policy_softmax(u_id, td_domain_num)\n",
    "        for a, p in enumerate(prob_each_action):\n",
    "            u_v += p * (r_matrix[u_id][a])\n",
    "        true_v += user_p[i]*u_v\n",
    "\n",
    "    return true_v\n",
    "\n",
    "\n",
    "def calc_true_v_ep(td_domain_num, model):\n",
    "    true_v = 0\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_domain_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features_onehot,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    test_data = {\"x\": all_user_context}\n",
    "    prob_each_action = model.predict(test_data)\n",
    "    #print(prob_each_action.shape)\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        for a, p in enumerate(prob_each_action[i]):\n",
    "            u_v += p * (r_matrix[u_id][a])\n",
    "        true_v += user_p[i]*u_v\n",
    "\n",
    "    return true_v\n",
    "\n",
    "def calc_true_v_ep_dm(td_domain_num, model):\n",
    "    true_v = 0\n",
    "    all_user_context_action = []\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    all_user_context = all_user_context.tolist()\n",
    "    for user_context in all_user_context:\n",
    "        for a in range(num_action):\n",
    "            all_user_context_action.append(user_context+[a])\n",
    "    q_hat_each_user_action = model.predict(all_user_context_action)\n",
    "    q_hat_each_user_action = q_hat_each_user_action.reshape(-1,30)\n",
    "    #q_hat_each_user_argmax_aciton = q_hat_each_user_action.argmax(axis=1)\n",
    "    #for i, u_id in enumerate(user_id_vec):\n",
    "        #u_v = r_matrix[u_id][q_hat_each_user_argmax_aciton[i]]\n",
    "        #true_v += user_p[i]*u_v\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        prob_each_a = soft_max_f(q_hat_each_user_action[i])\n",
    "        u_v = 0\n",
    "        for a, prob in enumerate(prob_each_a):\n",
    "            u_v += prob*r_matrix[u_id][a]\n",
    "        true_v += user_p[i]*u_v\n",
    "\n",
    "    return true_v\n",
    "\n",
    "def calc_newaction_freq_ep(td_domain_num, model):\n",
    "    freq = 0\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_domain_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features_onehot,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    test_data = {\"x\": all_user_context}\n",
    "    prob_each_action = model.predict(test_data)\n",
    "    \n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        for a, p in enumerate(prob_each_action[i]):\n",
    "            if a in new_action_index_array:\n",
    "                u_v += p \n",
    "            else:\n",
    "                pass\n",
    "        freq += user_p[i]*u_v\n",
    "\n",
    "    return freq\n",
    "\n",
    "def calc_relative_newaction_value_ep(td_domain_num, model):\n",
    "    true_v = 0\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_domain_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features_onehot,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    test_data = {\"x\": all_user_context}\n",
    "    prob_each_action = model.predict(test_data)\n",
    "    \n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        freq = 0\n",
    "        for a, p in enumerate(prob_each_action[i]):\n",
    "            if a in new_action_index_array:\n",
    "                u_v += p * (r_matrix[u_id][a])\n",
    "                freq += p\n",
    "            else:\n",
    "                pass\n",
    "        if freq == 0:\n",
    "            true_v += 0\n",
    "        else:\n",
    "            true_v += user_p[i]*(u_v/freq)\n",
    "\n",
    "    return true_v\n",
    "\n",
    "def calc_relative_newaction_value_ep_dm(td_domain_num, model):\n",
    "    true_v = 0\n",
    "    all_user_context_action = []\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    all_user_context = all_user_context.tolist()\n",
    "    for user_context in all_user_context:\n",
    "        for a in range(num_action):\n",
    "            all_user_context_action.append(user_context+[a])\n",
    "    q_hat_each_user_action = model.predict(all_user_context_action)\n",
    "    q_hat_each_user_action = q_hat_each_user_action.reshape(-1,30)\n",
    "    #q_hat_each_user_argmax_aciton = q_hat_each_user_action.argmax(axis=1)\n",
    "    #for i, u_id in enumerate(user_id_vec):\n",
    "        #if q_hat_each_user_argmax_aciton[i] in new_action_index_array:\n",
    "            #true_v += user_p[i]*r_matrix[u_id][q_hat_each_user_argmax_aciton[i]] \n",
    "        #else:\n",
    "            #pass\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        freq = 0\n",
    "        prob_each_a = soft_max_f(q_hat_each_user_action[i])\n",
    "        for a, p in enumerate(prob_each_a):\n",
    "            if a in new_action_index_array:\n",
    "                u_v += p * (r_matrix[u_id][a])\n",
    "                freq += p\n",
    "            else:\n",
    "                pass\n",
    "        if freq == 0:\n",
    "            true_v += 0\n",
    "        else:\n",
    "            true_v += user_p[i]*(u_v/freq)\n",
    "\n",
    "    return true_v\n",
    "\n",
    "def calc_newaction_freq_ep_dm(td_domain_num, model):\n",
    "    freq = 0\n",
    "    all_user_context_action = []\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_num)\n",
    "    all_user_context = user_reward_df[[\"user_id\"]].merge(user_features,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "    all_user_context = all_user_context.tolist()\n",
    "    for user_context in all_user_context:\n",
    "        for a in range(num_action):\n",
    "            all_user_context_action.append(user_context+[a])\n",
    "    q_hat_each_user_action = model.predict(all_user_context_action)\n",
    "    q_hat_each_user_action = q_hat_each_user_action.reshape(-1,30)\n",
    "    #q_hat_each_user_argmax_aciton = q_hat_each_user_action.argmax(axis=1)\n",
    "    #for i, u_id in enumerate(user_id_vec):\n",
    "        #if q_hat_each_user_argmax_aciton[i] in new_action_index_array:\n",
    "            #freq += user_p[i]\n",
    "        #else:\n",
    "            #pass\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        prob_each_a = soft_max_f(q_hat_each_user_action[i])\n",
    "        u_v = 0\n",
    "        for a, prob in enumerate(prob_each_a):\n",
    "            if a in new_action_index_array:\n",
    "                u_v += prob\n",
    "            else:\n",
    "                pass\n",
    "        freq += user_p[i]*u_v\n",
    "\n",
    "    return freq\n",
    "\n",
    "def log_data_generate(domain_num: int, log_data_sample_size: int, seed: int) -> Tuple[List[List[float]], List[int], List[float]]:\n",
    "    user_id_vec = []\n",
    "    context_vec = []\n",
    "    context_vec_row = []\n",
    "    action_vec = []\n",
    "    reward_vec = []\n",
    "    pscore_vec = []\n",
    "    pi_0_vec = []\n",
    "    q_x_a_vec = []\n",
    "    u_prob = user_split_softmax(domain_num)\n",
    "    unique_user_id = user_reward_df[\"user_id\"].unique()\n",
    "    \n",
    "    for i in range(log_data_sample_size):\n",
    "        \n",
    "        u_id = rng.choice(unique_user_id, size=1, p=u_prob)[0]\n",
    "        \n",
    "        context_sample_row = user_features[user_features[\"user_id\"]==u_id][user_feature_not_null_except_user_id_row].values[0].tolist()\n",
    "        context_sample = user_features_onehot[user_features_onehot[\"user_id\"]==u_id][user_feature_not_null_except_user_id].values[0].tolist()\n",
    "        \n",
    "        pi_0 = logging_policy_softmax(u_id, domain_num)\n",
    "        action_sample = rng.choice(num_action, size=1, p=pi_0)[0]\n",
    "        pscore = pi_0[action_sample]\n",
    "        \n",
    "        reward_sample = rng.normal(r_matrix[u_id][action_sample], 1)\n",
    "        \n",
    "        q_x_a_sample = user_reward_df[user_reward_df[\"user_id\"]==u_id].drop([\"user_id\",\"all_action_reward_mean\"],axis=1).values.tolist()[0]\n",
    "\n",
    "        \n",
    "        user_id_vec.append(u_id)\n",
    "        context_vec.append(context_sample)\n",
    "        context_vec_row.append(context_sample_row)\n",
    "        action_vec.append(action_sample)\n",
    "        reward_vec.append(reward_sample)\n",
    "        pscore_vec.append(pscore)\n",
    "        pi_0_vec.append(pi_0)\n",
    "        q_x_a_vec.append(q_x_a_sample)\n",
    "        \n",
    "    return user_id_vec, context_vec, action_vec, reward_vec, pscore_vec, pi_0_vec, context_vec_row, q_x_a_vec\n",
    "\n",
    "def calc_mean_prob_joint_x_a(user_id: int, context: List[float], action: int, domain_index_ls: List[str]) -> float:\n",
    "    prob_x_a = 0\n",
    "\n",
    "    for d in sorted(td_cluster_domains):\n",
    "        logging_policy_prob = logging_policy_softmax(user_id, d)[action]\n",
    "        denstiy_ratio = density_ratio_models[d][0].compute_density_ratio(np.array(context).reshape(1,len(user_feature_not_null_except_user_id_row)))[0]\n",
    "        \n",
    "        if d == td_num:\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob\n",
    "        else:\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob * denstiy_ratio\n",
    "    return prob_x_a / len(domain_index_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ff024c-a36a-4ed6-babe-6e1fcce850b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "max_iter = 30 \n",
    "\n",
    "\n",
    "target_domain_sample_size = 100\n",
    "source_domain_sample_size = 100\n",
    "\n",
    "\n",
    "seed_ls = [i for i in range(0,350)]\n",
    "\n",
    "\n",
    "dim_x = len(user_feature_not_null_except_user_id)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "cluster_size_ls = [2,4,6,8,10]\n",
    "cluster_size_ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55fb7e90-51f7-43d9-96cb-8bd5dc3bcab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_value_of_learned_policies = {\"DM\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DM_ALL\":{cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"MDOPE\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"logging\": {cluster_size: [] for cluster_size in cluster_size_ls}}\n",
    "true_value_only_newaction_of_learned_policies =  {\"DM\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DM_ALL\":{cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"MDOPE\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"logging\": {cluster_size: [] for cluster_size in cluster_size_ls}}\n",
    "freq_newaction_of_learned_policies =  {\"DM\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DM_ALL\":{cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"MDOPE\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"logging\": {cluster_size: [] for cluster_size in cluster_size_ls}}\n",
    "train_true_value_per_epoch =  {\"DM\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DM_ALL\":{cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"MDOPE\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"logging\": {cluster_size: [] for cluster_size in cluster_size_ls}}\n",
    "true_value_per_epoch =  {\"DM\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DM_ALL\":{cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"IPS_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"DR_ALL\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"MDOPE\": {cluster_size: [] for cluster_size in cluster_size_ls},\n",
    "                                  \"logging\": {cluster_size: [] for cluster_size in cluster_size_ls}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d395e4-63bc-4ea2-af71-ed1506b97173",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "new_action_index_array = []\n",
    "for seed in seed_ls:\n",
    "    print(f\"NOW SEED = {seed}\")\n",
    "    for cluster_size in cluster_size_ls:\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        random_ = check_random_state(seed)\n",
    "        \n",
    "        \n",
    "        pi_0_value = calc_true_v_lp(td_num)\n",
    "        #pi_0_value_only_na = calc_true_v_only_newaction_lp(td_num)\n",
    "        true_value_of_learned_policies[\"logging\"][cluster_size].append(pi_0_value)\n",
    "        #true_value_only_newaction_of_learned_policies[\"logging\"][new_action_num].append(pi_0_value_only_na)\n",
    "\n",
    "        \n",
    "        \n",
    "        user_id_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        context_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        context_row_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        action_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        reward_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        pi_0_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        pscore_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        q_x_a_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "    \n",
    "        \n",
    "        for d in reward_dict_each_domain.keys():\n",
    "            if d == td_num:\n",
    "                user_id_vec, context_vec, action_vec, reward_vec, pscore_vec, pi_0_vec, context_vec_row, q_x_a_vec = log_data_generate(d, target_domain_sample_size, seed)\n",
    "            else:\n",
    "                user_id_vec, context_vec, action_vec, reward_vec, pscore_vec, pi_0_vec, context_vec_row, q_x_a_vec = log_data_generate(d, source_domain_sample_size, seed)\n",
    "            user_id_dict_each_domain[d] = user_id_vec\n",
    "            context_dict_each_domain[d] = context_vec\n",
    "            context_row_dict_each_domain[d] = context_vec_row\n",
    "            action_dict_each_domain[d] = action_vec\n",
    "            reward_dict_each_domain[d] = reward_vec\n",
    "            pi_0_dict_each_domain[d] = pi_0_vec\n",
    "            pscore_dict_each_domain[d] = pscore_vec\n",
    "            q_x_a_dict_each_domain[d] = q_x_a_vec\n",
    "        \n",
    "        \n",
    "        domain_feature_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        domain_feature_dict_each_domain_ = {key: [] for key in range(num_domain)}\n",
    "        for d in domain_feature_dict_each_domain.keys():\n",
    "            domain_feature_dict_each_domain[d] = np.mean(reward_dict_each_domain[d])\n",
    "            domain_feature_dict_each_domain_[d] = np.mean(reward_dict_each_domain[d])\n",
    "\n",
    "        \n",
    "        td_cluster_domains = []\n",
    "        td_d_feature = domain_feature_dict_each_domain[td_num]\n",
    "        for k,v in domain_feature_dict_each_domain_.items():\n",
    "            domain_feature_dict_each_domain_[k] = abs(td_d_feature-v)\n",
    "        \n",
    "        domain_feature_dict_each_domain_ = sorted(domain_feature_dict_each_domain_.items(), key=lambda x:x[1])\n",
    "        for i in range(cluster_size):\n",
    "            td_cluster_domains.append(domain_feature_dict_each_domain_[i][0])\n",
    "        \n",
    "        print(td_cluster_domains)\n",
    "        \n",
    "        \n",
    "        print(\"Fit reward model.\")\n",
    "        \n",
    "        reward_models = {\"target_domain\": [], \"ALL_domain\": [], \"Cluster\": []}\n",
    "        \n",
    "        \n",
    "        \n",
    "        for data_type in reward_models.keys():\n",
    "            if data_type == \"target_domain\":\n",
    "                \n",
    "                context_and_action_vec = np.hstack((np.array(context_row_dict_each_domain[td_num]),np.array(action_dict_each_domain[td_num]).reshape(len(action_dict_each_domain[td_num]),-1)))\n",
    "                y_vec = np.array(reward_dict_each_domain[td_num])\n",
    "                forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8,random_state = SEED)\n",
    "                forest.fit(context_and_action_vec, y_vec)\n",
    "                \n",
    "                context_and_action_vec_each_sample_each_action = []\n",
    "                for one_context_vec in context_row_dict_each_domain[td_num]:\n",
    "                    context_and_action_one_sample = []\n",
    "                    for one_action in range(num_action):\n",
    "                        context_and_action_vec_each_sample_each_action.append(one_context_vec+[one_action])\n",
    "                context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "                q_hat_dr_td = forest.predict(context_and_action_vec_each_sample_each_action).reshape(-1,num_action)\n",
    "                \n",
    "                reward_models[\"target_domain\"].append(forest)\n",
    "\n",
    "            elif data_type == \"ALL_domain\":\n",
    "                domain_index_ls = []\n",
    "                user_id_ls_all_domain = []\n",
    "                context_ls_all_domain = []\n",
    "                action_ls_all_domain = []\n",
    "                reward_ls_all_domain = []\n",
    "                for domain_num in context_row_dict_each_domain.keys():\n",
    "                    domain_index_ls += [domain_num]*len(context_row_dict_each_domain[domain_num])\n",
    "                    user_id_ls_all_domain += user_id_dict_each_domain[domain_num]\n",
    "                    context_ls_all_domain += context_row_dict_each_domain[domain_num]\n",
    "                    action_ls_all_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_all_domain += reward_dict_each_domain[domain_num]\n",
    "                context_and_action_vec_all_domain = np.hstack((np.array(context_ls_all_domain),np.array(action_ls_all_domain).reshape(len(action_ls_all_domain),-1)))\n",
    "                y_vec = np.array(reward_ls_all_domain)\n",
    "                \n",
    "                forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8,random_state = SEED)\n",
    "                forest.fit(context_and_action_vec_all_domain, y_vec)\n",
    "                \n",
    "                context_and_action_vec_each_sample_each_action_all = []\n",
    "                for one_context_vec in context_ls_all_domain:\n",
    "                    context_and_action_one_sample = []\n",
    "                    for one_action in range(num_action):\n",
    "                        context_and_action_vec_each_sample_each_action_all.append(one_context_vec+[one_action])\n",
    "                context_and_action_vec_each_sample_each_action_all = np.array(context_and_action_vec_each_sample_each_action_all)\n",
    "                q_hat_dr_all = forest.predict(context_and_action_vec_each_sample_each_action_all).reshape(-1,num_action)\n",
    "                reward_models[\"ALL_domain\"].append(forest)\n",
    "\n",
    "            else:\n",
    "                domain_index_ls_only_tg_cluster = []\n",
    "                user_id_ls_target_cluster_domain = []\n",
    "                context_ls_target_cluster_domain = []\n",
    "                action_ls_target_cluster_domain = []\n",
    "                reward_ls_target_cluster_domain = []\n",
    "                context_and_domain_feature_ls_target_cluster_domain = []\n",
    "                \n",
    "                for domain_num in sorted(td_cluster_domains):\n",
    "                    domain_index_ls_only_tg_cluster += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "                    user_id_ls_target_cluster_domain += user_id_dict_each_domain[domain_num]\n",
    "                    context_ls_target_cluster_domain += context_row_dict_each_domain[domain_num]\n",
    "                    action_ls_target_cluster_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_target_cluster_domain += reward_dict_each_domain[domain_num]\n",
    "                \n",
    "                    contest_and_domain_feature_ls = copy.deepcopy(context_dict_each_domain[domain_num])\n",
    "                    for i in range(len(contest_and_domain_feature_ls)):\n",
    "                        contest_and_domain_feature_ls[i] += [domain_feature_dict_each_domain[domain_num]]\n",
    "                    context_and_domain_feature_ls_target_cluster_domain += contest_and_domain_feature_ls\n",
    "                        \n",
    "                context_and_action_and_domain_feature_vec_target_cluster_domain = np.hstack((np.array(context_and_domain_feature_ls_target_cluster_domain),np.array(action_ls_target_cluster_domain).reshape(len(action_ls_target_cluster_domain),-1)))\n",
    "                y_vec = np.array(reward_ls_target_cluster_domain)\n",
    "                \n",
    "                forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8, random_state = SEED)\n",
    "                forest.fit(context_and_action_and_domain_feature_vec_target_cluster_domain, y_vec)\n",
    "                \n",
    "                context_and_action_vec_each_sample_each_action_cluster = []\n",
    "                for one_context_vec in context_and_domain_feature_ls_target_cluster_domain:\n",
    "                    context_and_action_one_sample = []\n",
    "                    for one_action in range(num_action):\n",
    "                        context_and_action_vec_each_sample_each_action_cluster.append(one_context_vec+[one_action])\n",
    "                context_and_action_vec_each_sample_each_action_cluster = np.array(context_and_action_vec_each_sample_each_action_cluster)\n",
    "                q_hat_dr_cluster = forest.predict(context_and_action_vec_each_sample_each_action_cluster).reshape(-1,num_action)\n",
    "                reward_models[\"Cluster\"].append(forest)\n",
    "                \n",
    "                \n",
    "                \n",
    "                \n",
    "                   \n",
    "        print(\"Fit dens ratio model.\")\n",
    "        \n",
    "        density_ratio_models = {domain_num : [] for domain_num in sorted(td_cluster_domains)}\n",
    "        for domain_num in density_ratio_models.keys():\n",
    "            \n",
    "            model = densratio(np.array(context_row_dict_each_domain[domain_num]), np.array(context_row_dict_each_domain[td_num]), alpha=0.95, verbose=False) \n",
    "            \n",
    "            density_ratio_models[domain_num].append(model)\n",
    "\n",
    "        \n",
    "        offline_logged_data_td = {\"x\":np.array(context_dict_each_domain[td_num]), \"a\": np.array(action_dict_each_domain[td_num]), \"r\": np.array(reward_dict_each_domain[td_num]), \"pi_0\": np.array(pi_0_dict_each_domain[td_num]), \"pscore\":np.array(pscore_dict_each_domain[td_num]), \"q_x_a\":np.array(q_x_a_dict_each_domain[td_num])}\n",
    "        \n",
    "        \n",
    "        all_context = []\n",
    "        all_action = []\n",
    "        all_reward = []\n",
    "        all_pi_0 = []\n",
    "        all_pscore = []\n",
    "        all_q_x_a = []\n",
    "        for i in range(len(context_dict_each_domain)):\n",
    "            all_context += context_dict_each_domain[i]\n",
    "            all_action += action_dict_each_domain[i]\n",
    "            all_reward += reward_dict_each_domain[i]\n",
    "            all_pi_0 += pi_0_dict_each_domain[i]\n",
    "            all_pscore += pscore_dict_each_domain[i]\n",
    "            all_q_x_a += q_x_a_dict_each_domain[i]\n",
    "            \n",
    "        offline_logged_data_all = {\"x\":np.array(all_context), \"a\": np.array(all_action), \"r\": np.array(all_reward), \"pi_0\": np.array(all_pi_0), \"pscore\":np.array(all_pscore), \"q_x_a\":np.array(all_q_x_a)}\n",
    "\n",
    "        \n",
    "        cluster_user_id = []\n",
    "        cluster_domain_index = []\n",
    "        cluster_context = []\n",
    "        cluster_context_row = []\n",
    "        cluster_action = []\n",
    "        cluster_reward = []\n",
    "        cluster_pi_0 = []\n",
    "        cluster_q_x_a = []\n",
    "        cluster_mean_joint_pscore = []\n",
    "        for domain_num in sorted(td_cluster_domains):\n",
    "            cluster_user_id += user_id_dict_each_domain[domain_num]\n",
    "            cluster_domain_index += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "            cluster_context += context_dict_each_domain[domain_num]\n",
    "            cluster_context_row += context_row_dict_each_domain[domain_num]\n",
    "            cluster_action += action_dict_each_domain[domain_num]\n",
    "            cluster_reward += reward_dict_each_domain[domain_num]\n",
    "            cluster_pi_0 += pi_0_dict_each_domain[domain_num]\n",
    "            cluster_q_x_a += q_x_a_dict_each_domain[domain_num]\n",
    "        for i in range(len(cluster_user_id)):\n",
    "            cluster_mean_joint_pscore.append(calc_mean_prob_joint_x_a(cluster_user_id[i],cluster_context_row[i],cluster_action[i],cluster_domain_index))\n",
    "        \n",
    "        \n",
    "        offline_logged_data_cluster = {\"x\":np.array(cluster_context), \"a\": np.array(cluster_action), \"r\": np.array(cluster_reward), \"pi_0\": np.array(cluster_pi_0), \"pscore\":np.array(cluster_mean_joint_pscore), \"d_index\": np.array(cluster_domain_index), \"q_x_a\": np.array(cluster_q_x_a)}\n",
    "        \n",
    "        user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "        user_p = user_split_softmax(td_num)\n",
    "        all_user_context = user_reward_df[[\"user_id\"]].merge(user_features_onehot,how=\"left\",on=\"user_id\").drop(\"user_id\",axis=1).values\n",
    "        test_data = {\"x\": all_user_context, \"p_x\":user_p, \"q_x_a\":user_reward_df.drop([\"user_id\",\"all_action_reward_mean\"],axis=1).values}\n",
    "        \n",
    "        \n",
    "        \n",
    "        true_value_of_learned_policies[\"DM\"][cluster_size].append(calc_true_v_ep_dm(td_num, reward_models[\"target_domain\"][0]))\n",
    "        true_value_only_newaction_of_learned_policies[\"DM\"][cluster_size].append(calc_relative_newaction_value_ep_dm(td_num, reward_models[\"target_domain\"][0]))\n",
    "        freq_newaction_of_learned_policies[\"DM\"][cluster_size].append(calc_newaction_freq_ep_dm(td_num,reward_models[\"target_domain\"][0]))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #print(\"IPS\")\n",
    "        ips = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_action, max_iter=max_iter)\n",
    "        train_pv_ips_per_epoch, test_pv_ips_per_epoch = ips.fit(offline_logged_data_td, offline_logged_data_td, test_data)\n",
    "        train_true_value_per_epoch[\"IPS\"][cluster_size].append(train_pv_ips_per_epoch)\n",
    "        true_value_per_epoch[\"IPS\"][cluster_size].append(test_pv_ips_per_epoch)\n",
    "        true_value_of_learned_policies[\"IPS\"][cluster_size].append(calc_true_v_ep(td_num, ips))\n",
    "        true_value_only_newaction_of_learned_policies[\"IPS\"][cluster_size].append(calc_relative_newaction_value_ep(td_num, ips))\n",
    "        freq_newaction_of_learned_policies[\"IPS\"][cluster_size].append(calc_newaction_freq_ep(td_num,ips))\n",
    "\n",
    "        \n",
    "        #print(\"DR\")\n",
    "        dr = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_action, max_iter=max_iter)\n",
    "        train_pv_dr_per_epoch, test_pv_dr_per_epoch = dr.fit(offline_logged_data_td, offline_logged_data_td, test_data, q_hat=q_hat_dr_td)\n",
    "        train_true_value_per_epoch[\"DR\"][cluster_size].append(train_pv_dr_per_epoch)\n",
    "        true_value_per_epoch[\"DR\"][cluster_size].append(test_pv_dr_per_epoch)\n",
    "        true_value_of_learned_policies[\"DR\"][cluster_size].append(calc_true_v_ep(td_num, dr))\n",
    "        true_value_only_newaction_of_learned_policies[\"DR\"][cluster_size].append(calc_relative_newaction_value_ep(td_num, dr))\n",
    "        freq_newaction_of_learned_policies[\"DR\"][cluster_size].append(calc_newaction_freq_ep(td_num,dr))\n",
    "        \n",
    "        \n",
    "        \n",
    "        true_value_of_learned_policies[\"DM_ALL\"][cluster_size].append(calc_true_v_ep_dm(td_num, reward_models[\"ALL_domain\"][0]))\n",
    "        true_value_only_newaction_of_learned_policies[\"DM_ALL\"][cluster_size].append(calc_relative_newaction_value_ep_dm(td_num, reward_models[\"ALL_domain\"][0]))\n",
    "        freq_newaction_of_learned_policies[\"DM_ALL\"][cluster_size].append(calc_newaction_freq_ep_dm(td_num,reward_models[\"ALL_domain\"][0]))\n",
    "        \n",
    "        \n",
    "        \n",
    "        #print(\"IPS_ALL\")\n",
    "        ips_all = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_action, max_iter=max_iter)\n",
    "        train_pv_ips_all_per_epoch, test_pv_ips_all_per_epoch = ips_all.fit(offline_logged_data_all, offline_logged_data_td, test_data)\n",
    "        train_true_value_per_epoch[\"IPS_ALL\"][cluster_size].append(train_pv_ips_all_per_epoch)\n",
    "        true_value_per_epoch[\"IPS_ALL\"][cluster_size].append(test_pv_ips_all_per_epoch)\n",
    "        true_value_of_learned_policies[\"IPS_ALL\"][cluster_size].append(calc_true_v_ep(td_num, ips_all))\n",
    "        true_value_only_newaction_of_learned_policies[\"IPS_ALL\"][cluster_size].append(calc_relative_newaction_value_ep(td_num, ips_all))\n",
    "        freq_newaction_of_learned_policies[\"IPS_ALL\"][cluster_size].append(calc_newaction_freq_ep(td_num,ips_all))\n",
    "        \n",
    "        \n",
    "        #print(\"DR_ALL\")\n",
    "        dr_all = GradientBasedPolicyLearner(dim_x=dim_x, num_actions=num_action, max_iter=max_iter)\n",
    "        train_pv_dr_all_per_epoch, test_pv_dr_all_per_epoch = dr_all.fit(offline_logged_data_all, offline_logged_data_td, test_data, q_hat=q_hat_dr_all)\n",
    "        train_true_value_per_epoch[\"DR_ALL\"][cluster_size].append(train_pv_dr_all_per_epoch)\n",
    "        true_value_per_epoch[\"DR_ALL\"][cluster_size].append(test_pv_dr_all_per_epoch)\n",
    "        true_value_of_learned_policies[\"DR_ALL\"][cluster_size].append(calc_true_v_ep(td_num, dr_all))\n",
    "        true_value_only_newaction_of_learned_policies[\"DR_ALL\"][cluster_size].append(calc_relative_newaction_value_ep(td_num, dr_all))\n",
    "        freq_newaction_of_learned_policies[\"DR_ALL\"][cluster_size].append(calc_newaction_freq_ep(td_num,dr_all))\n",
    "        \n",
    "        \n",
    "        #print(\"MDOPE\")\n",
    "        mdope = GradientBasedPolicyLearnerMDOPE(dim_x=dim_x, num_actions=num_action, max_iter=max_iter)\n",
    "        train_pv_mdope_per_epoch, test_pv_mdope_per_epoch = mdope.fit(offline_logged_data_cluster, offline_logged_data_td, test_data, q_hat=q_hat_dr_cluster)\n",
    "        train_true_value_per_epoch[\"MDOPE\"][cluster_size].append(train_pv_mdope_per_epoch)\n",
    "        true_value_per_epoch[\"MDOPE\"][cluster_size].append(test_pv_mdope_per_epoch)\n",
    "        true_value_of_learned_policies[\"MDOPE\"][cluster_size].append(calc_true_v_ep(td_num, mdope))\n",
    "        true_value_only_newaction_of_learned_policies[\"MDOPE\"][cluster_size].append(calc_relative_newaction_value_ep(td_num, mdope))\n",
    "        freq_newaction_of_learned_policies[\"MDOPE\"][cluster_size].append(calc_newaction_freq_ep(td_num,mdope))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24752102-80aa-43dc-8286-b50f0f8a69f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_dict_key(d, old_key, new_key, default_value=None):\n",
    "    d[new_key] = d.pop(old_key, default_value)\n",
    "change_dict_key(true_value_of_learned_policies, 'MDOPE', 'COPE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cea457e-9a3e-4dc2-a2e9-f8b33548bd8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_metrics(estimates):\n",
    "    \n",
    "    data_list = []\n",
    "    for estimator_name, prams in estimates.items():\n",
    "        for pram, values in prams.items():\n",
    "            for i,value in enumerate(values):\n",
    "            #for i,value in enumerate(values):\n",
    "                #data_list.append({'Estimator': estimator_name, 'New_Action_Num': new_action_num/30, 'Value': value/estimates[\"logging\"][new_action_num][i]})\n",
    "                if estimator_name == \"logging\":\n",
    "                    data_list.append({'Estimator': estimator_name, 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"DM\":\n",
    "                    data_list.append({'Estimator': \"Reg-based(T)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"DM_ALL\":\n",
    "                    data_list.append({'Estimator': \"Reg-based(ALL)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"IPS\":\n",
    "                    data_list.append({'Estimator': estimator_name+\"-PG(T)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"IPS_ALL\":\n",
    "                    data_list.append({'Estimator': \"IPS\"+\"-PG(ALL)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"DR\":\n",
    "                    data_list.append({'Estimator': estimator_name+\"-PG(T)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"DR_ALL\":\n",
    "                    data_list.append({'Estimator': \"DR\"+\"-PG(ALL)\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "                elif estimator_name == \"MDOPE\" or estimator_name == \"COPE\":\n",
    "                    data_list.append({'Estimator': \"COPE\"+\"-PG\", 'cl_size': pram, 'Value': value/estimates[\"logging\"][pram][i]})\n",
    "\n",
    "    df = pd.DataFrame(data_list)\n",
    "    return df\n",
    "\n",
    "df_metrics_1 = calculate_metrics(true_value_of_learned_policies)\n",
    "df_metrics_1 = df_metrics_1[df_metrics_1[\"Estimator\"]!=\"logging\"].reset_index(drop=True)\n",
    "df_metrics_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5f96bd9-e230-44b7-9c2e-8e31974763a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h', 'D']#, 'd', 'P', 'X']\n",
    "#palette = sns.color_palette(\"deep\")[:7]\n",
    "palette = sns.color_palette(\"deep\")[:8]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='cl_size', y='Value', hue='Estimator', data=df_metrics_1,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "#plt.ylabel('relative policy value', fontsize=25)\n",
    "plt.ylabel('', fontsize=0)\n",
    "#ax.set_yscale('log')\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "#ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('size of the target cluster', fontsize=27, labelpad=20)\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "plt.xticks(cluster_size_ls)\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('relative policy value', fontsize=35)\n",
    "#plt.savefig(f'../../../output_experiments/real_world_opl-policy_value-cl.png', dpi=500, bbox_inches='tight');\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833a6bc7-5fef-4425-b875-5985a5d8ac5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#df_metrics_1.to_csv(\"./OPL_results/opl_pv_cl.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79060add-9145-4c72-94bc-81b5e2f45cd5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
