{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5132a9a9-9244-44d5-92ea-f0265502d3c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import time\n",
    "import itertools\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import seaborn as sns\n",
    "import scipy.stats as stats\n",
    "from scipy.stats import gamma\n",
    "from scipy.stats import ortho_group\n",
    "\n",
    "import cupy as cp\n",
    "import cupyx.scipy \n",
    "\n",
    "from LFHSIC.fhsic_naive import IndpTest_naive_rff\n",
    "from LFHSIC.lfhsic_g import IndpTest_LFGaussian\n",
    "from LFHSIC.lfhsic_m import IndpTest_LFMahalanobis\n",
    "\n",
    "device = torch.device('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ec1f4e87-dc3a-4f9f-9c9a-88529beaac65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n = 10000, d = 16, D=300\n",
    "def generate_ISA(n,d,sigma_normal,alpha):\n",
    "    \n",
    "    x = np.concatenate((np.random.normal(-1, sigma_normal, n//2), np.random.normal(1, sigma_normal, n//2)))\n",
    "    y = np.concatenate((np.random.normal(-1, sigma_normal, n//2), np.random.normal(1, sigma_normal, n//2)))\n",
    "    p = np.random.permutation(n)\n",
    "    y_p = y[p]\n",
    "\n",
    "    D = np.zeros([2,n])\n",
    "    D[0,:] = x\n",
    "    D[1,:] = y_p\n",
    "\n",
    "    theta = np.pi/4*alpha\n",
    "    c, s = np.cos(theta), np.sin(theta)\n",
    "    R = np.array(((c, -s), (s, c)))\n",
    "\n",
    "    D_R = R@D\n",
    "    X_mix = D_R[0,:].reshape(-1,1)\n",
    "    Y_mix = D_R[1,:].reshape(-1,1)\n",
    "\n",
    "    X_z = np.random.randn(n,d-1)\n",
    "    Y_z = np.random.randn(n,d-1)\n",
    "\n",
    "    X_con = np.concatenate((X_mix,X_z), axis=1)\n",
    "    Y_con = np.concatenate((Y_mix,Y_z), axis=1)\n",
    "\n",
    "    m_x = ortho_group.rvs(dim=d)\n",
    "    m_y = ortho_group.rvs(dim=d)\n",
    "\n",
    "    X = (m_x@X_con.T).T\n",
    "    Y = (m_y@Y_con.T).T\n",
    "    \n",
    "    return X,Y\n",
    "\n",
    "# w = 5\n",
    "def Sinusoid(x, y, w):\n",
    "    return 1 + np.sin(w*x)*np.sin(w*y)\n",
    "\n",
    "def Sinusoid_Generator(n,w):\n",
    "    i = 0\n",
    "    output = np.zeros([n,2])\n",
    "    while i < n:\n",
    "        U = np.random.rand(1)\n",
    "        V = np.random.rand(2)\n",
    "        x0 = -np.pi + V[0]*2*np.pi\n",
    "        x1 = -np.pi + V[1]*2*np.pi\n",
    "        if U < 1/2 * Sinusoid(x0,x1,w):\n",
    "            output[i, 0] = x0\n",
    "            output[i, 1] = x1\n",
    "            i = i + 1\n",
    "    return output[:,0], output[:,1]\n",
    "\n",
    "# n = 3000\n",
    "# d = 4\n",
    "def sinedependence(n,d):\n",
    "    mean = np.zeros(d)\n",
    "    cov = np.eye(d)\n",
    "    X = np.random.multivariate_normal(mean, cov, n)\n",
    "    Z = np.random.randn(n)\n",
    "    Y = 20*np.sin(4*np.pi*(X[:,0]**2 + X[:,1]**2))+Z \n",
    "    return X,Y\n",
    "\n",
    "# n = 4000\n",
    "# d = 5\n",
    "def GSign(n,d):\n",
    "    mean = np.zeros(d)\n",
    "    cov = np.eye(d)\n",
    "    X = np.random.multivariate_normal(mean, cov, n)\n",
    "    sign_X = np.sign(X)\n",
    "    Z = np.random.randn(n)\n",
    "    Y = np.abs(Z)*np.prod(sign_X,1)\n",
    "    return X,Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "63d05fce-73e6-4cce-96d5-fecb1e9f5db1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [0. 0.]\n",
      "1 [0.  0.5]\n",
      "2 [0.33333333 0.66666667]\n",
      "3 [0.25 0.75]\n",
      "4 [0.2 0.8]\n",
      "5 [0.33333333 0.66666667]\n",
      "6 [0.28571429 0.71428571]\n",
      "7 [0.25 0.75]\n",
      "8 [0.33333333 0.77777778]\n",
      "9 [0.4 0.8]\n",
      "10 [0.36363636 0.72727273]\n",
      "11 [0.33333333 0.75      ]\n",
      "12 [0.30769231 0.76923077]\n",
      "13 [0.28571429 0.78571429]\n",
      "14 [0.26666667 0.8       ]\n",
      "15 [0.25   0.8125]\n",
      "16 [0.23529412 0.82352941]\n",
      "17 [0.22222222 0.83333333]\n",
      "18 [0.26315789 0.84210526]\n",
      "19 [0.3  0.85]\n",
      "20 [0.28571429 0.85714286]\n",
      "21 [0.27272727 0.81818182]\n",
      "22 [0.30434783 0.82608696]\n",
      "23 [0.33333333 0.83333333]\n",
      "24 [0.36 0.84]\n",
      "25 [0.34615385 0.84615385]\n",
      "26 [0.37037037 0.81481481]\n",
      "27 [0.35714286 0.78571429]\n",
      "28 [0.37931034 0.79310345]\n",
      "29 [0.36666667 0.76666667]\n",
      "30 [0.35483871 0.77419355]\n",
      "31 [0.34375 0.75   ]\n",
      "32 [0.36363636 0.75757576]\n",
      "33 [0.35294118 0.76470588]\n",
      "34 [0.37142857 0.77142857]\n",
      "35 [0.38888889 0.77777778]\n",
      "36 [0.40540541 0.78378378]\n",
      "37 [0.39473684 0.76315789]\n",
      "38 [0.41025641 0.76923077]\n",
      "39 [0.4   0.775]\n",
      "40 [0.41463415 0.7804878 ]\n",
      "41 [0.42857143 0.78571429]\n",
      "42 [0.44186047 0.76744186]\n",
      "43 [0.45454545 0.77272727]\n",
      "44 [0.44444444 0.77777778]\n",
      "45 [0.43478261 0.76086957]\n",
      "46 [0.44680851 0.76595745]\n",
      "47 [0.45833333 0.77083333]\n",
      "48 [0.44897959 0.7755102 ]\n",
      "49 [0.46 0.78]\n"
     ]
    }
   ],
   "source": [
    "n = 3000\n",
    "d = 4\n",
    "D = 100\n",
    "test_num = 50\n",
    "\n",
    "result_test_correct = np.zeros([2,test_num])\n",
    "\n",
    "for it in range(test_num):\n",
    "    X, Y = sinedependence(n, d)\n",
    "    Y = Y.reshape(-1,1)\n",
    "    \n",
    "    rand_index = np.random.permutation(n)\n",
    "    X = X[rand_index]\n",
    "    Y = Y[rand_index]\n",
    "    X_tensor, Y_tensor = torch.tensor(X), torch.tensor(Y)\n",
    "\n",
    "    if len(X_tensor.size())==1:\n",
    "        X_tensor = X_tensor.reshape(-1,1)\n",
    "    if len(Y_tensor.size())==1:\n",
    "        Y_tensor = Y_tensor.reshape(-1,1)\n",
    "\n",
    "    #test#\n",
    "\n",
    "    lfhsic_g = IndpTest_LFGaussian(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all1 = lfhsic_g.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[0, it] = float(results_all1['h0_rejected'])\n",
    "\n",
    "    lfhsic_m = IndpTest_LFMahalanobis(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all2 = lfhsic_m.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[1, it] = float(results_all2['h0_rejected'])\n",
    "\n",
    "    print(it, np.sum(result_test_correct,1)/(it+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "73742465-8b7e-4c92-a013-bb01bb110148",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [1. 1.]\n",
      "1 [1. 1.]\n",
      "2 [1. 1.]\n",
      "3 [1. 1.]\n",
      "4 [1. 1.]\n",
      "5 [1. 1.]\n",
      "6 [1. 1.]\n",
      "7 [1. 1.]\n",
      "8 [1. 1.]\n",
      "9 [1. 1.]\n",
      "10 [1. 1.]\n",
      "11 [1. 1.]\n",
      "12 [1. 1.]\n",
      "13 [1. 1.]\n",
      "14 [1. 1.]\n",
      "15 [1. 1.]\n",
      "16 [1. 1.]\n",
      "17 [1. 1.]\n",
      "18 [1. 1.]\n",
      "19 [1. 1.]\n",
      "20 [1. 1.]\n",
      "21 [1. 1.]\n",
      "22 [1. 1.]\n",
      "23 [1. 1.]\n",
      "24 [1. 1.]\n",
      "25 [1. 1.]\n",
      "26 [1. 1.]\n",
      "27 [1. 1.]\n",
      "28 [0.96551724 1.        ]\n",
      "29 [0.96666667 1.        ]\n",
      "30 [0.96774194 1.        ]\n",
      "31 [0.96875 1.     ]\n",
      "32 [0.96969697 1.        ]\n",
      "33 [0.97058824 1.        ]\n",
      "34 [0.97142857 1.        ]\n",
      "35 [0.97222222 1.        ]\n",
      "36 [0.97297297 1.        ]\n",
      "37 [0.97368421 1.        ]\n",
      "38 [0.97435897 1.        ]\n",
      "39 [0.975 1.   ]\n",
      "40 [0.97560976 1.        ]\n",
      "41 [0.97619048 1.        ]\n",
      "42 [0.97674419 1.        ]\n",
      "43 [0.97727273 1.        ]\n",
      "44 [0.97777778 1.        ]\n",
      "45 [0.97826087 1.        ]\n",
      "46 [0.95744681 1.        ]\n",
      "47 [0.95833333 1.        ]\n",
      "48 [0.95918367 1.        ]\n",
      "49 [0.96 1.  ]\n"
     ]
    }
   ],
   "source": [
    "n = 2000\n",
    "w = 5\n",
    "D = 100\n",
    "test_num = 50\n",
    "\n",
    "result_test_correct = np.zeros([2,test_num])\n",
    "\n",
    "for it in range(test_num):\n",
    "    output = Sinusoid_Generator(n,w)\n",
    "    X = output[0].reshape(-1,1)\n",
    "    Y = output[1].reshape(-1,1)\n",
    "    \n",
    "    rand_index = np.random.permutation(n)\n",
    "    X = X[rand_index]\n",
    "    Y = Y[rand_index]\n",
    "    X_tensor, Y_tensor = torch.tensor(X), torch.tensor(Y)\n",
    "\n",
    "    if len(X_tensor.size())==1:\n",
    "        X_tensor = X_tensor.reshape(-1,1)\n",
    "    if len(Y_tensor.size())==1:\n",
    "        Y_tensor = Y_tensor.reshape(-1,1)\n",
    "\n",
    "    #test#\n",
    "\n",
    "    lfhsic_g = IndpTest_LFGaussian(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all1 = lfhsic_g.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[0, it] = float(results_all1['h0_rejected'])\n",
    "\n",
    "    lfhsic_m = IndpTest_LFMahalanobis(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all2 = lfhsic_m.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[1, it] = float(results_all2['h0_rejected'])\n",
    "\n",
    "    print(it, np.sum(result_test_correct,1)/(it+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36e061c0-3fd9-40e5-9cfd-dcff452958bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [0. 1.]\n",
      "1 [0.5 1. ]\n",
      "2 [0.33333333 1.        ]\n",
      "3 [0.25 1.  ]\n",
      "4 [0.2 0.8]\n",
      "5 [0.16666667 0.83333333]\n",
      "6 [0.28571429 0.85714286]\n",
      "7 [0.375 0.75 ]\n",
      "8 [0.44444444 0.77777778]\n",
      "9 [0.5 0.8]\n",
      "10 [0.54545455 0.81818182]\n",
      "11 [0.58333333 0.83333333]\n",
      "12 [0.61538462 0.84615385]\n",
      "13 [0.57142857 0.85714286]\n",
      "14 [0.53333333 0.86666667]\n",
      "15 [0.5625 0.875 ]\n",
      "16 [0.58823529 0.88235294]\n",
      "17 [0.61111111 0.88888889]\n",
      "18 [0.57894737 0.89473684]\n",
      "19 [0.6  0.85]\n",
      "20 [0.61904762 0.85714286]\n",
      "21 [0.63636364 0.81818182]\n",
      "22 [0.65217391 0.7826087 ]\n",
      "23 [0.66666667 0.75      ]\n",
      "24 [0.68 0.76]\n",
      "25 [0.69230769 0.76923077]\n",
      "26 [0.7037037  0.77777778]\n",
      "27 [0.71428571 0.78571429]\n",
      "28 [0.68965517 0.75862069]\n",
      "29 [0.7        0.76666667]\n",
      "30 [0.70967742 0.77419355]\n",
      "31 [0.71875 0.78125]\n",
      "32 [0.6969697  0.78787879]\n",
      "33 [0.70588235 0.79411765]\n",
      "34 [0.71428571 0.8       ]\n",
      "35 [0.72222222 0.80555556]\n",
      "36 [0.72972973 0.81081081]\n",
      "37 [0.73684211 0.81578947]\n",
      "38 [0.71794872 0.82051282]\n",
      "39 [0.725 0.825]\n",
      "40 [0.70731707 0.82926829]\n",
      "41 [0.71428571 0.83333333]\n",
      "42 [0.72093023 0.81395349]\n",
      "43 [0.72727273 0.81818182]\n",
      "44 [0.71111111 0.8       ]\n",
      "45 [0.69565217 0.80434783]\n",
      "46 [0.70212766 0.80851064]\n",
      "47 [0.70833333 0.79166667]\n",
      "48 [0.71428571 0.79591837]\n",
      "49 [0.72 0.78]\n"
     ]
    }
   ],
   "source": [
    "n = 4000\n",
    "d = 5\n",
    "D = 100\n",
    "test_num = 50\n",
    "\n",
    "result_test_correct = np.zeros([2,test_num])\n",
    "\n",
    "for it in range(test_num):\n",
    "    X, Y = GSign(n,d)\n",
    "    Y = Y.reshape(-1,1)\n",
    "    \n",
    "    rand_index = np.random.permutation(n)\n",
    "    X = X[rand_index]\n",
    "    Y = Y[rand_index]\n",
    "    X_tensor, Y_tensor = torch.tensor(X), torch.tensor(Y)\n",
    "\n",
    "    if len(X_tensor.size())==1:\n",
    "        X_tensor = X_tensor.reshape(-1,1)\n",
    "    if len(Y_tensor.size())==1:\n",
    "        Y_tensor = Y_tensor.reshape(-1,1)\n",
    "\n",
    "    #test#\n",
    "\n",
    "    lfhsic_g = IndpTest_LFGaussian(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all1 = lfhsic_g.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[0, it] = float(results_all1['h0_rejected'])\n",
    "\n",
    "    lfhsic_m = IndpTest_LFMahalanobis(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "    results_all2 = lfhsic_m.perform_test(rff_num = D, lr = 0.05, if_grid_search = True, debug = -1)\n",
    "    result_test_correct[1, it] = float(results_all2['h0_rejected'])\n",
    "\n",
    "    print(it, np.sum(result_test_correct,1)/(it+1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "05a2225c-1576-4531-b4ab-d76a28fbeee9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0 0 [1. 1.]\n",
      "1.0 1 [1. 1.]\n",
      "1.0 2 [0.66666667 1.        ]\n",
      "1.0 3 [0.75 1.  ]\n",
      "1.0 4 [0.8 1. ]\n",
      "1.0 5 [0.83333333 1.        ]\n",
      "1.0 6 [0.85714286 1.        ]\n",
      "1.0 7 [0.875 1.   ]\n",
      "1.0 8 [0.88888889 1.        ]\n",
      "1.0 9 [0.9 1. ]\n",
      "1.0 10 [0.81818182 1.        ]\n",
      "1.0 11 [0.83333333 1.        ]\n",
      "1.0 12 [0.84615385 1.        ]\n",
      "1.0 13 [0.78571429 1.        ]\n",
      "1.0 14 [0.73333333 1.        ]\n",
      "1.0 15 [0.6875 1.    ]\n",
      "1.0 16 [0.70588235 1.        ]\n",
      "1.0 17 [0.72222222 1.        ]\n",
      "1.0 18 [0.73684211 1.        ]\n",
      "1.0 19 [0.75 1.  ]\n",
      "1.0 20 [0.71428571 1.        ]\n",
      "1.0 21 [0.72727273 1.        ]\n",
      "1.0 22 [0.73913043 1.        ]\n",
      "1.0 23 [0.75 1.  ]\n",
      "1.0 24 [0.72 1.  ]\n",
      "1.0 25 [0.69230769 1.        ]\n",
      "1.0 26 [0.7037037 1.       ]\n",
      "1.0 27 [0.67857143 1.        ]\n",
      "1.0 28 [0.68965517 1.        ]\n",
      "1.0 29 [0.7 1. ]\n",
      "1.0 30 [0.70967742 1.        ]\n",
      "1.0 31 [0.6875 1.    ]\n",
      "1.0 32 [0.6969697 1.       ]\n",
      "1.0 33 [0.67647059 1.        ]\n",
      "1.0 34 [0.65714286 1.        ]\n",
      "1.0 35 [0.63888889 1.        ]\n",
      "1.0 36 [0.64864865 1.        ]\n",
      "1.0 37 [0.63157895 1.        ]\n",
      "1.0 38 [0.64102564 1.        ]\n",
      "1.0 39 [0.625 1.   ]\n",
      "1.0 40 [0.63414634 1.        ]\n",
      "1.0 41 [0.64285714 1.        ]\n",
      "1.0 42 [0.65116279 1.        ]\n",
      "1.0 43 [0.65909091 1.        ]\n",
      "1.0 44 [0.64444444 1.        ]\n",
      "1.0 45 [0.65217391 1.        ]\n",
      "1.0 46 [0.63829787 1.        ]\n",
      "1.0 47 [0.64583333 1.        ]\n",
      "1.0 48 [0.63265306 1.        ]\n",
      "1.0 49 [0.64 1.  ]\n"
     ]
    }
   ],
   "source": [
    "# ica experiments #\n",
    "D = 500\n",
    "test_num = 50\n",
    "result_all_correct = []\n",
    "\n",
    "# alpha_set = np.linspace(0.0, 1.0, num=10)\n",
    "alpha_set = [1.0]\n",
    "\n",
    "confident_level = 0.05\n",
    "\n",
    "for alpha in alpha_set:\n",
    "    result_test_correct = np.zeros([2,test_num])\n",
    "    \n",
    "    for it in range(test_num):\n",
    "        n = 10000\n",
    "        d = 16\n",
    "        sigma_normal = 0.1\n",
    "        \n",
    "        X, Y = generate_ISA(n,d,sigma_normal,alpha)\n",
    "        rand_index = np.random.permutation(n)\n",
    "        X = X[rand_index]\n",
    "        Y = Y[rand_index]\n",
    "        \n",
    "        X_tensor, Y_tensor = torch.tensor(X), torch.tensor(Y)\n",
    "        \n",
    "        #test#\n",
    "\n",
    "        lfhsic_g = IndpTest_LFGaussian(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "        results_all1 = lfhsic_g.perform_test(rff_num = D, lr = 0.05, if_grid_search = False, debug = -1)\n",
    "        result_test_correct[0, it] = float(results_all1['h0_rejected'])\n",
    "\n",
    "        lfhsic_m = IndpTest_LFMahalanobis(X_tensor, Y_tensor, device, alpha=0.05, null_gamma = True, split_ratio = 0.5)\n",
    "        results_all2 = lfhsic_m.perform_test(rff_num = D, lr = 0.05, if_grid_search = False, debug = -1)\n",
    "        result_test_correct[1, it] = float(results_all2['h0_rejected'])\n",
    "\n",
    "        print(alpha, it, np.sum(result_test_correct,1)/(it+1))\n",
    "    \n",
    "    result_all_correct.append(result_test_correct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69d9efb-6409-43f6-9157-0c5b8deccdeb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6973ed9e-b21f-4122-b586-76a313aa199b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
