{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ca240e68",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting cvxpy\n",
      "  Downloading cvxpy-1.2.2-cp39-cp39-macosx_10_9_x86_64.whl (882 kB)\n",
      "\u001b[K     |████████████████████████████████| 882 kB 7.2 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: scipy>=1.1.0 in /Users/rexhsieh/opt/anaconda3/lib/python3.9/site-packages (from cvxpy) (1.7.1)\n",
      "Requirement already satisfied: numpy>=1.15 in /Users/rexhsieh/opt/anaconda3/lib/python3.9/site-packages (from cvxpy) (1.20.3)\n",
      "Collecting scs>=1.1.6\n",
      "  Downloading scs-3.2.2-cp39-cp39-macosx_10_9_x86_64.whl (11.6 MB)\n",
      "\u001b[K     |████████████████████████████████| 11.6 MB 54.8 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting ecos>=2\n",
      "  Downloading ecos-2.0.10-cp39-cp39-macosx_10_9_x86_64.whl (88 kB)\n",
      "\u001b[K     |████████████████████████████████| 88 kB 24.1 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting osqp>=0.4.1\n",
      "  Downloading osqp-0.6.2.post5-cp39-cp39-macosx_10_9_x86_64.whl (249 kB)\n",
      "\u001b[K     |████████████████████████████████| 249 kB 111.7 MB/s eta 0:00:01\n",
      "\u001b[?25hCollecting qdldl\n",
      "  Downloading qdldl-0.1.5.post2-cp39-cp39-macosx_10_9_x86_64.whl (98 kB)\n",
      "\u001b[K     |████████████████████████████████| 98 kB 22.8 MB/s eta 0:00:01\n",
      "\u001b[?25hInstalling collected packages: qdldl, scs, osqp, ecos, cvxpy\n",
      "Successfully installed cvxpy-1.2.2 ecos-2.0.10 osqp-0.6.2.post5 qdldl-0.1.5.post2 scs-3.2.2\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "06568c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pylab as plt\n",
    "import ot\n",
    "import cvxpy as cp\n",
    "import seaborn as sns\n",
    "\n",
    "# Supplementary Packages\n",
    "#import scipy.stats as stats\n",
    "#import scipy.special as sps\n",
    "#import time as t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30b38653",
   "metadata": {},
   "source": [
    "## Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "532c1063",
   "metadata": {},
   "outputs": [],
   "source": [
    "def baryc_proj(source, target, method = 'emd'):\n",
    "    \n",
    "    n1 = source.shape[0]\n",
    "    n2 = target.shape[0]   \n",
    "    p = source.shape[1]\n",
    "    a_ones, b_ones = np.ones((n1,)) / n1, np.ones((n2,)) / n2\n",
    "    \n",
    "    M = ot.dist(source, target)\n",
    "    M = M.astype('float64')\n",
    "    M /= M.max()\n",
    "    \n",
    "    if method == 'emd':\n",
    "        OTplan = ot.emd(a_ones, b_ones, M, numItermax = 1e7)\n",
    "        \n",
    "    elif method == 'entropic':\n",
    "        OTplan = ot.bregman.sinkhorn_stabilized(a_ones, b_ones, M, reg = 5*1e-3)\n",
    "    \n",
    "    # initialization\n",
    "    OTmap = np.empty((0, p))\n",
    "\n",
    "    for i in range(n1):\n",
    "        \n",
    "        # normalization\n",
    "        OTplan[i,:] = OTplan[i,:] / sum(OTplan[i,:])\n",
    "    \n",
    "        # conditional expectation\n",
    "        OTmap = np.vstack([OTmap, (target.T @ OTplan[i,:])])\n",
    "    \n",
    "    OTmap = np.array(OTmap).astype('float32')\n",
    "    \n",
    "    return(OTmap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a0f6c1b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def DSCreplication(target, controls, method = 'emd', projtype = 'wass'):\n",
    "    \n",
    "    n = target.shape[0]\n",
    "    d = target.shape[1]\n",
    "    J = len(controls)\n",
    "    S = np.mean(target)*n*d*J # Stabilizer: to ground the optimization objective\n",
    "    \n",
    "    \n",
    "    # Barycentric Projection\n",
    "    G_list = []\n",
    "    proj_list = []\n",
    "    for i in range(len(controls)):\n",
    "        temp = baryc_proj(target, controls[i], method)\n",
    "        G_list.append(temp)\n",
    "        proj_list.append(temp - target)\n",
    "    \n",
    "    \n",
    "    # Obtain optimal weights\n",
    "    mylambda = cp.Variable(J)\n",
    "\n",
    "    objective = cp.Minimize(\n",
    "                    cp.sum_squares(\n",
    "                    cp.sum([a*b for a,b in zip(mylambda, proj_list)], axis = 0))/S\n",
    "                    )\n",
    "    \n",
    "    constraints = [mylambda >= 0, mylambda <= 1, cp.sum(mylambda) == 1]\n",
    "\n",
    "    prob = cp.Problem(objective, constraints)\n",
    "    prob.solve()\n",
    "\n",
    "    \n",
    "    testproj = sum([a*b for a,b in zip(weights, G_list)])\n",
    "    measureweights = [ot.unif(n)]*J\n",
    "    \n",
    "    if projtype == 'eucl':\n",
    "        projection = testproj\n",
    "    elif projtype == 'wass':\n",
    "        projection = ot.lp.free_support_barycenter(G_list, measureweights, X_init = testproj, weights = weights)\n",
    "\n",
    "    \n",
    "    return(weights, projection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "90fd3262",
   "metadata": {},
   "outputs": [],
   "source": [
    "target = Y1\n",
    "\n",
    "controls = [Y2, Y3, Y4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6eeac868",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'method' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/964293986.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0mproj_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontrols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m     \u001b[0mtemp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbaryc_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrols\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mG_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0mproj_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'method' is not defined"
     ]
    }
   ],
   "source": [
    "n = target.shape[0]\n",
    "d = target.shape[1]\n",
    "J = len(controls)\n",
    "S = np.mean(target)*n*d*J # Stabilizer: to ground the optimization objective\n",
    "\n",
    "\n",
    "# Barycentric Projection\n",
    "G_list = []\n",
    "proj_list = []\n",
    "for i in range(len(controls)):\n",
    "    temp = baryc_proj(target, controls[i], \"emd\")\n",
    "    G_list.append(temp)\n",
    "    proj_list.append(temp - target)\n",
    "\n",
    "\n",
    "# Obtain optimal weights\n",
    "mylambda = cp.Variable(J)\n",
    "\n",
    "objective = cp.Minimize(\n",
    "                cp.sum_squares(\n",
    "                cp.sum([a*b for a,b in zip(mylambda, proj_list)], axis = 0))/S\n",
    "                )\n",
    "\n",
    "constraints = [mylambda >= 0, mylambda <= 1, cp.sum(mylambda) == 1]\n",
    "\n",
    "prob = cp.Problem(objective, constraints)\n",
    "prob.solve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5c09fdd",
   "metadata": {},
   "source": [
    "## Mixed Multivariate Normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c3cef93e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mixed_multi_gauss(mean1, mean2, mean3, cov1, cov2, cov3, samplesize, partition1, partition2):\n",
    "    \n",
    "    size1 = int(samplesize * partition1)\n",
    "    size2 = int(samplesize * partition2)\n",
    "    size3 = int(samplesize - size1 - size2)\n",
    "    \n",
    "    gauss1 = np.random.multivariate_normal(mean = mean1, cov = cov1, size = size1)\n",
    "    gauss2 = np.random.multivariate_normal(mean = mean2, cov = cov2, size = size2)\n",
    "    gauss3 = np.random.multivariate_normal(mean = mean3, cov = cov3, size = size3)\n",
    "    \n",
    "    mixed = np.concatenate((gauss1, gauss2, gauss3), axis = 0)\n",
    "    np.random.shuffle(mixed)\n",
    "    \n",
    "    return(mixed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d957c4a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(31)\n",
    "\n",
    "dim = 10\n",
    "obs = 10000\n",
    "\n",
    "mu1 = [10]*dim\n",
    "mu2 = [50]*dim\n",
    "mu3 = [200]*dim\n",
    "mu4 = [-50]*dim\n",
    "mu5 = [-100]*dim\n",
    "\n",
    "covmat = np.full((dim, dim), 0.5)\n",
    "np.fill_diagonal(covmat, 1)\n",
    "\n",
    "\n",
    "X1 = np.random.multivariate_normal(mean = mu1, cov = covmat, size = obs)\n",
    "X2 = np.random.multivariate_normal(mean = mu2, cov = covmat, size = obs)\n",
    "X3 = np.random.multivariate_normal(mean = mu3, cov = covmat, size = obs)\n",
    "X4 = np.random.multivariate_normal(mean = mu4, cov = covmat, size = obs)\n",
    "\n",
    "\n",
    "Y1 = mixed_multi_gauss(mu1, mu2, mu3, covmat, covmat, covmat, obs, 0.3, 0.6)\n",
    "Y2 = mixed_multi_gauss(mu1, mu2, mu3, covmat, covmat, covmat, obs, 0.6, 0.3)\n",
    "Y3 = mixed_multi_gauss(mu2, mu3, mu4, covmat, covmat, covmat, obs, 0.5, 0.4)\n",
    "Y4 = mixed_multi_gauss(mu1, mu3, mu4, covmat, covmat, covmat, obs, 0.6, 0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fa612c5",
   "metadata": {},
   "source": [
    "## X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "601df883",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/4101260807.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mweightsX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprojectionX\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDSCreplication\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mX2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mweightsX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/1954538347.py\u001b[0m in \u001b[0;36mDSCreplication\u001b[0;34m(target, controls, method, projtype)\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0mproj_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontrols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0mtemp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbaryc_proj\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontrols\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m         \u001b[0mG_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0mproj_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtemp\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/542484844.py\u001b[0m in \u001b[0;36mbaryc_proj\u001b[0;34m(source, target, method)\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'emd'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m         \u001b[0mOTplan\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0memd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_ones\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_ones\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumItermax\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1e7\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'entropic'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/lib/python3.9/site-packages/ot/lp/__init__.py\u001b[0m in \u001b[0;36memd\u001b[0;34m(a, b, M, numItermax, log, center_dual, numThreads)\u001b[0m\n\u001b[1;32m    319\u001b[0m     \u001b[0mnumThreads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_number_threads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumThreads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    320\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m     \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_code\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0memd_c\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mM\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumItermax\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumThreads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    323\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mcenter_dual\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "weightsX, projectionX = DSCreplication(X1, [X2, X3, X4])\n",
    "weightsX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a416fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "meansX = pd.concat((pd.DataFrame(X1).describe().iloc[1,:], \n",
    "           pd.DataFrame(projectionX).describe().iloc[1,:]), axis = 1).transpose()\n",
    "meansX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14124da6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "sns.histplot(np.mean(X1, axis = 1), color = 'skyblue', alpha = 1)\n",
    "sns.histplot(np.mean(projectionX, axis = 1), color = 'orange', alpha = 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "704ae985",
   "metadata": {},
   "source": [
    "# Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d7ee0f57",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'weights' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/833946819.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mweightsY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprojectionY\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDSCreplication\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mY2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mY4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mweightsY\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/var/folders/jm/_lq4nz4n6bx4dr2qgtlnbqb40000gn/T/ipykernel_97780/1954538347.py\u001b[0m in \u001b[0;36mDSCreplication\u001b[0;34m(target, controls, method, projtype)\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m     \u001b[0mtestproj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mb\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     33\u001b[0m     \u001b[0mmeasureweights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mJ\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'weights' is not defined"
     ]
    }
   ],
   "source": [
    "weightsY, projectionY = DSCreplication(Y1, [Y2, Y3, Y4])\n",
    "weightsY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f82749",
   "metadata": {},
   "outputs": [],
   "source": [
    "meansY = pd.concat((pd.DataFrame(Y1).describe().iloc[1,:], \n",
    "           pd.DataFrame(projectionY).describe().iloc[1,:]), axis = 1).transpose()\n",
    "meansY"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c9f128",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "sns.histplot(np.mean(Y1, axis = 1), color = 'royalblue')\n",
    "sns.histplot(np.mean(projectionY, axis = 1), color = 'darkorange', alpha = 0.5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
