{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from faulty_policy import SoftmaxDataPolicy\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import os\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import LabelEncoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_pm_expected_value(pm_lambda, q, p, func_values):\n",
    "    w =  p / q\n",
    "    power_mean_w = w / (1 - pm_lambda + (pm_lambda * w))\n",
    "    return np.mean(func_values * power_mean_w)\n",
    "\n",
    "def calculate_es_expected_value(es_lambda, q, p, func_values):\n",
    "    return np.mean(func_values * p / (q ** es_lambda))\n",
    "\n",
    "def calculate_ls_expected_value(ls_lambda, q, p, func_values):\n",
    "    return -np.mean((1 / ls_lambda) * p * np.log(1 - ((ls_lambda * func_values) / (q + 1e-8))))\n",
    "\n",
    "def calculate_lsnl_expected_value(ls_lambda, q, p, func_values):\n",
    "    return -np.mean((1 / ls_lambda) * np.log(1 - ((ls_lambda * func_values * p) / (q + 1e-8))))\n",
    "\n",
    "def calculate_ops_expected_value(ops_lambda, q, p, func_values):\n",
    "    w = p / q\n",
    "    w2 = w * w\n",
    "    ops_w = (ops_lambda * w) / (w2 + ops_lambda)\n",
    "    return np.mean(ops_w * func_values)\n",
    "\n",
    "def calculate_ix_expected_value(ix_lambda, q, p, func_values):\n",
    "    return np.mean(func_values * p / (q + ix_lambda))\n",
    "\n",
    "def calculate_tr_expected_value(tr_lambda, q, p, func_values):\n",
    "    return np.mean(func_values * np.minimum(p / q, tr_lambda))\n",
    "\n",
    "def calculate_sn_expected_value(tr_lambda, q, p, func_values):\n",
    "    return np.mean(func_values * (p / q)) / np.mean(p / q)\n",
    "\n",
    "def calculate_lse_expected_value(lse_lambda, q, p, func_values):\n",
    "    if lse_lambda <= 0:\n",
    "        return np.mean(func_values * (p/q))\n",
    "    result = lse_lambda * (func_values) * (p / q)\n",
    "    result = np.exp(result)\n",
    "    result = np.log(np.mean(result))\n",
    "    return ((1 / lse_lambda) * result)\n",
    "\n",
    "def calculate_lsen_expected_value(lse_lambda, q, p, func_values):\n",
    "    result = lse_lambda * (func_values) * (p / q)\n",
    "    result = np.exp(result)\n",
    "    dom = np.exp(lse_lambda * (p / q))\n",
    "    result = np.log(np.mean(result) / (np.mean(dom) + 1e-8))\n",
    "    return ((1 / lse_lambda) * result)\n",
    "\n",
    "\n",
    "def calculate_mrdr_expected_value(mrdr_lambda, q, p, func_values):\n",
    "    return np.mean((1 - q) * p * func_values / q ** 2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import minimize\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hyperparameter(method, p, q, n, r):\n",
    "    significance = 0.1\n",
    "    if method == 'pm':\n",
    "        d2_Renyi = np.mean(p ** 2 / q)\n",
    "        out = (np.log(1 / significance) / (3 * d2_Renyi * n)) ** 0.5\n",
    "        return np.array([out])\n",
    "    if method == 'ix':\n",
    "        out = (np.log(2 / significance) / n) ** 0.5\n",
    "        return np.array([out])\n",
    "    if method == 'es':\n",
    "        return np.array([0.0, 0.3, 0.5, 0.7, 1.0])\n",
    "    if method == 'lse':\n",
    "        # return np.array([1 / n ** 0.5])\n",
    "        return np.array([0.0, 0.001, 0.01, 0.1, 1.0])\n",
    "    if method in ['ls', 'tr']: \n",
    "        return np.array([1 / n ** 0.5])\n",
    "    if method in ['mrdr', 'sn']:\n",
    "        return np.array([1.0])\n",
    "    if method == 'os':\n",
    "        def obj(lambda_):\n",
    "            shrinkage_weight = (lambda_ * (p/q)) / ((p/q) ** 2 + lambda_)\n",
    "            estimated_rewards_ = shrinkage_weight * r\n",
    "            variance = np.var(estimated_rewards_)\n",
    "            bias = np.sqrt(np.mean((p/q - shrinkage_weight) ** 2)) * max(r)\n",
    "            return bias ** 2 + variance\n",
    "        landa_opt = minimize(obj, x0=np.array([1]), bounds=[(0, np.inf)], method='Powell').x\n",
    "        return np.array([landa_opt])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimators = {\n",
    "    \"pm\": {\n",
    "        'param': partial(get_hyperparameter, method='pm'),\n",
    "        'func': calculate_pm_expected_value\n",
    "    },\n",
    "    \"es\": {\n",
    "        'param': partial(get_hyperparameter, method='es'),\n",
    "        'func': calculate_es_expected_value\n",
    "    },\n",
    "    \"ix\": {\n",
    "        'param': partial(get_hyperparameter, method='ix'),\n",
    "        'func': calculate_ix_expected_value\n",
    "    },\n",
    "    \"lse\": {\n",
    "        'param': partial(get_hyperparameter, method='lse'),\n",
    "        'func': calculate_lse_expected_value\n",
    "    },\n",
    "    \"ls\": {\n",
    "        'param': partial(get_hyperparameter, method='ls'),\n",
    "        'func': calculate_lsnl_expected_value\n",
    "    },\n",
    "    \"ls_lin\": {\n",
    "        'param': [0.01, 0.05, 0.1, 0.2, 0.5, 1.0, 1.5, 2, 5],\n",
    "        'func': calculate_ls_expected_value\n",
    "    },\n",
    "    \"tr\": {\n",
    "        'param': partial(get_hyperparameter, method='tr'),\n",
    "        'func': calculate_pm_expected_value\n",
    "    },\n",
    "    \"os\": {\n",
    "        'param': partial(get_hyperparameter, method='os'),\n",
    "        'func': calculate_ops_expected_value\n",
    "    },\n",
    "    \"sn\": {\n",
    "        'param': partial(get_hyperparameter, method='sn'),\n",
    "        'func': calculate_sn_expected_value\n",
    "    },\n",
    "    \"mrdr\": {\n",
    "        'param': partial(get_hyperparameter, method='mrdr'),\n",
    "        'func': calculate_mrdr_expected_value\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(log_temp, target_temp, dataset, noise=None):\n",
    "    meta = []\n",
    "    names = []\n",
    "    ps = {'train': [], 'test': []}\n",
    "    qs = {'train': [], 'test': []}\n",
    "    rs = {'train': [], 'test': []}\n",
    "    true_reward = {'train':[], 'test': []}\n",
    "    # for filename in sorted(os.listdir('UCI')):\n",
    "    #     df = pd.read_csv(Path('UCI') / filename)\n",
    "    df = pd.read_csv(Path('UCI') / dataset)\n",
    "    label_encoder = LabelEncoder()\n",
    "    y = label_encoder.fit_transform(df[df.columns[-1]])\n",
    "    categorical_columns = df.select_dtypes(include=['object', 'category']).columns\n",
    "    action_size = len(df[df.columns[-1]].unique())\n",
    "    df = pd.get_dummies(df, columns=categorical_columns)\n",
    "    x = df.values[:, :-action_size].astype(np.float32)\n",
    "    # print(x.shape, y.shape)\n",
    "    # p = np.ones(y.shape) / y.shape[-1]\n",
    "    # data = np.concatenate([x, y, p], axis=1)\n",
    "    train, test = train_test_split(np.concatenate([x, y[:, None]], axis=1), test_size=0.8)\n",
    "    y = {\n",
    "        'train': train[:, -1].astype(int),\n",
    "        'test': test[:, -1].astype(int)\n",
    "    }\n",
    "    # print(train.shape, test.shape)\n",
    "    names.append(dataset[:dataset.rindex(\".\")])\n",
    "    meta.append([x.shape[-1], action_size])\n",
    "    u = action_size // 2\n",
    "    logging_policy = SoftmaxDataPolicy(train[:, :-1], train[:, -1], \n",
    "                            test[:, :-1], test[:, -1], \n",
    "                            action_set=np.arange(action_size),\n",
    "                            temperature=log_temp, \n",
    "                            faulty_actions=np.arange(u, action_size)\n",
    "                            )\n",
    "    target_policy = SoftmaxDataPolicy(train[:, :-1], train[:, -1], \n",
    "                            test[:, :-1], test[:, -1], \n",
    "                            action_set=np.arange(action_size),\n",
    "                            temperature=target_temp, \n",
    "                            faulty_actions=np.arange(0, u)\n",
    "                            )\n",
    "    for mode in ['train', 'test']:\n",
    "        logging_probs = logging_policy.get_probs(mode)\n",
    "        target_probs = target_policy.get_probs(mode)\n",
    "        # print(logging_probs[0])\n",
    "        # print(target_probs[0])\n",
    "        action = np.array([np.random.choice(action_size, p=prob) for prob in logging_probs])\n",
    "        p = target_probs[np.arange(len(target_probs)), action]\n",
    "        q = logging_probs[np.arange(len(target_probs)), action]\n",
    "        r = (action == y[mode]).astype(float)\n",
    "        tr = np.mean(target_probs[np.arange(len(target_probs)), y[mode]])\n",
    "        true_reward[mode].append(tr)\n",
    "        # print(p/q)\n",
    "        # print(r)\n",
    "        # print(tr)\n",
    "        # break\n",
    "        if noise is not None:\n",
    "            r += np.random.binomial(n=1, p=noise, size=len(r))\n",
    "            r = r % 2\n",
    "        # r = r * 9 + 1\n",
    "        ps[mode].append(p)\n",
    "        qs[mode].append(q)\n",
    "        rs[mode].append(r)\n",
    "    return names, ps, qs, rs, true_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_exp = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pm : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='pm')\n",
      "es : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='es')\n",
      "ix : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='ix')\n",
      "lse : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='lse')\n",
      "ls : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='ls')\n",
      "tr : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='tr')\n",
      "os : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='os')\n",
      "sn : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='sn')\n",
      "mrdr : functools.partial(<function get_hyperparameter at 0x7f7cd43cf0a0>, method='mrdr')\n"
     ]
    }
   ],
   "source": [
    "for estimator, info in estimators.items():\n",
    "    print(estimator, \":\", info['param'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "all_lambdas_ls = defaultdict(lambda: 0)\n",
    "all_lambdas_es = defaultdict(lambda: 0)\n",
    "for dataset in sorted(os.listdir('UCI')):\n",
    "    print('-' * 100)\n",
    "    print(dataset)\n",
    "    all_errors = {k: [] for k in estimators.keys()}\n",
    "    for log_temp in [0.6, 0.7, 0.8]:\n",
    "        for target_temp in [0.1, 0.3, 0.5]:\n",
    "            for _ in range(n_exp):\n",
    "                names, ps, qs, rs, true_reward = generate_data(log_temp, target_temp, dataset, noise=0.3)\n",
    "                for i in range(len(ps['train'])):\n",
    "                    p = ps['train'][i]\n",
    "                    q = qs['train'][i]\n",
    "                    r = rs['train'][i]\n",
    "                    true_r= true_reward['train'][i]\n",
    "                    p_test = ps['test'][i]\n",
    "                    q_test = qs['test'][i]\n",
    "                    r_test = rs['test'][i]\n",
    "                    true_r_test = true_reward['test'][i]\n",
    "\n",
    "                    for estimator, info in estimators.items():\n",
    "                        best_landa = None\n",
    "                        best_error = 1000\n",
    "                        best_est = None\n",
    "                        for landa in info['param'](p=p, q=q, n=len(p), r=r):\n",
    "                            est_reward = -(info['func'](landa, q, p, -r))\n",
    "                            error = np.abs(est_reward - true_r)\n",
    "                            if error < best_error:\n",
    "                                best_error = error\n",
    "                                best_landa = landa\n",
    "                                best_est = est_reward\n",
    "                        est_reward = -(info['func'](best_landa, q_test, p_test, -r_test))\n",
    "                        error = est_reward - true_r_test\n",
    "                        all_errors[estimator].append(error)\n",
    "    print(' Method', '   BIAS', '    VAR', '    MSE')\n",
    "    print('Best method=', sorted([(k, np.mean(np.array(all_errors[k])**2)) for k in all_errors.keys()], key=lambda x: x[1])[0][0])\n",
    "    for k, v in all_errors.items():\n",
    "        bias = np.mean(v)\n",
    "        mse = np.mean(np.array(v)**2)\n",
    "        var = mse - bias**2\n",
    "        print(k.rjust(7, ' '), str(int(bias * 1e+4) / 1e+4).rjust(7, ' '), str(int(var * 1e+4) / 1e+4).rjust(7, ' '), str(int(mse * 1e+4) / 1e+4).rjust(7, ' '))\n",
    "                        \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlp_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
