{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 633,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.sparse import coo_matrix\n",
    "import cvxpy as cp\n",
    "import math\n",
    "from gurobi_optimods.min_cost_flow import min_cost_flow_scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 634,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(n, skew_function):\n",
    "    res = []\n",
    "    fa = []\n",
    "    ya = []\n",
    "    for _ in range(n):\n",
    "        f = np.random.uniform()\n",
    "        y = int(np.random.uniform() > 1 - skew_function(f))\n",
    "        fa.append(f)\n",
    "        ya.append(y)\n",
    "    return (np.array(fa), np.array(ya))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 635,
   "metadata": {},
   "outputs": [],
   "source": [
    "S_calibrated = []\n",
    "for i in range(10):\n",
    "    S_calibrated.append(prepare_dataset(2**10, lambda x:(x+0.01)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 637,
   "metadata": {},
   "outputs": [],
   "source": [
    "class node(object):\n",
    "    def __init__(self, l, r):\n",
    "        self.l = l\n",
    "        self.r = r\n",
    "        self.set = None\n",
    "        self.add = None\n",
    "        self.left = None\n",
    "        self.right = None\n",
    "\n",
    "def compose(root, c, func = True):\n",
    "    if func:\n",
    "        root.set = c\n",
    "        root.add = None\n",
    "    else:\n",
    "        if root.add != None:\n",
    "            root.add += c\n",
    "        elif root.set != None:\n",
    "            root.set += c\n",
    "        else:\n",
    "            root.add = c\n",
    "\n",
    "class segment_tree(object):\n",
    "    def __init__(self, size):\n",
    "        def createTree(l, r):\n",
    "            if l > r:\n",
    "                return None\n",
    "            if l == r:\n",
    "                n = node(l, r)\n",
    "                return n\n",
    "            mid = (l + r) // 2       \n",
    "            root = node(l, r)\n",
    "            root.left = createTree(l, mid)\n",
    "            root.right = createTree(mid+1, r)\n",
    "            return root\n",
    "        self.root = createTree(0, size-1)\n",
    "        self.size = size\n",
    "            \n",
    "    def apply(self, l, r, c, func = True):                    \n",
    "        def applyHelper(root, l, r, c, func = True):\n",
    "            if root.l > r or root.r <l:\n",
    "                return\n",
    "            if root.l >= l and root.r <= r:\n",
    "                compose(root, c, func)\n",
    "                return\n",
    "            if root.add != None:\n",
    "                compose(root.left, root.add, False)\n",
    "                compose(root.right, root.add, False)\n",
    "            elif root.set != None:\n",
    "                compose(root.left, root.set, True)\n",
    "                compose(root.right, root.set, True)                \n",
    "            applyHelper(root.left, l, r, c, func)\n",
    "            applyHelper(root.right, l, r, c, func)\n",
    "            root.add = None\n",
    "            root.set = None\n",
    "            return \n",
    "        \n",
    "        return applyHelper(self.root, l, r, c, func)\n",
    "    \n",
    "    \n",
    "    def access(self, ind):\n",
    "        def accessHelper(root, i):\n",
    "            if root.l == root.r:\n",
    "                if root.add != None:\n",
    "                    return root.add\n",
    "                elif root.set != None:\n",
    "                    return root.set\n",
    "                else:\n",
    "                    return 0\n",
    "            if root.add != None:\n",
    "                compose(root.left, root.add, False)\n",
    "                compose(root.right, root.add, False)\n",
    "                root.add = None\n",
    "            elif root.set != None:\n",
    "                return root.set\n",
    "            if i <= root.left.r:\n",
    "                return accessHelper(root.left, i)\n",
    "            else:\n",
    "                return accessHelper(root.right, i)\n",
    "        return accessHelper(self.root, ind)\n",
    "    \n",
    "    def add(self, l, r, c):\n",
    "        self.apply(l, r, c, False)\n",
    "        return\n",
    "    \n",
    "    def set(self, l, r, c):\n",
    "        self.apply(l, r, c, True)\n",
    "        return    \n",
    "    \n",
    "    def binary_search(self, c):\n",
    "        low = 0\n",
    "        high = self.size - 1\n",
    "        mid = 0\n",
    "        t = 0\n",
    "        while low <= high:\n",
    "            mid = (high + low) // 2\n",
    "            if self.access(mid) < c:\n",
    "                low = mid + 1\n",
    "\n",
    "            elif self.access(mid) > c:\n",
    "                high = mid - 1\n",
    "            else:\n",
    "                return mid-1\n",
    "        return high\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 638,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.011329418691121366\n"
     ]
    }
   ],
   "source": [
    "def DP(S):\n",
    "    (x_list, y_list) = S\n",
    "    n = len(x_list)\n",
    "    indices = np.argsort(x_list)\n",
    "    x_list = x_list[indices]\n",
    "    y_list = y_list[indices] \n",
    "    threshold = np.zeros((n-2,2))\n",
    "    demands = -np.array((y_list - x_list)/n)\n",
    "    cost = (x_list- np.roll(x_list,1,axis=0))[1:len(x_list)]\n",
    "\n",
    "    point = np.zeros(n+1)\n",
    "    if -demands[0] <= 0:\n",
    "        point[0] = -demands[0]\n",
    "        point[1] = 0\n",
    "    else:\n",
    "        point[0] = 0\n",
    "        point[1] = -demands[0]\n",
    "\n",
    "    # find vertices\n",
    "    shift = np.zeros(n-2)\n",
    "    for i in range(1,n-2):\n",
    "        shift[i] = shift[i-1] -demands[i]\n",
    "        point[i+1] = -shift[i]\n",
    "\n",
    "    point[0:n-1] = point[0:n-1] + shift[n-3]-demands[n-2]\n",
    "    point[n-1] = demands[n-1]\n",
    "    ind = np.argsort(np.argsort(point))\n",
    "    point = np.sort(point)\n",
    "    \n",
    "    # find slopes\n",
    "    l = int(ind[0])\n",
    "    r = int(ind[1])\n",
    "    slopes = segment_tree(n+2)\n",
    "    if -demands[0] <= 0:\n",
    "        slopes.set(0, r, 1-cost[0])\n",
    "        slopes.add(0, l, -2)\n",
    "        slopes.set(r+1, n+1, 1+cost[0])\n",
    "        threshold[0][0] = -demands[0]\n",
    "        threshold[0][1] = 0\n",
    "    else:\n",
    "        slopes.set(0, r, -1+cost[0])\n",
    "        slopes.add(0, l, -2*cost[0])\n",
    "        slopes.set(r+1, n+1, 1+cost[0])\n",
    "        threshold[0][0] = 0\n",
    "        threshold[0][1] = -demands[0]\n",
    "\n",
    "    for i in range(1, n-2):\n",
    "        slopes.set(0,l, -1)\n",
    "        slopes.set(r+1, n+1, 1)\n",
    "        slopes.add(0, int(ind[i+1]), -cost[i])\n",
    "        slopes.add(int(ind[i+1]+1), n+1, cost[i])\n",
    "        l = slopes.binary_search(-1)\n",
    "        r = slopes.binary_search(1)\n",
    "        threshold[i][0] = point[l] + shift[i] - shift[n-3] +demands[n-2]\n",
    "        threshold[i][1] = point[r] + shift[i] - shift[n-3] +demands[n-2]\n",
    "    slopes.set(0,l, -1)\n",
    "    slopes.set(r+1, n+1, 1)\n",
    "    slopes.add(0, int(ind[n-1]), -1)\n",
    "    slopes.add(int(ind[n-1]+1), n+1, 1)\n",
    "    slopes.add(0, int(ind[n]), -cost[n-2])\n",
    "    slopes.add(int(ind[n]+1), n+1, cost[n-2])\n",
    "\n",
    "    return point, threshold, slopes, cost, demands\n",
    "\n",
    "def solve(point, threshold, slope, cost, demands):\n",
    "    n = len(point)-2\n",
    "    solution = np.zeros(n)\n",
    "\n",
    "    for i in range(1,n+1):\n",
    "        if i == 1:\n",
    "            for j in range(slope.size):\n",
    "                if slope.access(j)<=0 and slope.access(j+1)>=0:\n",
    "                    solution[n-i] = point[j] \n",
    "                    break\n",
    "        else:\n",
    "            if solution[n-i+1]+ demands[n-i+1]<= threshold[n-i][0]:\n",
    "                solution[n-i] = threshold[n-i][0]\n",
    "            elif threshold[n-i][0]< solution[n-i+1]+ demands[n-i+1] and solution[n-i+1]+ demands[n-i+1] <  threshold[n-i][1]:\n",
    "                solution[n-i] = solution[n-i+1]+ demands[n-i+1]\n",
    "            else:\n",
    "                solution[n-i] = threshold[n-i][1]\n",
    "    #print(solution)\n",
    "    final = 0\n",
    "    for i in range(n):\n",
    "        final = final + cost[i]*abs(solution[i])\n",
    "        if i == 0:\n",
    "            final = final + abs(solution[0] + demands[0]) \n",
    "        if i == n-1:\n",
    "            final = final + abs(solution[n-1] - demands[n]) \n",
    "        else:\n",
    "            final = final + abs(solution[i] -solution[i+1]-demands[i+1])\n",
    "    return final\n",
    "\n",
    "result = 0\n",
    "for i in range(len(S_calibrated)):\n",
    "    point, threshold, slope, cost , demands= DP(S_calibrated[i])\n",
    "    result += solve(point, threshold, slope, cost, demands)\n",
    "print(result/len(S_calibrated))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 639,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smCE_LP(S):\n",
    "    (x_list, y_list) = S\n",
    "    n = len(x_list)\n",
    "    indices = np.argsort(x_list)\n",
    "    x_list = x_list[indices]\n",
    "    y_list = y_list[indices]\n",
    "    A = np.diag([1]*len(x_list))\n",
    "    A = (A -np.roll(A, 1, axis = 1))[0:len(x_list)-1]\n",
    "    A = np.concatenate([A, -A])\n",
    "    b = (x_list- np.roll(x_list,1,axis=0))[1:len(x_list)]\n",
    "    b = np.concatenate([b,b])\n",
    "    c = (y_list - x_list)\n",
    "    return c, A, b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 640,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.011329417598749591\n"
     ]
    }
   ],
   "source": [
    "result = 0\n",
    "for i in range(len(S_calibrated)):\n",
    "    c, A, b = smCE_LP(S_calibrated[i])\n",
    "    n = len(S_calibrated[i][0])\n",
    "    x = cp.Variable(n)\n",
    "    objective = cp.Minimize(np.array([c/n]) @ x)\n",
    "    constraints = [-1 <= x, x <= 1, A@x <= b]\n",
    "    prob = cp.Problem(objective, constraints)\n",
    "    result += -prob.solve()\n",
    "print(result/len(S_calibrated))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 641,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.011329418372186423\n"
     ]
    }
   ],
   "source": [
    "def smCE_flow(S):\n",
    "    (x_list, y_list) = S\n",
    "    n = len(x_list)\n",
    "    indices = np.argsort(x_list)\n",
    "    x_list = x_list[indices]\n",
    "    y_list = y_list[indices] \n",
    "    row = np.array(range(0,n-1))\n",
    "    column = np.array(range(1,n)) \n",
    "    row = np.concatenate((row, np.array(range(1,n))))\n",
    "    column = np.concatenate((column, np.array(range(0,n-1))))\n",
    "\n",
    "    row = np.concatenate((row, np.array(range(0,n))))\n",
    "    column = np.concatenate((column, np.array([n]*n)))\n",
    "    row = np.concatenate((row, np.array([n]*n)))\n",
    "    column = np.concatenate((column, np.array(range(0,n))))\n",
    "\n",
    "    data = np.array([1]*len(row))\n",
    "    \n",
    "    cost = (x_list- np.roll(x_list,1,axis=0))[1:len(x_list)]\n",
    "    cost = np.concatenate((cost,cost))\n",
    "    cost = np.concatenate((cost, np.array([1]*2*n)))\n",
    "\n",
    "    data = np.array([1]*len(row))\n",
    "    graph = coo_matrix((data, (row, column)), shape=(n+1, n+1))\n",
    "    capacities = coo_matrix((np.array([2]*len(row)), (row, column)), shape=(n+1, n+1)) \n",
    "    cost = coo_matrix((cost, (row, column)), shape=(n+1, n+1))\n",
    "\n",
    "    c = np.array((y_list - x_list)/n)\n",
    "    demands = np.concatenate((-c, np.array([np.sum(c)])))\n",
    "    return graph, capacities, cost, demands\n",
    "\n",
    "result_1 = 0\n",
    "for i in range(len(S_calibrated)):\n",
    "    graph, capacities, cost, demands = smCE_flow(S_calibrated[i])\n",
    "    # print(graph.toarray())\n",
    "    # print(capacities.toarray())\n",
    "    # print(cost.toarray())\n",
    "    # print(demands)\n",
    "    obj, flow = min_cost_flow_scipy(graph, capacities, cost, demands, verbose=False)\n",
    "    result_1+=obj\n",
    "print(result_1/len(S_calibrated))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
