{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b46ef322-fc09-40f5-82a6-855dff7e5cb1",
   "metadata": {},
   "source": [
    "## 1. import "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ef57cab-9d2f-4383-a241-57b31e4e83a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy import stats\n",
    "from collections import defaultdict\n",
    "from typing import List, Tuple, Dict\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from densratio import densratio\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46124440-9e8e-4f8f-bbe9-8551c6025223",
   "metadata": {},
   "source": [
    "## 2. Set default parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3605b1-860d-4051-8ab0-95367544b671",
   "metadata": {},
   "outputs": [],
   "source": [
    "# default param\n",
    "SEED = 111111\n",
    "\n",
    "\n",
    "\n",
    "num_source_domain = 29 \n",
    "\n",
    "\n",
    "\n",
    "test_data_sample_saize = 100000\n",
    "\n",
    "\n",
    "num_cluster = 6\n",
    "\n",
    "\n",
    "num_action = 20\n",
    "\n",
    "\n",
    "dim_x = 10\n",
    "\n",
    "sigma_x = 1\n",
    "\n",
    "\n",
    "dim_e = 5\n",
    "\n",
    "mu_e = 0\n",
    "\n",
    "sigma_e = 1\n",
    "\n",
    "\n",
    "sigma_r = 1\n",
    "\n",
    "\n",
    "\n",
    "ep_epsilon = 0.2\n",
    "\n",
    "## reward_function param\n",
    "Lambda = 0.5 \n",
    "\n",
    "fold_k = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92c49621-f809-4f84-8b2b-e73888242274",
   "metadata": {},
   "source": [
    "## 3. Define functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527fb7cb-856c-463e-aa3a-06a381327ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(seed=SEED)\n",
    "np.random.seed(SEED)\n",
    "\n",
    "def reward_function(cluster: int, context: List[float], action: int, domain_feature: List[float]) -> float:\n",
    "    \n",
    "    one_hot_action = [1 if i==action else 0 for i in range(num_action)]\n",
    "    one_hot_cluster = [1 if i==cluster else 0 for i in range(num_cluster)]\n",
    "    \n",
    "    \n",
    "    residual_effect = np.array(theta_e) @ np.array(domain_feature) + np.array(context).T @ np.array(M_x_e) @ np.array(domain_feature) + np.array(one_hot_action).T @ np.array(M_a_e) @ np.array(domain_feature)\n",
    "\n",
    "    \n",
    "    cluster_effect = (np.array(theta_x_c[cluster]) @ np.array(context)) + (np.array(theta_a_c[cluster]) @ np.array(one_hot_action)) + (np.array(theta_c) @ np.array(one_hot_cluster)) + np.array(context).T @ np.array(M_x_a_c[cluster]) @ np.array(one_hot_action)\n",
    "    \n",
    "    return residual_effect + Lambda*cluster_effect\n",
    "\n",
    "def reward_function_add_noise(cluster: int, context: List[float], action: int, domain_feature: List[float]) -> float:\n",
    "    return reward_function(cluster, context, action, domain_feature) + rng.uniform(3.0, 10.0)\n",
    "\n",
    "def logging_policy(cluster: int, context: List[float], domain_feature:List[float], use_domain: str, noise_ls) -> List[float]:\n",
    "    reward_each_action  = np.array([reward_function(cluster, context, action, domain_feature) for action in range(num_action)])\n",
    "    reward_each_action = reward_each_action + np.array(noise_ls)\n",
    "    reward_each_action_beta = lp_beta[use_domain] * reward_each_action\n",
    "    \n",
    "    reward_each_action_beta -= reward_each_action_beta.max()\n",
    "    \n",
    "    prob_each_action = np.exp(reward_each_action_beta) / np.sum(np.exp(reward_each_action_beta))\n",
    "     \n",
    "        \n",
    "    return prob_each_action\n",
    "\n",
    "\n",
    "def evaluation_policy(cluster: int, context: List[float], domain_feature: List[float]) -> List[float]:\n",
    "    reward_each_action  = np.array([reward_function(cluster, context, action, domain_feature) for action in range(num_action)])\n",
    "    prob_each_action = np.array([(1-ep_epsilon) if  i == np.argmax(reward_each_action) else 0 for i in range(num_action)])+ (ep_epsilon / num_action)\n",
    "    return prob_each_action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df3f592e-c71d-4d8c-915a-42ef11466f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def calc_ground_truth_policy_value(cluster: int, sample_size: int, context_vec: List[List[float]], domain_feature_vec: List[float]) -> float:\n",
    "    ground_truth_policy_value = 0\n",
    "    for i in range(sample_size): \n",
    "        evaluation_policy_prob_each_action = evaluation_policy(cluster, context_vec[i], domain_feature_vec)\n",
    "        for action_index, evaluation_policy_prob in enumerate(evaluation_policy_prob_each_action): # O(10^2)\n",
    "            ground_truth_policy_value += evaluation_policy_prob*reward_function(cluster, context_vec[i], action_index, domain_feature_vec) \n",
    "            \n",
    "    return ground_truth_policy_value / sample_size\n",
    "\n",
    "\n",
    "def calc_mse(V_true: float, V_hat_ls: List[float]) -> float:\n",
    "    mse = np.sum(((V_true - np.array(V_hat_ls))**2)) / len(V_hat_ls)\n",
    "    return mse\n",
    "\n",
    "\n",
    "def calc_bias(V_true: float, V_hat_ls: List[float], is_square: bool = True) -> float:\n",
    "    V_hat_mean = np.mean(np.array(V_hat_ls))\n",
    "    if is_square: bias = (V_hat_mean - V_true)**2\n",
    "    else: bias = V_hat_mean - V_true\n",
    "    return bias\n",
    "\n",
    "\n",
    "def calc_variance(V_hat_ls: List[float]) -> float:\n",
    "    V_hat_mean = np.mean(np.array(V_hat_ls))\n",
    "    variance = np.sum(((np.array(V_hat_ls) - V_hat_mean)**2)) / len(V_hat_ls)\n",
    "    return variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c345aa91-1845-436b-b248-adc252529a8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def dm_estimator(cluster: int, sample_size: int, context_vec: List[List[float]], domain_feature_vec:List[float], fold_num: int) -> float:\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    for one_context_vec in context_vec:\n",
    "        context_and_action_one_sample = []\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"domain_0\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    for i in range(sample_size):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy(cluster, context_vec[i], domain_feature_vec))\n",
    "        \n",
    "    estimates = 0\n",
    "    for p, q in zip(evaluation_policy_prob_each_sample_each_action, y_pred_each_sample_each_action):\n",
    "        estimates += p*q\n",
    "\n",
    "    return estimates / sample_size\n",
    "\n",
    "\n",
    "\n",
    "def dm_estimator_all_domain(context_vec: List[List[float]], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "\n",
    "    context_vec = (np.array(context_vec)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    for one_context_vec in context_vec:\n",
    "        context_and_action_one_sample = []\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"ALL_domain\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    for i in range(len(context_vec)):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy(domain_info[domain_index_ls[i]][\"cluster\"], context_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"]))\n",
    "        \n",
    "    for p, q in zip(evaluation_policy_prob_each_sample_each_action, y_pred_each_sample_each_action):\n",
    "        estimates += p*q\n",
    "    \n",
    "    return estimates / len(context_vec)\n",
    "\n",
    "\n",
    "\n",
    "def ips_estimator(cluster: int, sample_size: int, context_vec: List[List[float]], action_vec: List[int], reward_vec: List[float], domain_feature_vec:List[float], noise_vec) -> float:\n",
    "    estimates = 0\n",
    "    for i in range(sample_size):\n",
    "        evaluation_policy_prob = evaluation_policy(cluster, context_vec[i], domain_feature_vec)[action_vec[i]]\n",
    "        logging_policy_prob = logging_policy(cluster, context_vec[i], domain_feature_vec, \"domain_0\", noise_vec[i])[action_vec[i]]\n",
    "        estimates += (evaluation_policy_prob/logging_policy_prob) * reward_vec[i]\n",
    "    return estimates / sample_size\n",
    "\n",
    "\n",
    "\n",
    "def ips_estimator_all_domain(context_ls_all_domain: List[List[float]], action_ls_all_domain: List[int], reward_ls_all_domain: List[float], domain_index_ls: List[str], fold_num: int, noise_ls_all_domain) -> float:\n",
    "    estimates = 0\n",
    "    context_vec = (np.array(context_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    action_vec = (np.array(action_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    reward_vec = (np.array(reward_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "    noise_vec = (np.array(noise_ls_all_domain)[valid_datas[\"ALL_domain\"][fold_num]]).tolist()\n",
    "\n",
    "    for i in range(len(context_vec)):\n",
    "        evaluation_policy_prob = evaluation_policy(domain_info[domain_index_ls[i]][\"cluster\"], context_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"])[action_vec[i]]\n",
    "        logging_policy_prob = logging_policy(domain_info[domain_index_ls[i]][\"cluster\"], context_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"], domain_index_ls[i], noise_vec[i])[action_vec[i]]\n",
    "        estimates += (evaluation_policy_prob/logging_policy_prob) * reward_vec[i]\n",
    "\n",
    "    return estimates / len(context_vec)\n",
    "\n",
    "\n",
    "def dr_estimator(cluster: int, sample_size: int, context_vec: List[List[float]], action_vec: List[int], reward_vec: List[float], domain_feature_vec:List[float], fold_num: int, noise_vec) -> float:\n",
    "    context_and_action_vec_all_target_domain_sample = np.hstack((np.array(context_vec),np.array(action_vec).reshape(len(action_vec),-1)))\n",
    "    y_pred = reward_models[\"domain_0\"][fold_num].predict(context_and_action_vec_all_target_domain_sample)\n",
    "    diff_r_q_hat = np.array(reward_vec) - np.array(y_pred) \n",
    "    return ips_estimator(cluster, sample_size, context_vec, action_vec, diff_r_q_hat, domain_feature_vec, noise_vec) + dm_estimator(cluster, sample_size, context_vec, domain_feature_vec, fold_num)\n",
    "\n",
    "\n",
    "\n",
    "def dr_estimator_all_domain(context_ls_all_domain: List[List[float]], action_ls_all_domain: List[int], reward_ls_all_domain: List[float], domain_index_ls: List[str], fold_num: int, noise_ls_all_domain) -> float:\n",
    "    context_and_action_vec_all_one_domain_sample = np.hstack((np.array(context_ls_all_domain),np.array(action_ls_all_domain).reshape(len(np.array(action_ls_all_domain)),-1)))\n",
    "    y_pred = reward_models[\"ALL_domain\"][fold_num].predict(context_and_action_vec_all_one_domain_sample)\n",
    "    diff_r_q_hat = (np.array(reward_ls_all_domain) - np.array(y_pred)).tolist()\n",
    "    return ips_estimator_all_domain(context_ls_all_domain, action_ls_all_domain, diff_r_q_hat, domain_index_ls, fold_num, noise_ls_all_domain) + dm_estimator_all_domain(context_ls_all_domain, domain_index_ls, fold_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b985826d-39c6-44ce-bcfd-631a7abeb74a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def multi_source_offcem_estimator(cluster: int, context_ls_target_cluster_domain: List[List[float]], action_ls_target_cluster_domain: List[int], reward_ls_target_cluster_domain: List[float], domain_index_ls_only_tg_cluster: List[str], context_and_domain_feature_ls_target_cluster_domain: List[List[float]], fold_num: int, noise_ls_target_cluster_domain) -> float:\n",
    "    estimates = 0\n",
    "    \n",
    "    context_vec = (np.array(context_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    context_and_domain_feature_vec = (np.array(context_and_domain_feature_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    action_vec = (np.array(action_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    reward_vec = (np.array(reward_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    domain_index_ls = (np.array(domain_index_ls_only_tg_cluster)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    noise_vec = (np.array(noise_ls_target_cluster_domain)[valid_datas[\"Cluster\"][fold_num]]).tolist()\n",
    "    \n",
    "    \n",
    "    context_and_domain_feature_and_action_vec = np.hstack((np.array(context_and_domain_feature_vec),np.array(action_vec).reshape(len(np.array(action_vec)),-1)))\n",
    "    y_pred = reward_models[\"Cluster\"][fold_num].predict(context_and_domain_feature_and_action_vec)\n",
    "    diff_r_f_hat = np.array(reward_vec) - np.array(y_pred) \n",
    "    \n",
    "    \n",
    "    for i in range(len(diff_r_f_hat)): \n",
    "        evaluation_policy_prob = evaluation_policy(cluster, context_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"])[action_vec[i]]\n",
    "        prob_x_a = calc_mean_prob_joint_x_a(cluster, context_vec[i], action_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"], domain_index_ls, fold_num, noise_vec[i])\n",
    "        estimates += diff_r_f_hat[i]*(evaluation_policy_prob / prob_x_a)\n",
    "        \n",
    "    return (estimates / len(diff_r_f_hat)) + dm_estimator_for_multi_source_offcem(context_vec, context_and_domain_feature_vec, domain_index_ls, fold_num)\n",
    "\n",
    "def dm_estimator_for_multi_source_offcem(context_vec:List[List[float]], context_and_domain_feature_vec: List[List[float]], domain_index_ls: List[str], fold_num: int) -> float:\n",
    "    estimates = 0\n",
    "    \n",
    "    \n",
    "    context_and_action_vec_each_sample_each_action = []\n",
    "    domain_index_ls_all_action = []\n",
    "    for i, one_context_and_domain_feature_vec in enumerate(context_and_domain_feature_vec):\n",
    "        domain_index_ls_all_action += [domain_index_ls[i]]*num_action\n",
    "        for one_action in range(num_action):\n",
    "            context_and_action_vec_each_sample_each_action.append(one_context_and_domain_feature_vec+[one_action])\n",
    "    context_and_action_vec_each_sample_each_action = np.array(context_and_action_vec_each_sample_each_action)\n",
    "    \n",
    "    \n",
    "    y_pred_each_sample_each_action = reward_models[\"Cluster\"][fold_num].predict(context_and_action_vec_each_sample_each_action)\n",
    "\n",
    "    evaluation_policy_prob_each_sample_each_action = []\n",
    "    for i in range(len(context_vec)):\n",
    "        evaluation_policy_prob_each_sample_each_action += list(evaluation_policy(domain_info[domain_index_ls[i]][\"cluster\"], context_vec[i], domain_info[domain_index_ls[i]][\"domain_feature\"]))\n",
    "    \n",
    "    \n",
    "    for i in range(len(domain_index_ls_all_action)):\n",
    "        if domain_index_ls_all_action[i]==\"domain_0\":\n",
    "            estimates += evaluation_policy_prob_each_sample_each_action[i]*y_pred_each_sample_each_action[i]\n",
    "            \n",
    "    return estimates / (domain_index_ls.count(\"domain_0\"))\n",
    "\n",
    "\n",
    "def calc_mean_prob_joint_x_a(cluster: int, context: List[float], action: int, domain_feature: List[float], domain_index_ls: List[str], fold_num: int, noise_ls) -> float:\n",
    "    prob_x_a = 0\n",
    "\n",
    "    for d in cluster_info[cluster]:\n",
    "        logging_policy_prob = logging_policy(cluster, context, domain_feature, d, noise_ls)[action]\n",
    "        #density_ratio_model = density_ratio_models[d][0]\n",
    "        denstiy_ratio = density_ratio_models[d][0].compute_density_ratio(np.array(context).reshape(1,dim_x))[0]\n",
    "        if d == \"domain_0\":\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob# * denstiy_ratio\n",
    "        else:\n",
    "            prob_x_a +=  domain_index_ls.count(d)* logging_policy_prob * denstiy_ratio\n",
    "        \n",
    "\n",
    "    return prob_x_a / len(domain_index_ls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d80ce7c-0762-4ed7-864e-12207e4d973e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f83317b-9034-44ab-a28d-0f323f68cc1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_data_generate(seed: int, sample_size: int) -> Tuple[List[List[float]], List[int], List[float]]:\n",
    "    rng_ = np.random.default_rng(seed=123456)\n",
    "    \n",
    "\n",
    "    context_vec = []\n",
    "    action_vec = []\n",
    "    reward_vec = []\n",
    "    noise_vec = []\n",
    "        \n",
    "    for i in range(sample_size): \n",
    "        \n",
    "        context_sample = rng.normal(domain_info[domain_num][\"mu_k\"], sigma_x, dim_x).tolist()\n",
    "        \n",
    "        \n",
    "        noise_ls = rng_.uniform(-0.5,0.5,num_action).tolist()\n",
    "\n",
    "        \n",
    "        prob_each_action = logging_policy(domain_info[domain_num][\"cluster\"], context_sample, domain_info[domain_num][\"domain_feature\"], domain_num, noise_ls)\n",
    "        action_sample = rng.choice(num_action, size=1, p=prob_each_action)[0]\n",
    "\n",
    "        \n",
    "        mu_r = reward_function(domain_info[domain_num][\"cluster\"], context_sample, action_sample, domain_info[domain_num][\"domain_feature\"])\n",
    "        reward_sample = rng.normal(mu_r, sigma_r)\n",
    "\n",
    "        \n",
    "        context_vec.append(context_sample)\n",
    "        action_vec.append(action_sample)\n",
    "        reward_vec.append(reward_sample)\n",
    "        noise_vec.append(noise_ls)\n",
    "        \n",
    "    return context_vec, action_vec, reward_vec, noise_vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ac204ad-76ae-488f-8143-8a3124184c88",
   "metadata": {},
   "source": [
    "## 4. Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee43c056-1895-4472-8488-5edef2de66f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(seed=SEED)\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# cluster effect\n",
    "theta_x_c = [[rng.uniform(-1.0, 1.0) for _ in range(dim_x)] for _ in range(num_cluster)]\n",
    "theta_a_c = [[rng.uniform(-1.0, 1.0) for _ in range(num_action)] for _ in range(num_cluster)] \n",
    "M_x_a_c = [[[rng.uniform(-1.0, 1.0) for _ in range(num_action)] for _ in range(dim_x)] for _ in range(num_cluster)] \n",
    "theta_c = [rng.uniform(-1.0, 1.0) for _ in range(num_cluster)] \n",
    "\n",
    "# residual effect\n",
    "M_x_e = [[rng.uniform(-1.0, 1.0) for _ in range(dim_e)] for _ in range(dim_x)] \n",
    "M_a_e = [[rng.uniform(-1.0, 1.0) for _ in range(dim_e)] for _ in range(num_action)] \n",
    "theta_e = [rng.uniform(-10.0, 10.0) for _ in range(dim_e)] \n",
    "\n",
    "# theta_x_c, theta_a_c, M_x_a_c, theta_c, M_x_e, M_a_e, theta_e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70a44c38-e769-4daa-942e-21d571f8360a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "domain_info = {f\"domain_{i}\": {\"target_domain_flag\": -1, \"cluster\": -1, \"mu_k\": -1, \"domain_feature\":[]} for i in range(num_source_domain+1)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f54d5d1-8d23-41e2-bfa4-91e4143e41e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for domain_num in domain_info.keys():\n",
    "    \n",
    "    domain_info[domain_num][\"cluster\"] = rng.choice(num_cluster, size=1)[0]\n",
    "    \n",
    "    \n",
    "    \n",
    "    domain_info[domain_num][\"domain_feature\"] = rng.normal(mu_e, sigma_e, dim_e).tolist()\n",
    "\n",
    "    \n",
    "    domain_info[domain_num][\"mu_k\"] = rng.uniform(-1.0, 1.0) \n",
    "\n",
    "    \n",
    "    if domain_num == \"domain_0\": domain_info[domain_num][\"target_domain_flag\"] = 1\n",
    "    else : domain_info[domain_num][\"target_domain_flag\"] = 0\n",
    "\n",
    "\n",
    "cluster_info = {i:[] for i in range(num_cluster)}\n",
    "for domain_num in domain_info.keys():\n",
    "    cluster_info[domain_info[domain_num][\"cluster\"]].append(domain_num)\n",
    "cluster_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ff87bb7-fb9a-4897-bc6b-5ccfe601e36f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# logging policy param\n",
    "rng = np.random.default_rng(seed=123456)\n",
    "lp_beta = {d: rng.uniform(-0.5,0.5) for d in domain_info.keys()}\n",
    "\n",
    "lp_beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd7de570-8d83-48d0-97e1-300f75aac43b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "rng = np.random.default_rng(seed=SEED)\n",
    "np.random.seed(SEED)\n",
    "\n",
    "true_v_each_domain = {k: 0 for k in domain_info.keys()}\n",
    "\n",
    "for domain_num in true_v_each_domain.keys():\n",
    "    context_test_data = []\n",
    "    for i in range(test_data_sample_saize): \n",
    "        \n",
    "        context_sample = rng.normal(domain_info[domain_num][\"mu_k\"], sigma_x, dim_x).tolist()\n",
    "        context_test_data.append(context_sample)\n",
    "    true_v_each_domain[domain_num] = calc_ground_truth_policy_value(domain_info[domain_num][\"cluster\"], test_data_sample_saize, context_test_data, domain_info[domain_num][\"domain_feature\"])\n",
    "    print(f\"{domain_num}-TRUE_V:{true_v_each_domain[domain_num]}\")\n",
    "    break\n",
    "print(f\"MEAN-True-V: {sum(true_v_each_domain.values())/(num_source_domain+1)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82f9c547-0f49-4d84-9592-fa7e3506f972",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_v = true_v_each_domain[\"domain_0\"]\n",
    "true_v"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13a9c52d-d085-4896-bbc6-1aa1d6d07ab4",
   "metadata": {},
   "source": [
    "## 5. Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed2dc994-5ac0-4d0b-9c1b-5a8db727f235",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_ls = [i for i in range(0,150)]\n",
    "target_domain_sample_size_ls = [50, 75, 100, 125, 150, 175, 200, 250, 300]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bacec048-b3fc-43b2-a3c2-9215de446c52",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_size_dict = {k: [] for k in domain_info.keys()}\n",
    "for k in sample_size_dict.keys():\n",
    "    if k == \"domain_0\":\n",
    "        sample_size_dict[k] = target_domain_sample_size_ls\n",
    "    else:\n",
    "        sample_size_dict[k] = [100 for _ in range(len(target_domain_sample_size_ls))]\n",
    "sample_size_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81c02c65-4848-4d10-8cc6-9becc339bd01",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(target_domain_sample_size_ls)):\n",
    "    N = 0\n",
    "    for domain_num in sample_size_dict.keys():\n",
    "        N += sample_size_dict[domain_num][i]\n",
    "    print(N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45ebb0b-93eb-42cd-97bc-4a58a536eb3b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "estimate_dict = {\"DM\": {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"DM_ALL\":  {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"IPS\": {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"IPS_ALL\":  {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"DR\":  {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"DR_ALL\":  {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}, \n",
    "                 \"OFFCEM\":  {td_ss :[] for td_ss in sample_size_dict[\"domain_0\"]}}\n",
    "estimate_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47c81a2c-4f9f-4861-abc1-14fe2b3cf7a9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "for seed in seed_ls:\n",
    "    print(f\"Now calc seed={seed}.\")\n",
    "    for sample_size_index in range(len(target_domain_sample_size_ls)):\n",
    "        np.random.seed(seed)\n",
    "        rng = np.random.default_rng(seed=seed)\n",
    "        \n",
    "        context_dict_each_domain = {key: [] for key in domain_info.keys()}\n",
    "        action_dict_each_domain = {key: [] for key in domain_info.keys()}\n",
    "        reward_dict_each_domain = {key: [] for key in domain_info.keys()}\n",
    "        noise_dict_each_domain = {key: [] for key in domain_info.keys()}\n",
    "        print(\"Generate log data.\")\n",
    "        \n",
    "        for domain_num in domain_info.keys():\n",
    "            \n",
    "            context_vec, action_vec, reward_vec, noise_vec = log_data_generate(seed, sample_size_dict[domain_num][sample_size_index])\n",
    "            \n",
    "            context_dict_each_domain[domain_num] = context_vec\n",
    "            action_dict_each_domain[domain_num] = action_vec\n",
    "            reward_dict_each_domain[domain_num] = reward_vec\n",
    "            noise_dict_each_domain[domain_num] = noise_vec\n",
    "        \n",
    "        \n",
    "        \n",
    "        print(\"Fit reward model.\")\n",
    "        \n",
    "        reward_models = {\"domain_0\": [], \"ALL_domain\": [], \"Cluster\": []}\n",
    "        \n",
    "        valid_datas =  {k : [] for k in reward_models.keys()}\n",
    "        \n",
    "        \n",
    "        for data_type in reward_models.keys():\n",
    "            if data_type == \"domain_0\":\n",
    "                kfold = KFold(n_splits=fold_k, shuffle=True, random_state=seed)\n",
    "                context_and_action_vec = np.hstack((np.array(context_dict_each_domain[\"domain_0\"]),np.array(action_dict_each_domain[\"domain_0\"]).reshape(len(action_dict_each_domain[\"domain_0\"]),-1)))\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_vec):\n",
    "                    X_train, X_test = context_and_action_vec[train_indices], context_and_action_vec[test_indices]\n",
    "                    y_train, y_test = np.array(reward_dict_each_domain[\"domain_0\"])[train_indices], np.array(reward_dict_each_domain[\"domain_0\"])[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8,random_state = seed)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    print(f\"MSE(Target domain data): {mean_squared_error(y_test,forest.predict(X_test))}\")\n",
    "                    \n",
    "                    reward_models[\"domain_0\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"domain_0\"].append(test_indices)\n",
    "                    \n",
    "            elif data_type == \"ALL_domain\":\n",
    "                kfold = KFold(n_splits=fold_k, shuffle=True, random_state=seed)\n",
    "                domain_index_ls = []\n",
    "                context_ls_all_domain = []\n",
    "                action_ls_all_domain = []\n",
    "                reward_ls_all_domain = []\n",
    "                noise_ls_all_domain = []\n",
    "                for domain_num in domain_info.keys():\n",
    "                    domain_index_ls += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "                    context_ls_all_domain += context_dict_each_domain[domain_num]\n",
    "                    action_ls_all_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_all_domain += reward_dict_each_domain[domain_num]\n",
    "                    noise_ls_all_domain += noise_dict_each_domain[domain_num]\n",
    "                context_and_action_vec_all_domain = np.hstack((np.array(context_ls_all_domain),np.array(action_ls_all_domain).reshape(len(action_ls_all_domain),-1)))\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_vec_all_domain):\n",
    "                    X_train, X_test = context_and_action_vec_all_domain[train_indices], context_and_action_vec_all_domain[test_indices]\n",
    "                    y_train, y_test = np.array(reward_ls_all_domain)[train_indices], np.array(reward_ls_all_domain)[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8, random_state = seed)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    print(f\"MSE(ALL domain data): {mean_squared_error(y_test,forest.predict(X_test))}\")\n",
    "                    \n",
    "                    reward_models[\"ALL_domain\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"ALL_domain\"].append(test_indices)\n",
    "\n",
    "            else:\n",
    "                domain_index_ls_only_tg_cluster = []\n",
    "                context_ls_target_cluster_domain = []\n",
    "                action_ls_target_cluster_domain = []\n",
    "                reward_ls_target_cluster_domain = []\n",
    "                noise_ls_target_cluster_domain = []\n",
    "                context_and_domain_feature_ls_target_cluster_domain = []\n",
    "                \n",
    "                for domain_num in cluster_info[domain_info[\"domain_0\"][\"cluster\"]]:\n",
    "                    domain_index_ls_only_tg_cluster += [domain_num]*len(context_dict_each_domain[domain_num])\n",
    "                    context_ls_target_cluster_domain += context_dict_each_domain[domain_num]\n",
    "                    action_ls_target_cluster_domain += action_dict_each_domain[domain_num]\n",
    "                    reward_ls_target_cluster_domain += reward_dict_each_domain[domain_num]\n",
    "                    noise_ls_target_cluster_domain += noise_dict_each_domain[domain_num]\n",
    "                \n",
    "                    contest_and_domain_feature_ls = copy.deepcopy(context_dict_each_domain[domain_num])\n",
    "                    for i in range(len(contest_and_domain_feature_ls)):\n",
    "                        contest_and_domain_feature_ls[i] += domain_info[domain_num][\"domain_feature\"]\n",
    "                    context_and_domain_feature_ls_target_cluster_domain += contest_and_domain_feature_ls\n",
    "                        \n",
    "                context_and_action_and_domain_feature_vec_target_cluster_domain = np.hstack((np.array(context_and_domain_feature_ls_target_cluster_domain),np.array(action_ls_target_cluster_domain).reshape(len(action_ls_target_cluster_domain),-1)))\n",
    "                kfold = StratifiedKFold(n_splits=fold_k, shuffle=True, random_state=seed)\n",
    "                for train_indices, test_indices in kfold.split(context_and_action_and_domain_feature_vec_target_cluster_domain, domain_index_ls_only_tg_cluster):\n",
    "                    X_train, X_test = context_and_action_and_domain_feature_vec_target_cluster_domain[train_indices], context_and_action_and_domain_feature_vec_target_cluster_domain[test_indices]\n",
    "                    y_train, y_test = np.array(reward_ls_target_cluster_domain)[train_indices], np.array(reward_ls_target_cluster_domain)[test_indices]\n",
    "                    \n",
    "                    forest = RandomForestRegressor(n_estimators=100, max_depth=6, min_samples_leaf=10, max_samples=0.8, random_state = seed)\n",
    "                    #forest = RandomForestRegressor(n_estimators=100, max_depth=8, min_samples_leaf=10, max_samples=0.8, random_state = seed)\n",
    "                    forest.fit(X_train, y_train)\n",
    "                    print(f\"MSE(target cluster data): {mean_squared_error(y_test,forest.predict(X_test))}\")\n",
    "                    \n",
    "                    reward_models[\"Cluster\"].append(forest)\n",
    "                    \n",
    "                    valid_datas[\"Cluster\"].append(test_indices)\n",
    "        \n",
    "        print(\"Fit dens ratio model.\")\n",
    "        \n",
    "        density_ratio_models = {domain_num : [] for domain_num in cluster_info[domain_info[\"domain_0\"][\"cluster\"]]}\n",
    "        for domain_num in density_ratio_models.keys():\n",
    "            \n",
    "            model = densratio(np.array(context_dict_each_domain[domain_num]), np.array(context_dict_each_domain[\"domain_0\"]), verbose=False) \n",
    "            \n",
    "            density_ratio_models[domain_num].append(model)\n",
    "        \n",
    "        \n",
    "        \n",
    "        print(\"Calc V-hat each estimator.\")\n",
    "        #print(f\"ss={int((sample_size_dict['domain_0'][sample_size_index]/all_domain_sample_size)*100)}%\")\n",
    "\n",
    "        \n",
    "        hat_dm = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dm += dm_estimator(domain_info[\"domain_0\"][\"cluster\"], len(valid_datas[\"domain_0\"][fold_num]), (np.array(context_dict_each_domain[\"domain_0\"])[valid_datas[\"domain_0\"][fold_num]]).tolist(), domain_info[\"domain_0\"][\"domain_feature\"],fold_num)\n",
    "        hat_dm = hat_dm / fold_k\n",
    "        print(f\"hat-DM:{hat_dm}\")\n",
    "        \n",
    "        \n",
    "        hat_dm_all_domain = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dm_all_domain += dm_estimator_all_domain(context_ls_all_domain, domain_index_ls, fold_num)\n",
    "        hat_dm_all_domain = hat_dm_all_domain / fold_k\n",
    "        print(f\"hat-DM-alldomain:{hat_dm_all_domain}\")\n",
    "\n",
    "        \n",
    "        hat_ips = ips_estimator(domain_info[\"domain_0\"][\"cluster\"], sample_size_dict[\"domain_0\"][sample_size_index], context_dict_each_domain[\"domain_0\"], action_dict_each_domain[\"domain_0\"], reward_dict_each_domain[\"domain_0\"], domain_info[\"domain_0\"][\"domain_feature\"], noise_dict_each_domain[\"domain_0\"])\n",
    "        print(f\"hat-IPS:{hat_ips}\")\n",
    "\n",
    "        \n",
    "        hat_ips_all_domain = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_ips_all_domain += ips_estimator_all_domain(context_ls_all_domain, action_ls_all_domain, reward_ls_all_domain, domain_index_ls, fold_num, noise_ls_all_domain)\n",
    "        hat_ips_all_domain = hat_ips_all_domain / fold_k\n",
    "        print(f\"hat-IPS-alldomain:{hat_ips_all_domain}\")\n",
    "        \n",
    "        \n",
    "        hat_dr = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dr += dr_estimator(domain_info[\"domain_0\"][\"cluster\"], len(valid_datas[\"domain_0\"][fold_num]), (np.array(context_dict_each_domain[\"domain_0\"])[valid_datas[\"domain_0\"][fold_num]]).tolist(), (np.array(action_dict_each_domain[\"domain_0\"])[valid_datas[\"domain_0\"][fold_num]]).tolist(), (np.array(reward_dict_each_domain[\"domain_0\"])[valid_datas[\"domain_0\"][fold_num]]).tolist(), domain_info[\"domain_0\"][\"domain_feature\"], fold_num, (np.array(noise_dict_each_domain[\"domain_0\"])[valid_datas[\"domain_0\"][fold_num]]).tolist())\n",
    "        hat_dr = hat_dr / fold_k\n",
    "        print(f\"hat-DR:{hat_dr}\")\n",
    "\n",
    "        \n",
    "        hat_dr_all_domain = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_dr_all_domain += dr_estimator_all_domain(context_ls_all_domain, action_ls_all_domain, reward_ls_all_domain, domain_index_ls, fold_num, noise_ls_all_domain)\n",
    "        hat_dr_all_domain = hat_dr_all_domain / fold_k\n",
    "        print(f\"hat-DR-alldomain:{hat_dr_all_domain}\")\n",
    "\n",
    "        \n",
    "        hat_offcem = 0\n",
    "        for fold_num in range(fold_k):\n",
    "            hat_offcem += multi_source_offcem_estimator(domain_info[\"domain_0\"][\"cluster\"], context_ls_target_cluster_domain, action_ls_target_cluster_domain, reward_ls_target_cluster_domain, domain_index_ls_only_tg_cluster, context_and_domain_feature_ls_target_cluster_domain, fold_num, noise_ls_target_cluster_domain)\n",
    "        hat_offcem = hat_offcem / fold_k\n",
    "        print(f\"hat-OFFCEM:{hat_offcem}\")\n",
    "\n",
    "        \n",
    "        estimate_dict[\"DM\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_dm)\n",
    "        estimate_dict[\"DM_ALL\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_dm_all_domain)\n",
    "        estimate_dict[\"IPS\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_ips)\n",
    "        estimate_dict[\"IPS_ALL\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_ips_all_domain)\n",
    "        estimate_dict[\"DR\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_dr)\n",
    "        estimate_dict[\"DR_ALL\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_dr_all_domain)\n",
    "        estimate_dict[\"OFFCEM\"][sample_size_dict[\"domain_0\"][sample_size_index]].append(hat_offcem)\n",
    "        print(\"----------\")\n",
    "    if seed%10 ==0:\n",
    "        print(estimate_dict)\n",
    "    \n",
    "    print(\"======================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fdbaa55-ec4f-45ec-a7c9-30fd6f168b86",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(estimate_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56035d76-432b-4c04-91b8-54a0755aacd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_dict_key(d, old_key, new_key, default_value=None):\n",
    "    d[new_key] = d.pop(old_key, default_value)\n",
    "change_dict_key(estimate_dict, 'OFFCEM', 'COPE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f201cb57-cf1c-440b-8af2-43624fbaacd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "true_value = true_v\n",
    "def calculate_metrics(estimates):\n",
    "    \n",
    "    data_list = []\n",
    "    for estimator_name, sample_sizes in estimates.items():\n",
    "        for sample_size, values in sample_sizes.items():\n",
    "            for value in values:\n",
    "                data_list.append({'Estimator': estimator_name, 'Sample Size': sample_size, 'Value': value})\n",
    "\n",
    "    df = pd.DataFrame(data_list)\n",
    "    \n",
    "    df['MSE'] = (df['Value'] - true_value) ** 2\n",
    "    \n",
    "    return df\n",
    "\n",
    "\n",
    "df_metrics = calculate_metrics(estimate_dict)\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']#, 'D']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='MSE', hue='Estimator', data=df_metrics,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "#plt.ylabel('MSE(log-scale)', fontsize=25)\n",
    "ax.set_yscale('log')\n",
    "#ax.tick_params(axis=\"y\", labelsize=30)\n",
    "plt.ylabel('', fontsize=-100)\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('logged data size in the target domain', fontsize=30, labelpad=20)\n",
    "#ax.set_xticks(xticklabels) \n",
    "#ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "#plt.xticks([0.2,0.4,0.6,0.8])\n",
    "plt.xticks([50, 75, 100, 125, 150, 175, 200, 250, 300])\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('MSE (log-scale)', fontsize=30)\n",
    "#plt.savefig(f'../output_experiments/syns_mse_graph_samplesize.png', dpi=500,bbox_inches='tight');\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee2dc3a-8bed-4cb9-8029-c51a06d86b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "bias_dict = {}\n",
    "for estimator, samples in estimate_dict.items():\n",
    "    bias_dict[estimator] = {}\n",
    "    for sample_size, estimates in samples.items():\n",
    "        V_hat_mean = np.mean(np.array(estimates))\n",
    "        \n",
    "        bias = (V_hat_mean - true_value)**2\n",
    "        bias_dict[estimator][sample_size] = bias\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_list = []\n",
    "for estimator, samples in bias_dict.items():\n",
    "    for sample_size, bias in samples.items():\n",
    "            data_list.append({'Estimator': estimator, 'Sample Size': sample_size, 'Bias': bias})\n",
    "df = pd.DataFrame(data_list)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']#, 'D']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='Bias', hue='Estimator', data=df,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "#plt.ylabel('MSE(log-scale)', fontsize=25)\n",
    "ax.set_yscale('log')\n",
    "#ax.tick_params(axis=\"y\", labelsize=30)\n",
    "plt.ylabel('', fontsize=-100)\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('logged data size in the target domain', fontsize=30, labelpad=20)\n",
    "#ax.set_xticks(xticklabels) \n",
    "#ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "#plt.xticks([0.2,0.4,0.6,0.8])\n",
    "plt.xticks([50, 75, 100, 125, 150, 175, 200, 250, 300])\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "#plt.legend().remove()\n",
    "ax.set_title('Squared Bias (log-scale)', fontsize=30)\n",
    "#plt.savefig(f'../output_experiments/syns_bias_graph_samplesize.png', dpi=500,bbox_inches='tight');\n",
    "plt.show();\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450730b3-a2ef-40b1-9acb-1a6be66f3893",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "variance_dict = {}\n",
    "for estimator, samples in estimate_dict.items():\n",
    "    variance_dict[estimator] = {}\n",
    "    for sample_size, estimates in samples.items():\n",
    "        V_hat_mean = np.mean(np.array(estimates))\n",
    "       \n",
    "        variance =  np.sum(((np.array(estimates) - V_hat_mean)**2)) / len(estimates)\n",
    "        variance_dict[estimator][sample_size] = variance\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "data_list = []\n",
    "for estimator, samples in variance_dict.items():\n",
    "    for sample_size, variance in samples.items():\n",
    "            data_list.append({'Estimator': estimator, 'Sample Size': sample_size, 'Variance': variance})\n",
    "df = pd.DataFrame(data_list)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "\n",
    "\n",
    "markers = ['o', 'v', '8', 's', 'p', '*', 'h']#, 'D']\n",
    "palette = sns.color_palette(\"deep\")[:7]\n",
    "\n",
    "\n",
    "plt.figure(figsize=(11, 7), dpi=400)\n",
    "ax = sns.lineplot(x='Sample Size', y='Variance', hue='Estimator', data=df,\n",
    "                  markers=markers, style='Estimator', errorbar=('ci', 95), palette=palette,linewidth=4, markersize=18, dashes=False)\n",
    "#plt.legend(title='Estimator', loc='upper left', bbox_to_anchor=(1, 1), fontsize=12)\n",
    "\n",
    "\n",
    "#plt.ylabel('MSE(log-scale)', fontsize=25)\n",
    "ax.set_yscale('log')\n",
    "#ax.tick_params(axis=\"y\", labelsize=30)\n",
    "plt.ylabel('', fontsize=-100)\n",
    "ax.tick_params(axis=\"y\", labelsize=18)\n",
    "ax.yaxis.set_label_coords(-0.08, 0.5)\n",
    "\n",
    "plt.xlabel('logged data size in the target domain', fontsize=30, labelpad=20)\n",
    "#ax.set_xticks(xticklabels) \n",
    "#ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "#plt.xticks([0.2,0.4,0.6,0.8])\n",
    "plt.xticks([50, 75, 100, 125, 150, 175, 200, 250, 300])\n",
    "ax.tick_params(axis=\"x\", labelsize=18)\n",
    "\n",
    "plt.legend(title='Estimator')\n",
    "plt.legend(loc='lower center', bbox_to_anchor=(.5, 1.1), ncol=7,  fontsize=12);\n",
    "plt.legend().remove()\n",
    "ax.set_title('Variance (log-scale)', fontsize=30)\n",
    "#plt.savefig(f'../output_experiments/syns_variance_graph_samplesize.png', dpi=500,bbox_inches='tight');\n",
    "plt.show();\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
