{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1096c99d-4952-4734-b66a-5ecc469923c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from New_functions import *\n",
    "from Benchmark_functions import *\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "535eddf8-b4e1-4295-8a99-11c280e9a4b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "First, we generate a dataset, we can change the Alpha_s, Alpha_t and effect parameter to change the distribution of\n",
    "the generated dataset.\n",
    "'''\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def generate(ns, nt, p, q, s, t, u, Alpha_s=1, Alpha_t=0, effect=1, x_effect=0, z_diff=0.1, threshold_X=0.5, threshold_Y=0.5):\n",
    "    # Generate normal distributions for Z\n",
    "    Zs_null = np.random.normal(0, 1, (ns, q))\n",
    "    Zt_null = np.random.normal(0, 1, (nt, q))\n",
    "    \n",
    "    # Generate Z variables with and without shift\n",
    "    Z_source = np.hstack((np.random.normal(0, 1, (ns, p)), Zs_null))\n",
    "    Z_target = np.hstack((np.random.normal(z_diff, 1, (nt, p)), Zt_null))\n",
    "    \n",
    "    # Generate X variables\n",
    "    X_source = Z_source[:, :p] @ u + np.random.normal(0, 1, ns)\n",
    "    X_target = Z_target[:, :p] @ u + np.random.normal(0, 1, nt)\n",
    "    \n",
    "    # Convert X to binary\n",
    "    X_source = (np.random.rand(ns) < 1 / (1 + np.exp(-X_source))).astype(int)\n",
    "    X_target = (np.random.rand(nt) < 1 / (1 + np.exp(-X_target))).astype(int)\n",
    "    \n",
    "    # Generate V variables\n",
    "    V_source = Z_source[:, :p] @ s + Alpha_s * X_source + np.random.normal(0, 5, ns)\n",
    "    V_target = Z_target[:, :p] @ t + Alpha_t * X_target + np.random.normal(0, 5, nt)\n",
    "    \n",
    "    # Generate Y variables\n",
    "    Y_source = (Z_source[:, :p].sum(axis=1))**2 + effect * V_source + np.random.normal(0, 1, ns) + x_effect * X_source\n",
    "    Y_target = (Z_target[:, :p].sum(axis=1))**2 + effect * V_target + np.random.normal(0, 1, nt) + x_effect * X_target\n",
    "    \n",
    "    # Convert Y to binary\n",
    "    Y_source = (np.random.rand(ns) < 1 / (1 + np.exp(-Y_source))).astype(int)\n",
    "    Y_target = (np.random.rand(nt) < 1 / (1 + np.exp(-Y_target))).astype(int)\n",
    "    \n",
    "    return Y_source.reshape(-1, 1), X_source.reshape(-1, 1), V_source.reshape(-1, 1), Z_source,\\\n",
    "           Y_target.reshape(-1, 1), X_target.reshape(-1, 1), V_target.reshape(-1, 1), Z_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3c9c88b4-c59d-405b-940d-5b260bb9a9c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Set parameter for the generation of data\n",
    "ns, p,q = 1000, 5, 10\n",
    "nt = 2000\n",
    "\n",
    "s = np.array([-1, -0.5, 0, 1, 1.5])\n",
    "t = np.array([ -1, -1, 0.5 , 0.5, 1])\n",
    "u = np.array([ 0, -1, 0.5, -0.5, 1])\n",
    "\n",
    "Y_source, X_source, V_source, Z_source,Y_target, X_target, V_target, Z_target = \\\n",
    "generate(ns,nt, p,q, s, t, u, Alpha_s=0, Alpha_t = 2,effect=2, z_diff = 0.1)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c10a39c1-9118-4ec7-92d7-af65562e11fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Separation:\n",
    "# The data arrays X_e, Z_e, V_e are designated for density ratio estimation.\n",
    "# The arrays Z_source, X_source, V_source, Y_source are used for testing.\n",
    "# Here, we split the source data based on a specified proportion.\n",
    "\n",
    "proportion = 0.5\n",
    "num = int(proportion * X_source.shape[0])\n",
    "Z_e = Z_source[:num]\n",
    "X_e = X_source[:num]\n",
    "V_e = V_source[:num]\n",
    "Z_source = Z_source[num+1:]\n",
    "X_source = X_source[num+1:]\n",
    "V_source = V_source[num+1:]\n",
    "Y_source = Y_source[num+1:]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e2cf4f3-81e0-4b87-9ad9-83ba166b65eb",
   "metadata": {},
   "source": [
    "# Real data experiments to run"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21f66334-515f-4245-b726-98a1e975696d",
   "metadata": {},
   "source": [
    "## Our method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "79610af9-6d81-4dbc-bc16-24a98147e511",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.05528617426693039"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#1. Test without power enhancement\n",
    "Test(X_e, Z_e, V_e, X_source, Z_source, V_source, Y_source, \\\n",
    "     X_target, Z_target, V_target, L=3, K=20, datatype='binary')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8a5e103d-3a4a-43fd-a211-c5c06221e6e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.27508511 0.20949971 0.31859527]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.0405945146182739"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#2. Test with power enhancement\n",
    "\n",
    "Test_pe(X_e, Z_e, V_e, X_source, Z_source, V_source, Y_source, \\\n",
    "        X_target, Z_target, V_target, L=3, K=20, datatype='binary')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22968a9d-8980-44b3-9bf0-bcb716051aa8",
   "metadata": {},
   "source": [
    "### Use a different scoring function for testing V for the Power enhancement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "26f3b42b-91e1-4b0f-aacc-fd482282e35f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.24195826  0.22223624 -0.23156712]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.11639344900462212"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Originally, we use v*x for scoring, here we use (-v)*x to detect \n",
    "# the negative correlation between Y and V\n",
    "Test_pe(X_e, Z_e, V_e, X_source, Z_source, V_source, Y_source, \\\n",
    "        X_target, Z_target, V_target, L=3, K=20, datatype='binary', score = 'neg')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc094742-8470-45e6-86a9-f1393f7431ae",
   "metadata": {},
   "source": [
    "## Benchmark\n",
    "there are 3 benchmarks:1. Use source only data, 2. Use target only data, 3. Importance sampling method (the benchmark from others)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9251a4ab-1f4e-4856-be09-22a44f42bffd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5608435703845116"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#1. Use source data only\n",
    "#2. Test with power enhancement\n",
    "\n",
    "PCR_test(X_source,Z_source,V_source,Y_source)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5a2f8c8f-53ab-40ea-a0b5-3e83967dc13c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.518063396234775"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#2. Use target data only\n",
    "\n",
    "PCR_test(X_target,Z_target,V_target,Y_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "cb824e77-a28c-4122-9e25-f603c317843a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.18814915910850627"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#3. Use Importance Sampling benchmark method\n",
    "\n",
    "IS_test(X_e, Z_e, V_e, X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, L=3, K=20, datatype='binary')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "065e3620-442d-45b9-a89b-6002c61c784a",
   "metadata": {},
   "source": [
    "## Tune hyperparameter L"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "41fb6f19-94d2-4934-9e11-e88147f7a404",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training accuracy for X|Z: 0.717\n",
      "[0.23433064 0.19167032 0.28155873]\n",
      "[167.95443516 165.08707733 174.75729811]\n",
      "L is 2, pvalue: 0.933563916771966\n",
      "Training accuracy for X|Z: 0.717\n",
      "[0.2952632  0.20613484 0.31138347]\n",
      "[140.64749537 187.91719237 174.08845475]\n",
      "L is 5, pvalue: 0.06342484042790109\n",
      "Training accuracy for X|Z: 0.717\n",
      "[0.2715509  0.20274784 0.29688373]\n",
      "[148.2607481  172.03479922 184.74858996]\n",
      "L is 8, pvalue: 0.21636110362760874\n",
      "Training accuracy for X|Z: 0.717\n",
      "[0.29121139 0.13614264 0.34778643]\n",
      "[146.78635223 166.57257311 184.67772408]\n",
      "L is 10, pvalue: 0.195118444312963\n"
     ]
    }
   ],
   "source": [
    "l_lst = [2, 5, 8, 10]\n",
    "result_lst = []\n",
    "for l in l_lst:\n",
    "    # Use any test function above\n",
    "    pvalue = Test_pe(X_e, Z_e, V_e, X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, L=3, K=20, datatype='binary')\n",
    "    result_lst.append(pvalue)\n",
    "    print(f'L is {l}, pvalue: {pvalue}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
