{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "start_time": "2023-04-14T14:47:17.625906Z",
     "end_time": "2023-04-14T14:47:17.654373Z"
    }
   },
   "outputs": [],
   "source": [
    "DATA_DIR = \"../../datasets/toydata.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from scipy.optimize import minimize\n",
    "from copy import deepcopy"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-14T14:47:17.644232Z",
     "end_time": "2023-04-14T14:47:25.080711Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def read_txt(path):\n",
    "    datas = []\n",
    "    with open(path, 'r') as f:\n",
    "        for sample in f.readlines():\n",
    "            x, y = sample.strip().split(\" \")\n",
    "            datas.append(np.array([float(x), float(y)]))\n",
    "    return datas"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-14T14:47:28.405406Z",
     "end_time": "2023-04-14T14:47:28.422484Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.26276621109429227\n",
      "True\n",
      "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      "[array([0., 0.]), array([0., 0.]), array([0., 0.]), array([0., 0.]), array([0., 0.]), array([0., 0.]), array([0., 0.]), array([0., 0.])]\n"
     ]
    }
   ],
   "source": [
    "datas = read_txt(DATA_DIR)\n",
    "n = 12\n",
    "f = 5\n",
    "T = 8\n",
    "LR = 0.5\n",
    "args1 = {'n': n, 'f': f, 'm': 1, 'clip_tau': 10.0, 'clip_scaling': 'linear', 'momentum': 0.0, 'mimic-warmup': 1, 'last': -1}\n",
    "args = (n, f, T, LR, np.array([0,0]), np.array([2,2]), datas, 'krum', args1)\n",
    "g0 = []\n",
    "for i in range(2 * T):\n",
    "    g0.append(0)\n",
    "g0 = np.array(g0)\n",
    "\n",
    "\n",
    "def agg(aggr, grads, args1, glb=None):\n",
    "    if aggr == 'avg':\n",
    "        return np.mean(grads), glb\n",
    "\n",
    "    elif aggr == 'cm':\n",
    "        stacked = np.stack(grads, axis=0)\n",
    "        values_upper = np.median(stacked, axis=0)\n",
    "        values_lower = np.median((-stacked), axis=0)\n",
    "        return (values_upper - values_lower) / 2, glb\n",
    "\n",
    "    elif aggr == 'krum':\n",
    "        def _compute_scores(distances, i, n, f):\n",
    "            s = [distances[j][i] ** 2 for j in range(i)] + [\n",
    "                distances[i][j] ** 2 for j in range(i + 1, n)\n",
    "            ]\n",
    "            _s = sorted(s)[: n - f - 2]\n",
    "            return sum(_s)\n",
    "\n",
    "        def multi_krum(distances, n, f, m):\n",
    "            if n < 1:\n",
    "                raise ValueError(\n",
    "                    \"Number of workers should be positive integer. Got {}.\".format(f)\n",
    "                )\n",
    "            if m < 1 or m > n:\n",
    "                raise ValueError(\n",
    "                    \"Number of workers for aggregation should be >=1 and <= {}. Got {}.\".format(\n",
    "                        m, n\n",
    "                    )\n",
    "                )\n",
    "            if 2 * f + 2 > n:\n",
    "                raise ValueError(\"Too many Byzantine workers: 2 * {} + 2 >= {}.\".format(f, n))\n",
    "\n",
    "            for i in range(n - 1):\n",
    "                for j in range(i + 1, n):\n",
    "                    if distances[i][j] < 0:\n",
    "                        raise ValueError(\n",
    "                            \"The distance between node {} and {} should be non-negative: Got {}.\".format(\n",
    "                                i, j, distances[i][j]\n",
    "                            )\n",
    "                        )\n",
    "\n",
    "            scores = [(i, _compute_scores(distances, i, n, f)) for i in range(n)]\n",
    "            sorted_scores = sorted(scores, key=lambda x: x[1])\n",
    "            return list(map(lambda x: x[0], sorted_scores))[:m]\n",
    "\n",
    "        def _compute_euclidean_distance(v1, v2):\n",
    "            return np.linalg.norm(v1 - v2)\n",
    "\n",
    "        def pairwise_euclidean_distances(vectors):\n",
    "            n = len(vectors)\n",
    "            vectors = [v.flatten() for v in vectors]\n",
    "\n",
    "            distances = {}\n",
    "            for i in range(n - 1):\n",
    "                distances[i] = {}\n",
    "                for j in range(i + 1, n):\n",
    "                    distances[i][j] = _compute_euclidean_distance(vectors[i], vectors[j]) ** 2\n",
    "            return distances\n",
    "\n",
    "        distances = pairwise_euclidean_distances(grads)\n",
    "        top_m_indices = multi_krum(distances, args1['n'], args1['f'], args1['m'])\n",
    "        values = sum(grads[i] for i in top_m_indices)\n",
    "        glb = top_m_indices\n",
    "        return values, glb\n",
    "\n",
    "    elif aggr == 'cp':\n",
    "        if args1['clip_scaling'] is None:\n",
    "            tau = args1['clip_tau']\n",
    "        elif args1['clip_scaling'] == \"linear\":\n",
    "            tau = args1['clip_tau'] / (1 - args1['momentum'])\n",
    "        elif args1['clip_scaling'] == \"sqrt\":\n",
    "            tau = args1['clip_tau'] / np.sqrt(1 - args1['momentum'])\n",
    "        else:\n",
    "            raise NotImplementedError(args1['clip_scaling'])\n",
    "\n",
    "        n_iter=3\n",
    "\n",
    "        def clip(v):\n",
    "            v_norm = np.linalg.norm(v)\n",
    "            scale = min(1, tau / v_norm)\n",
    "            return v * scale\n",
    "\n",
    "        if glb is None:\n",
    "            glb = np.zeros_like(grads[0])\n",
    "\n",
    "        for _ in range(n_iter):\n",
    "            glb = (\n",
    "                sum(clip(v - glb) for v in grads) / len(grads)\n",
    "                + glb\n",
    "            )\n",
    "        return np.copy(glb), glb\n",
    "\n",
    "    elif aggr == 'rfa':\n",
    "        if glb:\n",
    "            T, nu = glb\n",
    "        else:\n",
    "            T = 8\n",
    "            nu = 0.1\n",
    "            glb = (T, nu)\n",
    "\n",
    "        def _compute_euclidean_distance(v1, v2):\n",
    "            return np.linalg.norm(v1 - v2)\n",
    "\n",
    "        def smoothed_weiszfeld(weights, alphas, z, nu, T):\n",
    "            m = len(weights)\n",
    "            if len(alphas) != m:\n",
    "                raise ValueError\n",
    "\n",
    "            if nu < 0:\n",
    "                raise ValueError\n",
    "\n",
    "            for t in range(T):\n",
    "                betas = []\n",
    "                for k in range(m):\n",
    "                    distance = _compute_euclidean_distance(z, weights[k])\n",
    "                    betas.append(alphas[k] / max(distance, nu))\n",
    "\n",
    "                z = 0\n",
    "                for w, beta in zip(weights, betas):\n",
    "                    z += w * beta\n",
    "                z /= sum(betas)\n",
    "            return z\n",
    "\n",
    "        alphas = [1 / len(grads) for _ in grads]\n",
    "        z = np.zeros_like(grads[0])\n",
    "        return smoothed_weiszfeld(grads, alphas, z=z, nu=nu, T=T), glb\n",
    "\n",
    "    elif aggr == 'tm':\n",
    "        def topk_(matrix, K, axis=1):\n",
    "            if axis == 0:\n",
    "                row_index = np.arange(matrix.shape[1 - axis])\n",
    "                topk_index = np.argpartition(-matrix, K, axis=axis)[0:K, :]\n",
    "                topk_data = matrix[topk_index, row_index]\n",
    "                topk_index_sort = np.argsort(-topk_data,axis=axis)\n",
    "                topk_data_sort = topk_data[topk_index_sort,row_index]\n",
    "                topk_index_sort = topk_index[0:K,:][topk_index_sort,row_index]\n",
    "            else:\n",
    "                column_index = np.arange(matrix.shape[1 - axis])[:, None]\n",
    "                topk_index = np.argpartition(-matrix, K, axis=axis)[:, 0:K]\n",
    "                topk_data = matrix[column_index, topk_index]\n",
    "                topk_index_sort = np.argsort(-topk_data, axis=axis)\n",
    "                topk_data_sort = topk_data[column_index, topk_index_sort]\n",
    "                topk_index_sort = topk_index[:,0:K][column_index,topk_index_sort]\n",
    "            return topk_data_sort, topk_index_sort\n",
    "\n",
    "        glb = args1['f']\n",
    "        if len(grads) - 2 * glb > 0:\n",
    "            b = glb\n",
    "        else:\n",
    "            b = glb\n",
    "            while len(grads) - 2 * b <= 0:\n",
    "                b -= 1\n",
    "            if b < 0:\n",
    "                raise RuntimeError\n",
    "\n",
    "        stacked = np.stack(grads, axis=0)\n",
    "        largest, _ = topk_(stacked, b, 0)\n",
    "        neg_smallest, _ = topk_(-stacked, b, 0)\n",
    "        new_stacked = sum(np.concatenate([stacked, -largest, neg_smallest]), 0)\n",
    "        new_stacked /= len(grads) - 2 * b\n",
    "        return new_stacked, glb\n",
    "\n",
    "\n",
    "def fun(args):\n",
    "    n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "    def v(g):\n",
    "        glb = None\n",
    "        thetas = []\n",
    "        theta = ori\n",
    "        coe = 1\n",
    "        for t in range(T):\n",
    "            grads = []\n",
    "            for k in range(n - f):\n",
    "                grads.append(theta - data[k])\n",
    "            for k in range(f):\n",
    "                grads.append(np.array([g[2*t], g[2*t+1]]))\n",
    "            aggregated, glb = agg(aggr, grads, args1, glb)\n",
    "            theta = theta - LR * aggregated\n",
    "            thetas.append(coe * np.mean(np.power((theta - goal), 2)))\n",
    "            coe *= args1['last']\n",
    "        return -np.mean(np.array(thetas))\n",
    "    return v\n",
    "\n",
    "\n",
    "def fun_last(args):\n",
    "    n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "    def v(g):\n",
    "        glb = None\n",
    "        theta = ori\n",
    "        for t in range(T):\n",
    "            grads = []\n",
    "            for k in range(n - f):\n",
    "                grads.append(theta - data[k])\n",
    "            for k in range(f):\n",
    "                grads.append(np.array([g[2*t], g[2*t+1]]))\n",
    "            aggregated, glb = agg(aggr, grads, args1, glb)\n",
    "            theta = theta - LR * aggregated\n",
    "        return -np.mean(np.power((theta - goal), 2))\n",
    "    return v\n",
    "\n",
    "\n",
    "def fun_lower(args):\n",
    "    n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "    def v(g):\n",
    "        glb = None\n",
    "        thetas = []\n",
    "        theta = ori\n",
    "        for t in range(T):\n",
    "            grads = []\n",
    "            for k in range(n - f):\n",
    "                grads.append(theta - data[k])\n",
    "            for k in range(f):\n",
    "                grads.append(np.array([g[2*t], g[2*t+1]]))\n",
    "            aggregated, glb = agg(aggr, grads, args1, glb)\n",
    "            theta = theta - LR * aggregated\n",
    "            thetas.append(np.mean(np.power((theta - goal), 2)))\n",
    "        return -np.min(np.array(thetas))\n",
    "    return v\n",
    "\n",
    "\n",
    "def fun_coe(args):\n",
    "    n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "    def v(g):\n",
    "        glb = None\n",
    "        thetas = []\n",
    "        theta = ori\n",
    "        coe = 1\n",
    "        for t in range(T):\n",
    "            grads = []\n",
    "            for k in range(n - f):\n",
    "                grads.append(theta - data[k])\n",
    "            for k in range(f):\n",
    "                grads.append(np.array([g[2*t], g[2*t+1]]))\n",
    "            aggregated, glb = agg(aggr, grads, args1, glb)\n",
    "            theta = theta - LR * aggregated\n",
    "            thetas.append(coe * np.mean(np.power((theta - goal), 2)))\n",
    "            coe *= args1['last']\n",
    "        return -np.mean(np.array(thetas))\n",
    "    return v\n",
    "\n",
    "# def con(args, t):\n",
    "#     n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "#     def v(x):\n",
    "#         grads = []\n",
    "#         for k in range(n):\n",
    "#             grads.append(2 * (x[t] - data[k]))\n",
    "#         return x[t + 1] - (x[t] - LR * agg(self.agg, grads))\n",
    "#     return v\n",
    "\n",
    "res = minimize(fun_lower(args), g0, method='Nelder-Mead')\n",
    "print(res.fun)\n",
    "print(res.success)\n",
    "print(res.x)\n",
    "grads = []\n",
    "for t in range(T):\n",
    "    grads.append(np.array([res.x[2*t], res.x[2*t+1]]))\n",
    "print(grads)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-17T12:08:34.268396Z",
     "end_time": "2023-04-17T12:08:35.133394Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12 5 8 0.5 [0 0] [2 2] [array([2.18919433, 2.1889828 ]), array([1.72742965, 2.43838245]), array([2.0969758 , 1.51393248]), array([1.31970575, 2.53694255]), array([2.30125713, 1.95114008]), array([1.38611146, 2.40500619]), array([2.97932588, 0.96561345]), array([2.00205083, 1.63445234]), array([1.95190893, 1.58818792]), array([1.67720725, 2.48391858]), array([3.14872776, 2.20074892]), array([1.76087083, 2.05553847])] krum {'n': 12, 'f': 5, 'm': 1, 'clip_tau': 10.0, 'clip_scaling': 'linear', 'momentum': 0.0, 'mimic-warmup': 1, 'last': -1}\n"
     ]
    }
   ],
   "source": [
    "n, f, T, LR, ori, goal, data, aggr, args1 = args\n",
    "print(n, f, T, LR, ori, goal, data, aggr, args1)\n",
    "\n",
    "thetas = []\n",
    "gradients = []\n",
    "parameters = []\n",
    "theta = ori\n",
    "glb = None\n",
    "for t in range(T):\n",
    "    params = []\n",
    "    for k in range(n - f):\n",
    "        params.append(theta - data[k])\n",
    "    for k in range(f):\n",
    "        params.append(grads[t])\n",
    "    parameters.append(params)\n",
    "    gradient, glb = agg(aggr, params, args1, glb)\n",
    "    gradients.append(gradient)\n",
    "    theta = theta - LR * gradient\n",
    "    thetas.append(theta)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-17T12:08:37.011961Z",
     "end_time": "2023-04-17T12:08:37.033979Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "outputs": [
    {
     "data": {
      "text/plain": "[array([0.86371482, 1.21919122]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684]),\n array([1.29557224, 1.82878684])]"
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "thetas"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-17T12:08:38.280055Z",
     "end_time": "2023-04-17T12:08:38.290166Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "outputs": [
    {
     "data": {
      "text/plain": "[array([-1.72742965, -2.43838245]),\n array([-0.86371482, -1.21919122]),\n array([0., 0.]),\n array([0., 0.]),\n array([0., 0.]),\n array([0., 0.]),\n array([0., 0.]),\n array([0., 0.])]"
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradients"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "start_time": "2023-04-17T12:08:39.957951Z",
     "end_time": "2023-04-17T12:08:39.981468Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "outputs": [
    {
     "data": {
      "text/plain": "[[array([-2.04560236, -3.10227814]),\n  array([-2.54564137, -2.39741382]),\n  array([-1.02651485, -2.48821055]),\n  array([-1.30682523, -1.40828643]),\n  array([-0.85175421, -2.95818181]),\n  array([660.21528374, -78.4716266 ]),\n  array([660.21528374, -78.4716266 ]),\n  array([660.21528374, -78.4716266 ])],\n [array([-4.13185995, -1.60468236]),\n  array([-4.63189895, -0.89981804]),\n  array([-3.11277243, -0.99061477]),\n  array([-3.39308282,  0.08930935]),\n  array([-2.9380118 , -1.46058603]),\n  array([138.77599701, -42.33461917]),\n  array([138.77599701, -42.33461917]),\n  array([138.77599701, -42.33461917])],\n [array([-5.24582103, -0.27096652]),\n  array([-5.74586003,  0.4338978 ]),\n  array([-4.22673351,  0.34310107]),\n  array([-4.5070439 ,  1.42302519]),\n  array([-4.05197288, -0.12687019]),\n  array([114.92693056, -48.60325924]),\n  array([114.92693056, -48.60325924]),\n  array([114.92693056, -48.60325924])],\n [array([-5.682762  ,  0.70382954]),\n  array([-6.18280101,  1.40869386]),\n  array([-4.66367449,  1.31789713]),\n  array([-4.94398488,  2.39782125]),\n  array([-4.48891385,  0.84792587]),\n  array([122.08255858, -61.76175166]),\n  array([122.08255858, -61.76175166]),\n  array([122.08255858, -61.76175166])],\n [array([-5.79318776,  1.38385119]),\n  array([-6.29322676,  2.08871551]),\n  array([-4.77410024,  1.99791878]),\n  array([-5.05441063,  3.0778429 ]),\n  array([-4.59933961,  1.52794752]),\n  array([113.37541773, -64.45852986]),\n  array([113.37541773, -64.45852986]),\n  array([113.37541773, -64.45852986])],\n [array([-5.76744206,  1.85274649]),\n  array([-6.26748107,  2.55761081]),\n  array([-4.74835455,  2.46681408]),\n  array([-5.02866494,  3.5467382 ]),\n  array([-4.57359392,  1.99684282]),\n  array([101.60507382, -62.38402056]),\n  array([101.60507382, -62.38402056]),\n  array([101.60507382, -62.38402056])],\n [array([-5.69673158,  2.17420199]),\n  array([-6.19677059,  2.87906631]),\n  array([-4.67764407,  2.78826958]),\n  array([-4.95795446,  3.8681937 ]),\n  array([-4.50288344,  2.31829832]),\n  array([101.70611256, -65.55077017]),\n  array([101.70611256, -65.55077017]),\n  array([101.70611256, -65.55077017])],\n [array([-5.62190771,  2.39223512]),\n  array([-6.12194672,  3.09709944]),\n  array([-4.60282019,  3.00630271]),\n  array([-4.88313058,  4.08622683]),\n  array([-4.42805956,  2.53633145]),\n  array([ 73.33092827, -48.81445777]),\n  array([ 73.33092827, -48.81445777]),\n  array([ 73.33092827, -48.81445777])]]"
     },
     "execution_count": 188,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parameters"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "outputs": [
    {
     "data": {
      "text/plain": "(array([ 2.00293208, -2.52630131]), array([ 2.00293208, -2.52630131]))"
     },
     "execution_count": 189,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg('cp', parameters[1],args1)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "outputs": [],
   "source": [
    "from codes.aggregators.clipping import Clipping"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "outputs": [],
   "source": [
    "cl = Clipping(tau=args1['clip_tau'], n_iter=3)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([ 2.0715, -2.3267], dtype=torch.float64)"
     },
     "execution_count": 177,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = []\n",
    "for p in parameters[1]:\n",
    "    inputs.append(torch.from_numpy(p))\n",
    "cl.__call__(inputs)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([ 4.3015, -2.9161], dtype=torch.float64)"
     },
     "execution_count": 178,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = []\n",
    "for p in parameters[0]:\n",
    "    inputs.append(torch.from_numpy(p))\n",
    "cl.__call__(inputs)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([ 4.3015, -2.9161], dtype=torch.float64)"
     },
     "execution_count": 181,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cl.momentum"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
