{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "313d98f5-34e0-4e5a-9211-4f373427d6d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.cluster import KMeans\n",
    "import seaborn as sns\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "from typing import List, Tuple, Dict\n",
    "\n",
    "from scipy import stats\n",
    "from densratio import densratio\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import copy\n",
    "import json\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ad11c04-4483-468d-82db-967d3a65eba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "pd.set_option('display.max_rows', 100)\n",
    "\n",
    "pd.set_option('display.max_columns', 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "363fd077-2f5a-4f3e-9e65-fb8ce2df0c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you are running locally, make sure you are in the directory of KuaiRec.\n",
    "rootpath=\"../../../\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30ca6e46-a1af-40e4-8c67-62031520839a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Loading small matrix...\")\n",
    "small_matrix = pd.read_csv(rootpath + \"data/small_matrix.csv\")\n",
    "\n",
    "print(\"Loading user features...\")\n",
    "user_features = pd.read_csv(rootpath + \"data/user_features.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988c770d-012c-47b1-bd82-97b9e477ac32",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a72851ef-3f16-4fb7-bf64-7cd4e80d5cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = small_matrix.copy()\n",
    "df = df[[\"user_id\", \"video_id\",\"watch_ratio\"]]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "729c0f4b-25c4-4b96-b33a-a4f0cffff85d",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c05bbd8-f6d2-4cf2-bc29-d77165cd15ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "SEED = 111\n",
    "rng = np.random.default_rng(SEED)\n",
    "rng"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e086eecf-162e-4c31-afcb-a88bd14bd996",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "num_action = 30 \n",
    "domain_cluster_num = 5 \n",
    "td_cluster_num = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd2bc0e5-93f3-4ca7-abb4-0a49260c23bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cat_col_name = [\"user_active_degree\",\"follow_user_num_range\",\"fans_user_num_range\",\"friend_user_num_range\",\"register_days_range\"]\n",
    "for col in cat_col_name:\n",
    "    le = LabelEncoder()\n",
    "    encoded = le.fit_transform(user_features[col].values)\n",
    "    user_features[col] = encoded\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7442ac-f517-46cf-b9c2-81fbfefb2666",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_feature_null = [\"onehot_feat4\",\"onehot_feat12\",\"onehot_feat13\",\"onehot_feat14\",\"onehot_feat15\",\"onehot_feat16\",\"onehot_feat17\"]\n",
    "user_features = user_features.drop(user_feature_null, axis=1)\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3b431f-4018-4404-8415-64ce60db34b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_features_onehot = pd.get_dummies(user_features, columns=user_features.drop([\"user_id\",\"follow_user_num\",\"fans_user_num\",\"friend_user_num\",\"register_days\"],axis=1).columns.tolist(), dtype=int)\n",
    "user_features_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5a186f6-dbe0-4b9b-aef6-3e17e6e4b7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df_count = df.groupby(\"video_id\")[[\"user_id\"]].nunique()\n",
    "df_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87e90588-0dd0-4e69-89b1-e578a6f4aae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "video_id_all_user = df_count[df_count[\"user_id\"]==df[\"user_id\"].nunique()].index.values\n",
    "video_id_all_user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6b8ce83-1f59-4441-b2b7-44a9325d21fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[df[\"video_id\"].isin(video_id_all_user)]\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b02d2da0-4cf8-42bd-b7a0-3b5d57d22fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"video_id\"].nunique(),df[\"user_id\"].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd485854-6d0d-4bb1-ad13-3ec0001695fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_base = df.copy()\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8d0860e-0bee-4a43-8dc2-50ba6a7b3b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "use_action = rng.choice(df_base[\"video_id\"].unique(), size=num_action, replace=False, shuffle=False)\n",
    "use_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "753e4ca7-fd77-40fe-943b-cb0e2a3d93fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df_base = df_base[df_base[\"video_id\"].isin(use_action)].sort_values([\"user_id\",\"video_id\"]).reset_index(drop=True)\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b665efa-c3c2-4fad-b036-bcc8d6056ca9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "reindex_action = {video_index: i for i, video_index in enumerate(use_action)}\n",
    "reindex_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6508a273-db87-4095-848a-5285758b2d8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_base[\"video_id\"] = df_base[\"video_id\"].map(reindex_action)\n",
    "df_base.sort_values([\"user_id\",\"video_id\"]).reset_index(drop=True)\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8f406c-d76a-4738-9a71-01fa8fb0925d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "key_ls = [\"user_id\"] + [f\"action_{a_id}\" for a_id in range(num_action)]\n",
    "user_reward_dict_for_clustering = {k:[] for k in key_ls}\n",
    "user_reward_dict_for_clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19263ccd-f673-4ae2-8012-62db53f8f259",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for u_id in df_base[\"user_id\"].unique():\n",
    "    user_reward_dict_for_clustering[\"user_id\"].append(u_id)\n",
    "    for a_id in range(num_action):\n",
    "        r = df_base[(df_base[\"user_id\"]==u_id) & (df_base[\"video_id\"]==a_id)][\"watch_ratio\"].values[0]\n",
    "        user_reward_dict_for_clustering[f\"action_{a_id}\"].append(r)\n",
    "user_reward_dict_for_clustering       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700b2180-9695-486b-a40b-7b79d56d2dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_reward_df = pd.DataFrame(user_reward_dict_for_clustering)\n",
    "user_reward_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d512bc8-7887-423d-a539-74a8ebb11706",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features = user_features[user_features[\"user_id\"].isin(df_base[\"user_id\"].unique())].reset_index(drop=True)\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee56eef6-5286-427a-83b1-60d528fb1654",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features_onehot = user_features_onehot[user_features_onehot[\"user_id\"].isin(df_base[\"user_id\"].unique())].reset_index(drop=True)\n",
    "user_features_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f74ee1c-6858-4e80-8080-aa2ea40b4909",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "user_feature_not_null_except_user_id_row = user_features.drop(\"user_id\",axis=1).columns.tolist()\n",
    "\n",
    "user_feature_not_null_except_user_id = user_features_onehot.drop(\"user_id\",axis=1).columns.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27517189-95ef-49da-bbe6-b74719ee331d",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_feature_not_null_except_user_id_row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fea139-a463-4287-8207-191c88d087d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_u = user_features_onehot[user_feature_not_null_except_user_id].values\n",
    "X_u.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b31c53-3929-41cb-bde4-829221065ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "tsne = TSNE(n_components=2, random_state=SEED)\n",
    "X_u_reduced_tsne = tsne.fit_transform(X_u)\n",
    "X_u_reduced_tsne"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55787c1f-997d-429e-9096-1995f1cf1b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features[\"TSNE-feature_0\"] = X_u_reduced_tsne.T[0]\n",
    "user_features[\"TSNE-feature_1\"] = X_u_reduced_tsne.T[1]\n",
    "user_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df6ce284-6910-487f-b699-42cc14731157",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_features_onehot[\"TSNE-feature_0\"] = X_u_reduced_tsne.T[0]\n",
    "user_features_onehot[\"TSNE-feature_1\"] = X_u_reduced_tsne.T[1]\n",
    "user_features_onehot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcfc7635-b0ec-42a4-a7e6-0a16cc57e763",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_reward_df[\"all_action_reward_mean\"] = user_reward_df.drop(\"user_id\", axis=1).mean(axis=1)\n",
    "user_reward_df = user_reward_df.sort_values(\"all_action_reward_mean\").reset_index(drop=True)\n",
    "user_reward_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a3418ef-0a76-4070-a072-529b0533d603",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_reward_df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e549c04-21e8-49b3-b04b-f6753ab7af6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def user_split_softmax(use_domain: int) -> List[float]:\n",
    "    reward_mean_each_user  = np.array([user_reward_df[user_reward_df[\"user_id\"]==user_id][\"all_action_reward_mean\"].values[0] for user_id in user_reward_df[\"user_id\"].unique()])\n",
    "    prob_each_user = np.exp(alpha_ls[use_domain] * reward_mean_each_user) / np.sum(np.exp(alpha_ls[use_domain] * reward_mean_each_user))\n",
    "    return prob_each_user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8fd793a-c25e-41b3-b028-85b9cb455e87",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha_ls = [-0.5, -0.4, -0.3, -0.2, -0.1, 0.2, 0.4, 0.6, 0.8, 1.0]\n",
    "alpha_ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d4d6d39-bde5-466d-9662-f56b3b7e83d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_domain = 10\n",
    "num_domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8336a6c6-c835-4dd2-8250-27f9f5683811",
   "metadata": {},
   "outputs": [],
   "source": [
    "td_num = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4269aa40-e486-40ef-a379-5f5875c9a245",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_base = df_base.merge(user_features[[\"user_id\",\"TSNE-feature_0\", \"TSNE-feature_1\"]], how=\"left\", on=\"user_id\")\n",
    "df_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b66b9f8-47dc-4abe-8faa-a6e9ba19f4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "r_matrix = defaultdict(lambda: defaultdict(float))\n",
    "r_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c866c3a2-0cfe-4c3d-b9b5-ddc59c795ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for u,i,r in zip(df_base[\"user_id\"].values,df_base[\"video_id\"].values,df_base[\"watch_ratio\"].values):\n",
    "    r_matrix[u][i] = r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42b39352-143b-4a4f-9a5e-b298929ce224",
   "metadata": {},
   "outputs": [],
   "source": [
    "lp_beta =[1.7554052086783454,\n",
    " -0.0778302798603224,\n",
    " 0.4358573654398761,\n",
    " 0.9618667569858963,\n",
    " -1.6885912553614273,\n",
    " 1.4200556375331574,\n",
    " 0.07585659348725438,\n",
    " 1.4690940931261292,\n",
    " -0.1691014474606618,\n",
    " 1.666700954096115]\n",
    "lp_beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b9e294f-faa3-4c37-ba0b-0351837277e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ep_epsilon = [0.2 for _ in range(num_domain)]\n",
    "ep_epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47f031fc-d342-4d15-9fdb-56414e0f849b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "dir = \"./noise_dict_-3_3.json\"\n",
    "encoding = \"utf-8\"\t\n",
    "with open(dir, mode=\"rt\", encoding=\"utf-8\") as f:\n",
    "\tnoise_dict_before = json.load(f)\n",
    "noise_dict = {int(k): v for k,v in noise_dict_before.items()}\n",
    "noise_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af28d48-7704-42e8-a72c-823a0155697d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def logging_policy_softmax(user_id: int, use_domain: int) -> List[float]:\n",
    "    reward_each_action  = np.array([r_matrix[user_id][action] + noise_dict[user_id][action] for action in range(num_action)])\n",
    "    reward_each_action_beta = lp_beta[use_domain] * reward_each_action\n",
    "    \n",
    "    reward_each_action_beta -= reward_each_action_beta.max()\n",
    "    \n",
    "    prob_each_action = np.exp(reward_each_action_beta) / np.sum(np.exp(reward_each_action_beta))\n",
    "    \n",
    "    if use_domain == td_num:\n",
    "        unsupported_action_prob_ls = []\n",
    "        for unsupported_action_index in unsupported_action_index_array:\n",
    "            unsupported_action_prob_ls.append(prob_each_action[unsupported_action_index])\n",
    "            prob_each_action[unsupported_action_index] = 0\n",
    "        for unsupported_action_prob in unsupported_action_prob_ls:\n",
    "            add_prob = unsupported_action_prob / (num_action - len(unsupported_action_prob_ls))\n",
    "            for i in range(num_action):\n",
    "                if i in unsupported_action_index_array:\n",
    "                    pass\n",
    "                else:\n",
    "                    prob_each_action[i] += add_prob\n",
    "                \n",
    "        \n",
    "    return prob_each_action\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def evaluation_policy_mix_deterministic_and_epsilongreedy(user_id: int, use_domain: int) -> List[float]:\n",
    "    #reward_each_action  = np.array([r_matrix[user_id][action] + noise_dict[user_id][action] for action in range(num_action)])\n",
    "    reward_each_action  = np.array([r_matrix[user_id][action] for action in range(num_action)])\n",
    "    prob_each_action = np.array([(1-ep_epsilon[use_domain]) if  i == np.argmax(reward_each_action) else 0 for i in range(num_action)])+ (ep_epsilon[use_domain] / num_action)\n",
    "    return prob_each_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b391c5-2168-4d2f-b405-9db9cafad007",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_true_v(td_domain_num):\n",
    "    true_v = 0\n",
    "    user_id_vec = user_reward_df[\"user_id\"].unique()\n",
    "    user_p = user_split_softmax(td_domain_num)\n",
    "    for i, u_id in enumerate(user_id_vec):\n",
    "        u_v = 0\n",
    "        prob_each_action = evaluation_policy_mix_deterministic_and_epsilongreedy(u_id, td_domain_num)\n",
    "        for a, p in enumerate(prob_each_action):\n",
    "            u_v += p * (r_matrix[u_id][a])\n",
    "        true_v += user_p[i]*u_v\n",
    "\n",
    "    return true_v\n",
    "\n",
    "\n",
    "def log_data_generate(domain_num: int, log_data_sample_size: int, seed: int) -> Tuple[List[List[float]], List[int], List[float]]:\n",
    "    user_id_vec = []\n",
    "    context_vec = []\n",
    "    action_vec = []\n",
    "    reward_vec = []\n",
    "    u_prob = user_split_softmax(domain_num)\n",
    "    unique_user_id = user_reward_df[\"user_id\"].unique()\n",
    "    \n",
    "    for i in range(log_data_sample_size):\n",
    "        \n",
    "        u_id = rng.choice(unique_user_id, size=1, p=u_prob)[0]\n",
    "        \n",
    "        context_sample = user_features[user_features[\"user_id\"]==u_id][user_feature_not_null_except_user_id_row].values[0].tolist()\n",
    "        \n",
    "        action_sample = rng.choice(num_action, size=1, p=logging_policy_softmax(u_id, domain_num))[0]\n",
    "        \n",
    "        reward_sample = rng.normal(r_matrix[u_id][action_sample], 1)\n",
    "\n",
    "        \n",
    "        user_id_vec.append(u_id)\n",
    "        context_vec.append(context_sample)\n",
    "        action_vec.append(action_sample)\n",
    "        reward_vec.append(reward_sample)\n",
    "        \n",
    "    return user_id_vec, context_vec, action_vec, reward_vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a06aeba-e748-43a8-984d-b0d9d0666bae",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "true_v = calc_true_v(td_num)\n",
    "true_v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f853158-f37e-4cec-b3a1-28cf948c2a83",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "true_v_dict = {d: 0 for d in range(num_domain)}\n",
    "for d in range(num_domain):\n",
    "    true_v_dict[d] = calc_true_v(d)\n",
    "true_v_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3ff024c-a36a-4ed6-babe-6e1fcce850b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "target_domain_sample_size = 100\n",
    "target_domain_sample_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75d3d106-80ad-46eb-a490-4289438572bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "source_domain_sample_size = 100\n",
    "source_domain_sample_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9b15503-7624-47f6-9275-1027039cdbce",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "fold_k = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08319db6-d7c7-439a-b193-af89dd2048b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dim_x = len(user_feature_not_null_except_user_id_row)\n",
    "dim_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f70c57c9-d259-47fa-868d-b2ac6962775e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dm_estimator(user_id_vec: List[int], context_vec: List[List[float]], sample_size: int, fold_num: int) -> float:\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    for one_context_vec in context_vec:\n",
    "        context_and_action_one_sample = []\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"target_domain\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    for i in range(sample_size):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], td_num))\n",
    "        \n",
    "    estimates = 0\n",
    "    for p, q in zip(evaluation_policy_prob_each_sample_each_action, y_pred_each_sample_each_action):\n",
    "        estimates += p*q\n",
    "\n",
    "    return estimates / sample_size\n",
    "\n",
    "\n",
    "def dm_estimator_all_domain(user_id_vec: List[int], context_vec: List[List[float]], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "    user_id_vec = (np.array(user_id_vec)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    context_vec = (np.array(context_vec)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    for one_context_vec in context_vec:\n",
    "        context_and_action_one_sample = []\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"ALL_domain\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    #print(len(user_id_vec),len(context_vec),len(domain_index_ls))\n",
    "    for i, d in enumerate(domain_index_ls):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], d))\n",
    "        \n",
    "    for p, q in zip(evaluation_policy_prob_each_sample_each_action, y_pred_each_sample_each_action):\n",
    "        estimates += p*q\n",
    "    \n",
    "    return estimates / len(context_vec)\n",
    "\n",
    "\n",
    "def ips_estimator(user_id_vec: List[int], sample_size: int, action_vec: List[int], reward_vec: List[float]) -> float:\n",
    "    estimates = 0\n",
    "    for i in range(sample_size):\n",
    "        evaluation_policy_prob = evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], td_num)[action_vec[i]]\n",
    "        logging_policy_prob = logging_policy_softmax(user_id_vec[i], td_num)[action_vec[i]]\n",
    "        estimates += (evaluation_policy_prob/logging_policy_prob) * reward_vec[i]\n",
    "    return estimates / sample_size\n",
    "\n",
    "\n",
    "def ips_estimator_all_domain(user_id_ls_all_domain: List[int], action_ls_all_domain: List[int], reward_ls_all_domain: List[float], domain_index_ls: List[str]) -> float:\n",
    "    \n",
    "    estimates = 0\n",
    "    user_id_vec = (np.array(user_id_ls_all_domain)).tolist()\n",
    "    action_vec = (np.array(action_ls_all_domain)).tolist()\n",
    "    reward_vec = (np.array(reward_ls_all_domain)).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls)).tolist()\n",
    "\n",
    "    #print(user_id_vec)\n",
    "\n",
    "    for i, d in enumerate(domain_index_ls):\n",
    "        evaluation_policy_prob = evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], d)[action_vec[i]]\n",
    "        logging_policy_prob = logging_policy_softmax(user_id_vec[i], d)[action_vec[i]]\n",
    "        estimates += (evaluation_policy_prob/logging_policy_prob) * reward_vec[i]\n",
    "\n",
    "    return estimates / len(user_id_vec)\n",
    "\n",
    "\n",
    "def ips_estimator_all_domain_for_dr_all(user_id_ls_all_domain: List[int], context_ls_all_domain: List[List[float]], action_ls_all_domain: List[int], reward_ls_all_domain: List[float], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "    user_id_vec = (np.array(user_id_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    context_vec = (np.array(context_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    action_vec = (np.array(action_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    reward_vec = (np.array(reward_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "\n",
    "    for i, d in enumerate(domain_index_ls):\n",
    "        evaluation_policy_prob = evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], d)[action_vec[i]]\n",
    "        logging_policy_prob = logging_policy_softmax(user_id_vec[i], d)[action_vec[i]]\n",
    "        estimates += (evaluation_policy_prob/logging_policy_prob) * reward_vec[i]\n",
    "\n",
    "    return estimates / len(user_id_vec)\n",
    "\n",
    "\n",
    "\n",
    "def dr_estimator(user_id_vec: List[int], sample_size: int, context_vec: List[List[float]], action_vec: List[int], reward_vec: List[float], fold_num: int) -> float:\n",
    "    context_and_action_vec_all_target_domain_sample = np.hstack((np.array(context_vec),np.array(action_vec).reshape(len(action_vec),-1)))\n",
    "    y_pred = reward_models[\"target_domain\"][fold_num].predict(context_and_action_vec_all_target_domain_sample)\n",
    "    diff_r_q_hat = np.array(reward_vec) - np.array(y_pred) \n",
    "    return ips_estimator(user_id_vec, sample_size, action_vec, diff_r_q_hat) + dm_estimator(user_id_vec, context_vec, sample_size, fold_num)\n",
    "\n",
    "\n",
    "\n",
    "def dr_estimator_all_domain(user_id_ls_all_domain: List[int], context_ls_all_domain: List[List[float]], action_ls_all_domain: List[int], reward_ls_all_domain: List[float], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    context_and_action_vec_all_one_domain_sample = np.hstack((np.array(context_ls_all_domain),np.array(action_ls_all_domain).reshape(len(np.array(action_ls_all_domain)),-1)))\n",
    "    y_pred = reward_models[\"ALL_domain\"][fold_num].predict(context_and_action_vec_all_one_domain_sample)\n",
    "    diff_r_q_hat = (np.array(reward_ls_all_domain) - np.array(y_pred)).tolist()\n",
    "    return ips_estimator_all_domain_for_dr_all(user_id_ls_all_domain, context_ls_all_domain, action_ls_all_domain, diff_r_q_hat, domain_index_ls, fold_num) + dm_estimator_all_domain(user_id_ls_all_domain, context_ls_all_domain, domain_index_ls, fold_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e404dc81-2fab-49b4-a227-d6eb34579168",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "\n",
    "def multi_source_offcem_estimator(user_id_ls_target_cluster_domain: List[int], context_ls_target_cluster_domain: List[List[float]], action_ls_target_cluster_domain: List[int], reward_ls_target_cluster_domain: List[float], domain_index_ls_only_tg_cluster: List[str], context_and_domain_feature_ls_target_cluster_domain: List[List[float]], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "    \n",
    "    user_id_vec = (np.array(user_id_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    context_vec = (np.array(context_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    context_and_domain_feature_vec = (np.array(context_and_domain_feature_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    action_vec = (np.array(action_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    reward_vec = (np.array(reward_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls_only_tg_cluster)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    \n",
    "    \n",
    "    context_and_domain_feature_and_action_vec = np.hstack((np.array(context_and_domain_feature_vec),np.array(action_vec).reshape(len(np.array(action_vec)),-1)))\n",
    "    y_pred = reward_models[\"Cluster\"][fold_num].predict(context_and_domain_feature_and_action_vec)\n",
    "    diff_r_f_hat = np.array(reward_vec) - np.array(y_pred) \n",
    "    \n",
    "    \n",
    "    for i in range(len(diff_r_f_hat)): \n",
    "        evaluation_policy_prob = evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i],td_num)[action_vec[i]]\n",
    "        prob_x_a = calc_mean_prob_joint_x_a(user_id_vec[i], context_vec[i], action_vec[i], domain_index_ls, fold_num)\n",
    "        estimates += diff_r_f_hat[i]*(evaluation_policy_prob / prob_x_a)\n",
    "    \n",
    "    return (estimates / len(diff_r_f_hat)) + dm_estimator_for_multi_source_offcem(user_id_vec, context_vec, context_and_domain_feature_vec, domain_index_ls, fold_num)\n",
    "\n",
    "\n",
    "\n",
    "def calc_mean_prob_joint_x_a(user_id: int, context: List[float], action: int, domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    prob_x_a = 0\n",
    "\n",
    "    for d in sorted(td_cluster_domains):\n",
    "        logging_policy_prob = logging_policy_softmax(user_id, d)[action]\n",
    "        denstiy_ratio = density_ratio_models[d][0].compute_density_ratio(np.array(context).reshape(1,dim_x))[0]\n",
    "        \n",
    "        if d == td_num:\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob# * denstiy_ratio\n",
    "        else:\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob * denstiy_ratio\n",
    "    return prob_x_a / len(domain_index_ls)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def dm_estimator_for_multi_source_offcem(user_id_vec: List[int], context_vec:List[List[float]], context_and_domain_feature_vec: List[List[float]], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    domain_index_ls_all_action = []\n",
    "    for i, one_context_and_domain_feature_vec in enumerate(context_and_domain_feature_vec):\n",
    "        domain_index_ls_all_action += [domain_index_ls[i]]*num_action\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_and_domain_feature_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"Cluster\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    for i, d in enumerate(domain_index_ls):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy_mix_deterministic_and_epsilongreedy(user_id_vec[i], d))\n",
    "    \n",
    "    \n",
    "    for i in range(len(domain_index_ls_all_action)):\n",
    "        if domain_index_ls_all_action[i]==td_num:\n",
    "            estimates += evaluation_policy_prob_each_sample_each_action[i]*y_pred_each_sample_each_action[i]\n",
    "            \n",
    "    return estimates / (domain_index_ls.count(td_num))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d04f03b-45bf-489b-b664-7762c4d25f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ls = [i for i in range(0,200)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be8781d2-f42e-498b-9189-71ca1e6defdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "unsupported_action_num_ls = [0, 6, 12, 18, 24]\n",
    "unsupported_action_num_ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fea814-7e0b-4c8d-87f1-2a37e9ad0bf0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimate_dict = {\"DM\": {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"DM_ALL\":  {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"IPS\": {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"IPS_ALL\":  {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"DR\":  {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"DR_ALL\": {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}, \n",
    "                 \"OFFCEM\": {unsupported_action_num : [] for unsupported_action_num in unsupported_action_num_ls}}\n",
    "estimate_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18a6ef6b-66f5-475a-ae6d-8639a7827210",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "for seed in seed_ls:\n",
    "    print(f\"NOW SEED = {seed}\")\n",
    "    for unsupport_action_num in unsupported_action_num_ls:\n",
    "        \n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        np.random.seed(seed)\n",
    "        unsupported_action_index_array = rng.choice(num_action, size=unsupport_action_num,replace=False)\n",
    "        \n",
    "        \n",
    "        user_id_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        context_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        action_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        reward_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "    \n",
    "        \n",
    "        for d in reward_dict_each_domain.keys():\n",
    "            if d == td_num:\n",
    "                user_id_vec, context_vec, action_vec, reward_vec = log_data_generate(d, target_domain_sample_size, seed)\n",
    "            else:\n",
    "                user_id_vec, context_vec, action_vec, reward_vec = log_data_generate(d, source_domain_sample_size, seed)\n",
    "            user_id_dict_each_domain[d] = user_id_vec\n",
    "            context_dict_each_domain[d] = context_vec\n",
    "            action_dict_each_domain[d] = action_vec\n",
    "            reward_dict_each_domain[d] = reward_vec\n",
    "            \n",
    "        \n",
    "        domain_feature_dict_each_domain = {key: [] for key in range(num_domain)}\n",
    "        domain_feature_dict_each_domain_ = {key: [] for key in range(num_domain)}\n",
    "        for d in domain_feature_dict_each_domain.keys():\n",
    "            domain_feature_dict_each_domain[d] = np.mean(reward_dict_each_domain[d])\n",
    "            domain_feature_dict_each_domain_[d] = np.mean(reward_dict_each_domain[d])\n",
    "\n",
    "        \n",
    "        td_cluster_domains = []\n",
    "        td_d_feature = domain_feature_dict_each_domain[td_num]\n",
    "        for k,v in domain_feature_dict_each_domain_.items():\n",
    "            domain_feature_dict_each_domain_[k] = abs(td_d_feature-v)\n",
    "        \n",
    "        domain_feature_dict_each_domain_ = sorted(domain_feature_dict_each_domain_.items(), key=lambda x:x[1])\n",
    "        for i in range(td_cluster_num):\n",
    "            td_cluster_domains.append(domain_feature_dict_each_domain_[i][0])\n",
    "        \n",
    "        print(td_cluster_domains)\n",
    "        \n",
    "        \n",
    "        print(\"Fit reward model.\")\n",
    "        \n",
    "        reward_models = {\"target_domain\": [], \"ALL_domain\": [], \"Cluster\": []}\n",
    "        \n",
    "        valid_datas =  {k : [] for k in reward_models.keys()}\n",
    "        \n",
    "        \n",
    "        for data_type in reward_models.keys():\n",
    "            if data_type == \"target_domain\":\n",
    "                kfold = KFold(n_splits=fold_k, shuffle=True, random_state=SEED)\n",
    "                context_and_action_vec = np.hstack((np.array(context_dict_each_domain[td_num]),np.array(action_dict_each_domain[td_num]).reshape(len(action_dict_each_domain[td_num]),-1)))\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_vec):\n",
    "                    X_train, X_test = context_and_action_vec[train_indices], context_and_action_vec[test_indices]\n",
    "                    y_train, y_test = np.array(reward_dict_each_domain[td_num])[train_indices], np.array(reward_dict_each_domain[td_num])[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8,random_state = SEED)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    # print(f\"MSE(Target domain data): {mean_squared_error(y_test,forest.predict(X_test))}\")\n",
    "                    \n",
    "                    reward_models[\"target_domain\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"target_domain\"].append(test_indices)\n",
    "            \n",
    "                    \n",
    "            elif data_type == \"ALL_domain\":\n",
    "                kfold = KFold(n_splits=fold_k, shuffle=True, random_state=SEED)\n",
    "                domain_index_ls = []\n",
    "                user_id_ls_all_domain = []\n",
    "                context_ls_all_domain = []\n",
    "                action_ls_all_domain = []\n",
    "                reward_ls_all_domain = []\n",
    "                for domain_num in context_dict_each_domain.keys():\n",
    "                    domain_index_ls += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "                    user_id_ls_all_domain += user_id_dict_each_domain[domain_num]\n",
    "                    context_ls_all_domain += context_dict_each_domain[domain_num]\n",
    "                    action_ls_all_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_all_domain += reward_dict_each_domain[domain_num]\n",
    "                context_and_action_vec_all_domain = np.hstack((np.array(context_ls_all_domain),np.array(action_ls_all_domain).reshape(len(action_ls_all_domain),-1)))\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_vec_all_domain):\n",
    "                    X_train, X_test = context_and_action_vec_all_domain[train_indices], context_and_action_vec_all_domain[test_indices]\n",
    "                    y_train, y_test = np.array(reward_ls_all_domain)[train_indices], np.array(reward_ls_all_domain)[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8,random_state = SEED)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    print(f\"RMSE(ALL domain data): {np.sqrt(mean_squared_error(y_test,forest.predict(X_test)))}\")\n",
    "                    \n",
    "                    reward_models[\"ALL_domain\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"ALL_domain\"].append(test_indices)\n",
    "    \n",
    "            else:\n",
    "                domain_index_ls_only_tg_cluster = []\n",
    "                user_id_ls_target_cluster_domain = []\n",
    "                context_ls_target_cluster_domain = []\n",
    "                action_ls_target_cluster_domain = []\n",
    "                reward_ls_target_cluster_domain = []\n",
    "                context_and_domain_feature_ls_target_cluster_domain = []\n",
    "                \n",
    "                for domain_num in sorted(td_cluster_domains):\n",
    "                    domain_index_ls_only_tg_cluster += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "                    user_id_ls_target_cluster_domain += user_id_dict_each_domain[domain_num]\n",
    "                    context_ls_target_cluster_domain += context_dict_each_domain[domain_num]\n",
    "                    action_ls_target_cluster_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_target_cluster_domain += reward_dict_each_domain[domain_num]\n",
    "                \n",
    "                    contest_and_domain_feature_ls = copy.deepcopy(context_dict_each_domain[domain_num])\n",
    "                    for i in range(len(contest_and_domain_feature_ls)):\n",
    "                        contest_and_domain_feature_ls[i] += [domain_feature_dict_each_domain[domain_num]]\n",
    "                    context_and_domain_feature_ls_target_cluster_domain += contest_and_domain_feature_ls\n",
    "                        \n",
    "                context_and_action_and_domain_feature_vec_target_cluster_domain = np.hstack((np.array(context_and_domain_feature_ls_target_cluster_domain),np.array(action_ls_target_cluster_domain).reshape(len(action_ls_target_cluster_domain),-1)))\n",
    "                kfold = StratifiedKFold(n_splits=fold_k, shuffle=True, random_state=SEED)\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_and_domain_feature_vec_target_cluster_domain, domain_index_ls_only_tg_cluster):\n",
    "                    X_train, X_test = context_and_action_and_domain_feature_vec_target_cluster_domain[train_indices], context_and_action_and_domain_feature_vec_target_cluster_domain[test_indices]\n",
    "                    y_train, y_test = np.array(reward_ls_target_cluster_domain)[train_indices], np.array(reward_ls_target_cluster_domain)[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8, random_state = SEED)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    #print(f\"MSE(target cluster data): {mean_squared_error(y_test,forest.predict(X_test))}\")\n",
    "                    \n",
    "                    reward_models[\"Cluster\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"Cluster\"].append(test_indices)\n",
    "        \n",
    "        print(\"Fit dens ratio model.\")\n",
    "        \n",
    "        density_ratio_models = {domain_num : [] for domain_num in sorted(td_cluster_domains)}\n",
    "        for domain_num in density_ratio_models.keys():\n",
    "            \n",
    "            model = densratio(np.array(context_dict_each_domain[domain_num]), np.array(context_dict_each_domain[td_num]), alpha=0.95, verbose=False) \n",
    "            \n",
    "            density_ratio_models[domain_num].append(model)\n",
    "        \n",
    "        \n",
    "        \n",
    "        print(\"Calc V-hat each estimator.\")\n",
    "        print(f\"ss={(target_domain_sample_size/(target_domain_sample_size+(source_domain_sample_size*num_domain)))*100}%\")\n",
    "    \n",
    "        \n",
    "        hat_dm = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dm += dm_estimator((np.array(user_id_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), (np.array(context_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), len(valid_datas[\"target_domain\"][fold_num]), fold_num)\n",
    "        hat_dm = hat_dm / fold_k\n",
    "        print(f\"hat-DM:{hat_dm}, MSE:{(hat_dm-true_v)**2}\")\n",
    "        \n",
    "        \n",
    "        hat_dm_all_domain = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dm_all_domain += dm_estimator_all_domain(user_id_ls_all_domain, context_ls_all_domain, domain_index_ls, fold_num)\n",
    "        hat_dm_all_domain = hat_dm_all_domain / fold_k\n",
    "        print(f\"hat-DM-alldomain:{hat_dm_all_domain}, MSE:{(hat_dm_all_domain-true_v)**2}\")\n",
    "        \n",
    "        \n",
    "        hat_ips = ips_estimator((np.array(user_id_dict_each_domain[td_num])).tolist(), len(user_id_dict_each_domain[td_num]), action_dict_each_domain[td_num], reward_dict_each_domain[td_num])\n",
    "        print(f\"hat-IPS:{hat_ips}, MSE:{(hat_ips-true_v)**2}\")\n",
    "    \n",
    "        \n",
    "        hat_ips_all_domain = ips_estimator_all_domain(user_id_ls_all_domain, action_ls_all_domain, reward_ls_all_domain, domain_index_ls)\n",
    "        print(f\"hat-IPS-alldomain:{hat_ips_all_domain}, MSE:{(hat_ips_all_domain-true_v)**2}\")\n",
    "        \n",
    "        \n",
    "        hat_dr = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dr += dr_estimator((np.array(user_id_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), len(valid_datas[\"target_domain\"][fold_num]), (np.array(context_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), (np.array(action_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), (np.array(reward_dict_each_domain[td_num])[valid_datas[\"target_domain\"][fold_num]]).tolist(), fold_num)\n",
    "        hat_dr = hat_dr / fold_k\n",
    "        print(f\"hat-DR:{hat_dr}, MSE:{(hat_dr-true_v)**2}\")\n",
    "        \n",
    "        \n",
    "        hat_dr_all_domain = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dr_all_domain += dr_estimator_all_domain(user_id_ls_all_domain, context_ls_all_domain, action_ls_all_domain, reward_ls_all_domain, domain_index_ls, fold_num)\n",
    "        hat_dr_all_domain = hat_dr_all_domain / fold_k\n",
    "        print(f\"hat-DR-alldomain:{hat_dr_all_domain}, MSE:{(hat_dr_all_domain-true_v)**2}\")\n",
    "    \n",
    "        \n",
    "        hat_offcem = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_offcem += multi_source_offcem_estimator(user_id_ls_target_cluster_domain, context_ls_target_cluster_domain, action_ls_target_cluster_domain, reward_ls_target_cluster_domain, domain_index_ls_only_tg_cluster, context_and_domain_feature_ls_target_cluster_domain, fold_num)\n",
    "        hat_offcem = hat_offcem / fold_k\n",
    "        print(f\"hat-OFFCEM:{hat_offcem}, MSE:{(hat_offcem-true_v)**2}\")\n",
    "    \n",
    "        \n",
    "        estimate_dict[\"DM\"][unsupport_action_num].append(hat_dm)\n",
    "        estimate_dict[\"DM_ALL\"][unsupport_action_num].append(hat_dm_all_domain)\n",
    "        estimate_dict[\"IPS\"][unsupport_action_num].append(hat_ips)\n",
    "        estimate_dict[\"IPS_ALL\"][unsupport_action_num].append(hat_ips_all_domain)\n",
    "        estimate_dict[\"DR\"][unsupport_action_num].append(hat_dr)\n",
    "        estimate_dict[\"DR_ALL\"][unsupport_action_num].append(hat_dr_all_domain)\n",
    "        estimate_dict[\"OFFCEM\"][unsupport_action_num].append(hat_offcem)\n",
    "        print(\"----------\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1caf0f44-336b-4f7c-9f7c-38390d25ef60",
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_dict_key(d, old_key, new_key, default_value=None):\n",
    "    d[new_key] = d.pop(old_key, default_value)\n",
    "change_dict_key(estimate_dict, 'OFFCEM', 'COPE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80de005f-4b38-4b29-9d8a-9a545176a0c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_metrics(estimates):\n",
    "    \n",
    "    data_list = []\n",
    "    for estimator_name, sample_sizes in estimates.items():\n",
    "        for sample_size, values in sample_sizes.items():\n",
    "            for i, value in enumerate(values):\n",
    "                data_list.append({'Estimator': estimator_name, 'Sample Size': sample_size/30, 'Value': value})\n",
    "\n",
    "    df = pd.DataFrame(data_list)\n",
    "    \n",
    "    df['MSE'] = (df['Value'] - true_value) ** 2\n",
    "\n",
    "    \n",
    "    return df\n",
    "\n",
    "true_value = true_v\n",
    "\n",
    "estimators = estimate_dict\n",
    "\n",
    "df_metrics = calculate_metrics(estimators)\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='MSE', hue='Estimator', data=df_metrics,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "\n",
    "\n",
    "\n",
    "plt.ylabel('', fontsize=0)\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "#ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('ratio of new actions in the target domain', fontsize=30, labelpad=20)\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "plt.xticks([0, 6/30,12/30,18/30,24/30])\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('MSE (log-scale)', fontsize=35)\n",
    "#plt.savefig(f'../../output_experiments/real-world_mse_graph_na.png', dpi=500, bbox_inches='tight');\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80c173f-3cdb-465e-945c-f26985481d81",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "true_value = true_v\n",
    "\n",
    "estimators = estimate_dict\n",
    "\n",
    "\n",
    "bias_dict = {}\n",
    "for estimator, samples in estimators.items():\n",
    "    bias_dict[estimator] = {}\n",
    "    for sample_size, estimates in samples.items():\n",
    "        V_hat_mean = np.mean(np.array(estimates))\n",
    "        \n",
    "        bias = (V_hat_mean - true_value)**2\n",
    "        bias_dict[estimator][sample_size] = bias\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_list = []\n",
    "for estimator, samples in bias_dict.items():\n",
    "    for sample_size, bias in samples.items():\n",
    "            data_list.append({'Estimator': estimator, 'Sample Size': sample_size/30, 'Bias': bias})\n",
    "df = pd.DataFrame(data_list)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='Bias', hue='Estimator', data=df,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.ylabel('', fontsize=0)\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "#ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('ratio of new actions in the target domain', fontsize=30, labelpad=20)\n",
    "#ax.set_xticks(xticklabels) \n",
    "#ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "plt.xticks([0,6/30,12/30,18/30,24/30])\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('Squared Bias(log-scale)', fontsize=35)\n",
    "#plt.savefig(f'../../output_experiments/real-world_bias_graph_na.png', dpi=500, bbox_inches='tight');\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c914286c-3830-4ec3-8740-3589d48a05b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "true_value = true_v\n",
    "\n",
    "estimators = estimate_dict\n",
    "\n",
    "\n",
    "variance_dict = {}\n",
    "for estimator, samples in estimators.items():\n",
    "    variance_dict[estimator] = {}\n",
    "    for sample_size, estimates in samples.items():\n",
    "        V_hat_mean = np.mean(np.array(estimates))\n",
    "       \n",
    "        variance =  np.sum(((np.array(estimates) - V_hat_mean)**2)) / len(estimates)\n",
    "        variance_dict[estimator][sample_size] = variance\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_list = []\n",
    "for estimator, samples in variance_dict.items():\n",
    "    for sample_size, variance in samples.items():\n",
    "            data_list.append({'Estimator': estimator, 'Sample Size': sample_size/30, 'Variance': variance})\n",
    "df = pd.DataFrame(data_list)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='Variance', hue='Estimator', data=df,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.ylabel('', fontsize=0)\n",
    "ax.set_yscale('log')\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "#ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('ratio of new actions in the target domain', fontsize=30, labelpad=20)\n",
    "#ax.set_xticks(xticklabels) \n",
    "#ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "plt.xticks([0,6/30,12/30,18/30,24/30])\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('Variance (log-scale)', fontsize=35)\n",
    "#plt.savefig(f'../../output_experiments/real-world_variance_graph_na.png', dpi=500,bbox_inches='tight');\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b702c0-9e37-4887-b6d6-a0f98a81149f",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "import json\n",
    "dir = \"./OPE_results/estimates_na.json\"\n",
    "with open(dir, mode=\"wt\", encoding=\"utf-8\") as f:\n",
    "\tjson.dump(estimate_dict, f, ensure_ascii=False, indent=2)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ca12074-6fcd-4af6-9e03-26c7fb12d91d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
