{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions\n",
    "from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook\n",
    "from kernelthinning.util import fprint  # for printing while flushing buffer\n",
    "from kernelthinning.tictoc import tic, toc # for timing blocks of code\n",
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "import pickle as pkl\n",
    "\n",
    "# Fitting linear models\n",
    "import statsmodels.api as sm\n",
    "from scipy.stats import multivariate_normal\n",
    "\n",
    "# plottibg libraries\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import pylab\n",
    "import seaborn as sns\n",
    "plt.style.use('seaborn-white')\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string\n",
    "from util_k_mmd import get_combined_mmd_filename\n",
    "\n",
    "from sklearn.linear_model import LinearRegression\n",
    "\n",
    "from util_k_mmd import *\n",
    "\n",
    "import time \n",
    "\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Compress import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PARAMETERS\n",
    "\n",
    "var = 1. # Variance\n",
    "d = int(100) \n",
    "nsamp = 4**7 # Number of samples\n",
    "# if args is None else args.d\n",
    "params_p = {\"name\": \"gauss\", \"var\": var, \"d\": int(d), \"saved_samples\": False}\n",
    "params_k_swap = {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "params_k_split = {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kernel_eval(x, y, params_k):\n",
    "    \"\"\"Returns matrix of kernel evaluations kernel(xi, yi) for each row index i.\n",
    "    x and y should have the same number of columns, and x should either have the\n",
    "    same shape as y or consist of a single row, in which case, x is broadcasted \n",
    "    to have the same shape as y.\n",
    "    \"\"\"\n",
    "    if params_k[\"name\"] in [\"gauss\", \"gauss_rt\"]:\n",
    "        k_vals = np.sum((x-y)**2,axis=1)\n",
    "        scale = -.5/params_k[\"var\"]\n",
    "        return(np.exp(scale*k_vals))\n",
    "    raise ValueError(\"Unrecognized kernel name {}\".format(params_k[\"name\"]))\n",
    "split_kernel = partial(kernel_eval, params_k=params_k_split)\n",
    "swap_kernel = partial(kernel_eval, params_k=params_k_swap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # COMPRESS CODE\n",
    "\n",
    "# def divide4(X):\n",
    "#     return np.array_split(X,4)\n",
    "\n",
    "# def combine4(Z):\n",
    "#     return np.concatenate(Z)\n",
    "\n",
    "# def Halve(Y,split_kernel=params_k_split, swap_kernel=params_k_swap, delta=0.5, seed=None, store_K=False):\n",
    "#     return kt.thin(Y,1,split_kernel, swap_kernel, delta, seed, store_K) \n",
    "\n",
    "\n",
    "# def size(X):\n",
    "#     a = np.shape(X)\n",
    "#     return a[0]\n",
    "\n",
    "\n",
    "# def compress(X,split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False): \n",
    "#     if size(X) == 1:\n",
    "#         return X \n",
    "#     else: \n",
    "#         Z = divide4(X)\n",
    "#         h=[]\n",
    "#         for x in Z:\n",
    "#             h.append(compress(x,split_kernel,swap_kernel, delta,seed,store_K))\n",
    "#         Y = combine4(h)\n",
    "# #print(size(h))\n",
    "#         return Halve(Y,split_kernel, swap_kernel, delta, seed, store_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# INITIAL TRY \n",
    "X = sample(nsamp, params_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # COMPRESS_trial CODE\n",
    "# def funhalve(X):\n",
    "#     A = np.array_split(X,2)\n",
    "#     return A[0]\n",
    "\n",
    "# def compress2(X,split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "# #     print(size(X))\n",
    "#     if size(X) == 1:\n",
    "#         return X\n",
    "#     else: \n",
    "#         Z = divide4(X)\n",
    "#         h=[]\n",
    "#         for x in Z:\n",
    "#             h.append(compress2(x,split_kernel,swap_kernel, delta,seed,store_K))\n",
    "#         Y = combine4(h)\n",
    "#         print(size(Y))\n",
    "#         return funhalve(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(size(X))\n",
    "# i=6\n",
    "# # while i < 7:\n",
    "# print(size(kt.thin(X,i,split_kernel,swap_kernel)))\n",
    "# #     i=i+1\n",
    "# # W=kt.thin(X,3,split_kernel,swap_kernel)\n",
    "# # size(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def IndexHalve(X,Y,split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "#     b = []\n",
    "#     i=0\n",
    "#     r = np.zeros(size(Y))\n",
    "#     for c in Y:\n",
    "#         b.append(X[int(c)])\n",
    "#         r[i]=c\n",
    "#         i=i+1\n",
    "#     b = np.array(b)\n",
    "# #     print(r)\n",
    "# #     print(b)\n",
    "#     g = kt.thin(b,1,split_kernel, swap_kernel, delta, seed, store_K)\n",
    "# #     print(g)\n",
    "#     s = []\n",
    "#     for i in g:\n",
    "#         s.append(int(r[int(i)]))\n",
    "# #     print(s)\n",
    "#     return s\n",
    "        \n",
    "# def IndexCompress(X,m,split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "# #     print(size(m))\n",
    "#     if size(m)==1:\n",
    "#         return m\n",
    "#     else: \n",
    "#         Z = divide4(m)\n",
    "#         h=[]\n",
    "#         for x in Z:\n",
    "# #             print(x)\n",
    "#             h.append(IndexCompress(X,x,split_kernel,swap_kernel, delta,seed,store_K))\n",
    "#         Y = combine4(h)\n",
    "# #         print(Y)\n",
    "# #print(size(h))\n",
    "# #         J = IndexHalve(X,Y,split_kernel, swap_kernel, delta, seed, store_K)\n",
    "# #         print(J)\n",
    "# #         print(size(J))\n",
    "# #         return J\n",
    "#         return IndexHalve(X,Y,split_kernel, swap_kernel, delta, seed, store_K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def IndexCompressFull(X,split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "#     return IndexCompress(X,np.array(range(size(X))),split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# size(X)\n",
    "W= IndexCompressFull(X)\n",
    "# size(W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# time(IndexCompressFull(X))\n",
    "size(IndexCompressFull(X))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size(kt.thin(X, 4 , split_kernel,swap_kernel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions\n",
    "from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook\n",
    "from kernelthinning.util import fprint  # for printing while flushing buffer\n",
    "from kernelthinning.tictoc import tic, toc # for timing blocks of code\n",
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "import pickle as pkl\n",
    "\n",
    "# Fitting linear models\n",
    "import statsmodels.api as sm\n",
    "from scipy.stats import multivariate_normal\n",
    "\n",
    "# plottibg libraries\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import pylab\n",
    "import seaborn as sns\n",
    "plt.style.use('seaborn-white')\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string\n",
    "from util_k_mmd import get_combined_mmd_filename\n",
    "\n",
    "# PARAMETERS\n",
    "\n",
    "# var = 1. # Variance\n",
    "# d = int(3) \n",
    "# nsamp = 4**5 # Number of samples\n",
    "# # if args is None else args.d\n",
    "# params_p = {\"name\": \"gauss\", \"var\": var, \"d\": int(d), \"saved_samples\": False}\n",
    "# params_k_swap = {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "# params_k_split = {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}\n",
    "\n",
    "def kernel_eval(x, y, params_k):\n",
    "    \"\"\"Returns matrix of kernel evaluations kernel(xi, yi) for each row index i.\n",
    "    x and y should have the same number of columns, and x should either have the\n",
    "    same shape as y or consist of a single row, in which case, x is broadcasted \n",
    "    to have the same shape as y.\n",
    "    \"\"\"\n",
    "    if params_k[\"name\"] in [\"gauss\", \"gauss_rt\"]:\n",
    "        k_vals = np.sum((x-y)**2,axis=1)\n",
    "        scale = -.5/params_k[\"var\"]\n",
    "        return(np.exp(scale*k_vals))\n",
    "    raise ValueError(\"Unrecognized kernel name {}\".format(params_k[\"name\"]))\n",
    "split_kernel = partial(kernel_eval, params_k=params_k_split)\n",
    "swap_kernel = partial(kernel_eval, params_k=params_k_swap)\n",
    "\n",
    "# Sample from X\n",
    "# X = sample(nsamp, params_p)\n",
    "\n",
    "def divide4(X):\n",
    "    return np.array_split(X,4)\n",
    "\n",
    "def combine4(Z):\n",
    "    return np.concatenate(Z)\n",
    "\n",
    "def Halve(Y,split_kernel=params_k_split, swap_kernel=params_k_swap, delta=0.5, seed=None, store_K=False):\n",
    "    return kt.thin(Y,1,split_kernel, swap_kernel, delta, seed, store_K) \n",
    "\n",
    "\n",
    "def size(X):\n",
    "    a = np.shape(X)\n",
    "    return a[0]\n",
    "\n",
    "\n",
    "def IndexHalve(X,Y,split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "    b = []\n",
    "    i=0\n",
    "    r = np.zeros(size(Y))\n",
    "    for c in Y:\n",
    "        b.append(X[int(c)])\n",
    "        r[i]=c\n",
    "        i=i+1\n",
    "    b = np.array(b)\n",
    "#     print(r)\n",
    "#     print(b)\n",
    "    g = kt.thin(b,1,split_kernel, swap_kernel, delta, seed, store_K)\n",
    "    print(g)\n",
    "#     print(g)\n",
    "    s = []\n",
    "    for i in g:\n",
    "        s.append(int(r[int(i)]))\n",
    "#     print(s)\n",
    "    return s\n",
    "        \n",
    "def IndexCompress(X,m,split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "#     print(size(m))\n",
    "    if size(m)==1:\n",
    "        return m\n",
    "    else: \n",
    "        Z = divide4(m)\n",
    "        h=[]\n",
    "        for x in Z:\n",
    "#             print(x)\n",
    "            h.append(IndexCompress(X,x,split_kernel,swap_kernel, delta,seed,store_K))\n",
    "        Y = combine4(h)\n",
    "#         print(Y)\n",
    "#print(size(h))\n",
    "#         J = IndexHalve(X,Y,split_kernel, swap_kernel, delta, seed, store_K)\n",
    "#         print(J)\n",
    "#         print(size(J))\n",
    "#         return J\n",
    "        return IndexHalve(X,Y,split_kernel, swap_kernel, delta, seed, store_K)\n",
    "\n",
    "\n",
    "def IndexCompressFull(X,split_kernel, swap_kernel, delta=0.5, seed=None, store_K=False):\n",
    "    return IndexCompress(X,np.array(range(size(X))),split_kernel=split_kernel, swap_kernel=swap_kernel, delta=0.5, seed=None, store_K=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size(IndexCompressFull(X, split_kernel,swap_kernel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time(IndexCompressFull(X,split_kernel,swap_kernel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time(kt.thin(X,6,split_kernel,swap_kernel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(6):\n",
    "    nsamp_1 = 4**i\n",
    "    X_1 = sample(nsamp_1, params_p)\n",
    "    print(time(IndexCompressFull(X_1,split_kernel,swap_kernel)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "IndexCompressFull"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squared_mmd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from util_k_mmd import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squared_mmd(params_k_swap,params_p,X[IndexCompressFull(X,split_kernel,swap_kernel)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squared_mmd(params_k_swap,params_p,X[kt.thin(X,4,split_kernel,swap_kernel)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsamp = 4**6 \n",
    "W = sample(nsamp, params_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## COMPRESS MMD\n",
    "\n",
    "a = []\n",
    "for i in range(20):\n",
    "    print(i)\n",
    "    nsamp = 4**i\n",
    "    W = sample(nsamp, params_p)\n",
    "    a.append( [i,squared_mmd(params_k_swap,params_p,W[IndexCompressFull(W,split_kernel,swap_kernel)])] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "squared_mmd(params_k_swap,params_p,W[kt.thin(W,6,split_kernel,swap_kernel)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## KERNEL THinning MMD\n",
    "\n",
    "b = []\n",
    "for i in range(10):\n",
    "    print(i)http://localhost:8080/notebooks/code/KT_Overleaf/Code/Notebooks/Compress.ipynb#\n",
    "    nsamp2 = 4**i\n",
    "    W2 = sample(nsamp2, params_p)\n",
    "    b.append( [i , squared_mmd(params_k_swap,params_p,W2[kt.thin(W2,i,split_kernel,swap_kernel)]) ] ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Compress_MMD_Data = []\n",
    "Compress_Time_Data = []\n",
    "Thinning_MMD_Data =[]\n",
    "Thinning_Time_Data = []\n",
    "DATASETS =[]\n",
    "for i in range(8):\n",
    "    print(i)\n",
    "    nsamp = 4**i\n",
    "    W = sample(nsamp, params_p)\n",
    "    time_1 = time.perf_counter()\n",
    "    Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "    time_2 = time.perf_counter()\n",
    "    Compress_MMD_Data.append( [i,squared_mmd(params_k_swap,params_p,W[Wcompress])])\n",
    "    Compress_Time_Data.append( [i, time_2 - time_1])\n",
    "    DATASETS.append(W)\n",
    "#     b.append( [i , squared_mmd(params_k_swap,params_p,W[kt.thin(W,i,split_kernel,swap_kernel)]) ] )\n",
    "        \n",
    "Compress_MMD_DATA = np.array(Compress_MMD_Data)\n",
    "Compress_Time_Data = np.array(Compress_Time_Data)\n",
    "\n",
    "# with open('COMPRESS_MMD_DATA_d=100', 'wb') as f:\n",
    "#     np.save(f, Compress_MMD_Data)\n",
    "\n",
    "# with open('COMPRESS_TIME_DATA_d=100', 'wb') as f:\n",
    "#     np.save(f, Compress_Time_Data)\n",
    "    \n",
    "# f=open('MMD_Data_1','w')\n",
    "# f.write(a)\n",
    "# g.open('MMD_Data_2','w')\n",
    "# g.write(b)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Compress_MMD_Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Thinning_Time_Data\n",
    "# Compress_Time_Data\n",
    "Thinning_MMD_Data = np.array(Thinning_MMD_Data)\n",
    "Thinning_Time_Data = np.array(Thinning_Time_Data)\n",
    "\n",
    "with open('THINNING_MMD_DATA_11', 'wb') as f:\n",
    "    np.save(f, Thinning_MMD_Data)\n",
    "\n",
    "with open('Thinning_TIME_DATA_11', 'wb') as f:\n",
    "    np.save(f, Thinning_Time_Data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    W = DATASETS[i]\n",
    "    time_1 = time.perf_counter()\n",
    "    WCompress = kt.thin(W,i,split_kernel,swap_kernel)\n",
    "    time_2 = time.perf_counter()\n",
    "    Thinning_MMD_Data.append([i, squared_mmd(params_k_swap,params_p,W[WCompress])])\n",
    "    Thinning_Time_Data.append( [i, time_2 -time_1])\n",
    "\n",
    "Thinning_MMD_Data = np.array(Thinning_MMD_Data)\n",
    "Thinning_Time_Data = np.array(Thinning_Time_Data)\n",
    "\n",
    "with open('THINNING_MMD_DATA_10', 'wb') as f:\n",
    "    np.save(f, Thinning_MMD_Data)\n",
    "\n",
    "with open('Thinning_TIME_DATA_10', 'wb') as f:\n",
    "    np.save(f, Thinning_Time_Data)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(8,10):\n",
    "    print(i)\n",
    "    W = DATASETS[i]\n",
    "    time_1 = time.perf_counter()\n",
    "    WCompress = kt.thin(W,i,split_kernel,swap_kernel)\n",
    "    time_2 = time.perf_counter()\n",
    "    Thinning_MMD_Data.insert([i, squared_mmd(params_k_swap,params_p,W[WCompress])])\n",
    "    Thinning_Time_Data.insert( [i, time_2 -time_1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "    print(i)\n",
    "    W1= G[i]\n",
    "    b.append( [i , squared_mmd(params_k_swap,params_p,W1[kt.thin(W1,i,split_kernel,swap_kernel)]) ] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.array(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "    print(i)\n",
    "    W1= G[i]\n",
    "    b.append( [i , squared_mmd(params_k_swap,params_p,W1[kt.thin(W1,i,split_kernel,swap_kernel)]) ] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tempfile import TemporaryFile\n",
    "# open(\"MMD_DATA_1\", \"w\")\n",
    "with open('MMD_DATA_1', 'wb') as f:\n",
    "    np.save(f, a )\n",
    "\n",
    "#     np.save(f, np.array([1, 3]))\n",
    "\n",
    "\n",
    "# f = open(\"MMD_DATA_1\",\"w\")\n",
    "# np.save(f,a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.array(a)\n",
    "b = np.array(b)\n",
    "np.save(MMD_DATA_1,a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = a[:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = b[:,1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(a[:,0], a[:,1], b[:,0] , b[:,1]   )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(a[:,0], a[:,1]- b[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "times_compress = [] \n",
    "Datasets = []\n",
    "\n",
    "for i in range(20):\n",
    "    print(i)\n",
    "    nsamp = 4**i\n",
    "    W = sample(nsamp, params_p)\n",
    "    tic = time.perf_counter()\n",
    "    Wsparse = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "    toc = time.perf_counter()\n",
    "    times_compress.append([i, toc -tic ])\n",
    "    Datasets.append(W)\n",
    "\n",
    "times_compress = np.array(times_compress)\n",
    "\n",
    "with open('MMD_DATA_7', 'wb') as f:\n",
    "    np.save(f, times_compress )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "times_thinning = [] \n",
    "# Datasets = []\n",
    "\n",
    "for i in range(20):\n",
    "    print(i)\n",
    "#     nsamp = 4**i\n",
    "#     W = sample(nsamp, params_p)\n",
    "    W2 = Datasets[i]\n",
    "    tic = time.perf_counter()\n",
    "    W2sparse = kt.thin(W2,i,split_kernel,swap_kernel)\n",
    "    toc = time.perf_counter()\n",
    "    times_thinning.append([i, toc -tic ])\n",
    "#     Datasets.append(W)\n",
    "\n",
    "with open('MMD_DATA_8', 'wb') as f:\n",
    "    np.save(f, times_thinning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(2*times_compress[:,0],np.log(times_compress[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = compute_diag_mog_params(10)\n",
    "X = sample(10,p)\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mixtures of Gaussians \n",
    "Compress_MOG_MMD_Data = []\n",
    "Compress_MOG_Time_Data = []\n",
    "MOG_DATASETS =[]\n",
    "\n",
    "for j in range(1,10):\n",
    "    print(j)\n",
    "    p = compute_diag_mog_params(j)\n",
    "    for i in range(10):\n",
    "        print(i)\n",
    "        nsamp = 4**i\n",
    "        W = sample(nsamp, p)\n",
    "        time_1 = time.perf_counter()\n",
    "        Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "        time_2 = time.perf_counter()\n",
    "        Compress_MOG_MMD_Data.append( [j,i,squared_mmd(params_k_swap,params_p,W[Wcompress])])\n",
    "        Compress_MOG_Time_Data.append( [j,i, time_2 - time_1])\n",
    "        DATASETS.append([j,i,W])\n",
    "#     b.append( [i , squared_mmd(params_k_swap,params_p,W[kt.thin(W,i,split_kernel,swap_kernel)]) ] )\n",
    "\n",
    "\n",
    "        \n",
    "Compress_MOG_MMD_DATA = np.array(Compress_MOG_MMD_Data)\n",
    "Compress_MOG_Time_Data = np.array(Compress_MOG_Time_Data)\n",
    "\n",
    "with open('COMPRESS_MOG_MMD_DATA_12', 'wb') as f:\n",
    "    np.save(f, Compress_MOG_MMD_Data)\n",
    "\n",
    "with open('COMPRESS_MOG_TIME_DATA_12', 'wb') as f:\n",
    "    np.save(f, Compress_MOG_Time_Data)\n",
    "    \n",
    "# f=open('MMD_Data_1','w')\n",
    "# f.write(a)\n",
    "# g.open('MMD_Data_2','w')\n",
    "# g.write(b)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size(W[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## IID Subsampling \n",
    "\n",
    "def IndexIIDSub(X,m,l):\n",
    "    Sublist = np.random.choice(m,l)\n",
    "    return Sublist\n",
    "\n",
    "def IndexIIDSubFull(X,l):\n",
    "    return IndexIIDSub(X, np.array(range(size(X))) , l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "o = np.array(range(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.choice(o,99)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Compress_MMD_Data1 = np.empty((5,8,20))\n",
    "Compress_Time_Data1 = np.empty((5,8,20))\n",
    "IID_MMD_Data1 = np.empty((5,8,20))\n",
    "IID_time_Data1 = np.empty((5,8,20))\n",
    "# Thinning_MMD_Data =[]\n",
    "# Thinning_Time_Data = []\n",
    "\n",
    "# DATASETS =np.empty((4,8,19))\n",
    "\n",
    "for j in [2,3,4]:\n",
    "    p={\"name\": \"gauss\", \"var\": var, \"d\": int(j), \"saved_samples\": False}\n",
    "    for i in range(8):\n",
    "        print(i)\n",
    "        nsamp = 4**i\n",
    "        for l in range(20):\n",
    "            W = sample(nsamp,p)\n",
    "#             W = sample(nsamp, params_p)\n",
    "            time_1 = time.perf_counter()\n",
    "            Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "            time_2 = time.perf_counter()\n",
    "            Compress_MMD_Data1[j,i,l] = squared_mmd(params_k_swap,params_p,W[Wcompress])\n",
    "            Compress_Time_Data1[j,i,l] =  time_2 - time_1\n",
    "            time_3 = time.perf_counter()\n",
    "            Wsubsample = IndexIIDSubFull(W,2**i)\n",
    "            time_4 = time.perf_counter()\n",
    "            IID_MMD_Data1 = squared_mmd(params_k_swap,params_p,W[Wsubsample])\n",
    "            IID_time_Data1[j,i,l] = time_4 - time_3\n",
    "#             DATASETS[j,i,l] = W\n",
    "\n",
    "# DATASETS =[]\n",
    "# for i in range(8):\n",
    "#     print(i)\n",
    "#     nsamp = 4**i\n",
    "#     W = sample(nsamp, params_p)\n",
    "#     time_1 = time.perf_counter()\n",
    "#     Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "#     time_2 = time.perf_counter()\n",
    "#     Compress_MMD_Data.append( [i,squared_mmd(params_k_swap,params_p,W[Wcompress])])\n",
    "#     Compress_Time_Data.append( [i, time_2 - time_1])\n",
    "#     DATASETS.append(W)\n",
    "#     b.append( [i , squared_mmd(params_k_swap,params_p,W[kt.thin(W,i,split_kernel,swap_kernel)]) ] )\n",
    "        \n",
    "Compress_MMD_DATA = np.array(Compress_MMD_Data)\n",
    "Compress_Time_Data = np.array(Compress_Time_Data)\n",
    "IID_MMD_Data = np.array(IID_MMD_Data)\n",
    "IID_time_Data= np.array(IID_time_Data)\n",
    "\n",
    "# with open('COMPRESS_MMD_DATA_d=100', 'wb') as f:\n",
    "#     np.save(f, Compress_MMD_Data)\n",
    "\n",
    "# with open('COMPRESS_TIME_DATA_d=100', 'wb') as f:\n",
    "#     np.save(f, Compress_Time_Data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('COMPRESS_MMD_DATA_MULTI', 'wb') as f:\n",
    "    np.save(f, Compress_MMD_Data1)\n",
    "\n",
    "with open('COMPRESS_TIME_DATA_MULTI', 'wb') as f:\n",
    "    np.save(f, Compress_Time_Data1)\n",
    "    \n",
    "with open('IID_MMD_DATA_MULTI','wb') as f:\n",
    "    np.save(f,IID_MMD_Data1)\n",
    "\n",
    "with open('IID_TIME_DATA_MULTI','wb') as f:\n",
    "    np.save(f,IID_time_Data1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.empty( (3,3,3) )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a[1,2,2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Compress_MMD_Data2 = np.empty((5,8,100))\n",
    "Compress_Time_Data2 = np.empty((5,8,100))\n",
    "IID_MMD_Data2 = np.empty((5,8,100))\n",
    "IID_time_Data2 = np.empty((5,8,100))\n",
    "# Thinning_MMD_Data =[]\n",
    "# Thinning_Time_Data = []\n",
    "\n",
    "# DATASETS =np.empty((4,8,19))\n",
    "\n",
    "for j in [2,3,4]:\n",
    "    p={\"name\": \"gauss\", \"var\": var, \"d\": int(j), \"saved_samples\": False}\n",
    "    for i in range(8):\n",
    "        print(i)\n",
    "        nsamp = 4**i\n",
    "        for l in range(100):\n",
    "            W = sample(nsamp,p)\n",
    "#             W = sample(nsamp, params_p)\n",
    "            time_1 = time.perf_counter()\n",
    "            Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "            time_2 = time.perf_counter()\n",
    "            Compress_MMD_Data2[j,i,l] = squared_mmd(params_k_swap,params_p,W[Wcompress])\n",
    "            Compress_Time_Data2[j,i,l] =  time_2 - time_1\n",
    "            time_3 = time.perf_counter()\n",
    "            Wsubsample = IndexIIDSubFull(W,2**i)\n",
    "            time_4 = time.perf_counter()\n",
    "            IID_MMD_Data2 = squared_mmd(params_k_swap,params_p,W[Wsubsample])\n",
    "            IID_time_Data2[j,i,l] = time_4 - time_3\n",
    "#             DATASETS[j,i,l] = W\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Compress_MMD_Data2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for j in [0,10,20,30,40,50,60,70,80,80,100]:\n",
    "    p={\"name\": \"gauss\", \"var\": var, \"d\": int(j), \"saved_samples\": False}\n",
    "    nsamp = 4**7\n",
    "    print('dim=')\n",
    "    print(j)\n",
    "    for l in range(5):\n",
    "            W = sample(nsamp,p)\n",
    "#             W = sample(nsamp, params_p)\n",
    "            time_1 = time.perf_counter()\n",
    "            Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "            time_2 = time.perf_counter()\n",
    "            print(squared_mmd(params_k_swap,p,W[Wcompress]))\n",
    "#             Compress_MMD_Data2[j,i,l] = squared_mmd(params_k_swap,params_p,W[Wcompress])\n",
    "#             Compress_Time_Data2[j,i,l] =  time_2 - time_1\n",
    "#             time_3 = time.perf_counter()\n",
    "#             Wsubsample = IndexIIDSubFull(W,2**i)\n",
    "#             time_4 = time.perf_counter()\n",
    "#             IID_MMD_Data2 = squared_mmd(params_k_swap,params_p,W[Wsubsample])\n",
    "#             IID_time_Data2[j,i,l] = time_4 - time_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i= 7\n",
    "print(i)\n",
    "nsamp = 4**i\n",
    "W = sample(nsamp, params_p)\n",
    "# time_1 = time.perf_counter()\n",
    "Wcompress = IndexCompressFull(W,split_kernel,swap_kernel)\n",
    "# time_2 = time.perf_counter()\n",
    "#     Compress_MMD_Data.append( [i,squared_mmd(params_k_swap,params_p,W[Wcompress])])\n",
    "#     Compress_Time_Data.append( [i, time_2 - time_1])\n",
    "print(squared_mmd(params_k_swap,params_p,W[Wcompress]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
