{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# **Library Imports**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import random\n",
    "import math\n",
    "import statistics\n",
    "import operator\n",
    "from scipy.stats import norm\n",
    "from collections import Counter as counter\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.linear_model import BayesianRidge\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "from scipy.stats import wasserstein_distance as wass_d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mergable Misra Gries [[Paper](https://arxiv.org/abs/1705.07001)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MergableMisraGries(dict):\n",
    "    #k: counters\n",
    "    #n: number of elments seen\n",
    "\n",
    "\n",
    "    def __init__(self, k=2, kstar = -1, n=0):\n",
    "        super().__init__()\n",
    "        self.k = k\n",
    "        if kstar == -1:\n",
    "          self.kstar =  k/2\n",
    "        if kstar != k/2 and kstar != -1:\n",
    "          print(\"Warning: Choosing other threshold than median!\")\n",
    "        self.kstar = kstar\n",
    "        self.n = n\n",
    "\n",
    "    def _decrement(self):\n",
    "        \"\"\"Decrement all elements and remove zeros.\"\"\"\n",
    "        c = statistics.median(list(self.values()))\n",
    "        for element in list(self):\n",
    "                self[element] -= c\n",
    "                if self[element] <= 0:\n",
    "                    del self[element]\n",
    "        return c\n",
    "\n",
    "    def record(self, element):\n",
    "        \"\"\"Record a stream element.\"\"\"\n",
    "        self.n += 1\n",
    "        # print(element)\n",
    "        if element[0] in self:\n",
    "            self[element[0]] += element[1]\n",
    "        elif len(self) < self.k:\n",
    "            self[element[0]] = element[1]\n",
    "        else:\n",
    "            c = self._decrement()\n",
    "            #  print(c)\n",
    "            #  print(self.keys())\n",
    "            if element[1] > c:\n",
    "                self[element[0]] = element[1]\n",
    "\n",
    "\n",
    "    def merge(self, other):\n",
    "        \"\"\"Merge two together to get a new one.\"\"\"\n",
    "        new = self\n",
    "        for element in list(other.keys()):\n",
    "          new.record([element,other[element]])\n",
    "        return new\n",
    "\n",
    "    def __add__(self, other):\n",
    "        return self.merge(other)\n",
    "\n",
    "    def scanunweighted(self, stream):\n",
    "        \"\"\"Scan an entire array>\"\"\"\n",
    "        for element in stream:\n",
    "            self.record([element,1])\n",
    "        return self\n",
    "\n",
    "    def scanweighted(self, stream):\n",
    "        \"\"\"Scan an entire array>\"\"\"\n",
    "        for element in stream:\n",
    "            self.record(element)\n",
    "        return self\n",
    "\n",
    "    def output(self):\n",
    "        \"\"\"Every item which occurs more than n/k times is guaranteed to appear in the output array.\"\"\"\n",
    "        nk = self.n / self.k\n",
    "        return [key for key, count in self.items() if count > nk]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SWA and STVA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SWA(MergableMisraGries):\n",
    "    def estimateCDF(self, element):\n",
    "        totalfreq = sum(self.values())\n",
    "        cumultaivefreq = 0\n",
    "        for key in self.keys():\n",
    "            if key <= element:\n",
    "                cumultaivefreq += self[key]\n",
    "        return cumultaivefreq/totalfreq\n",
    "    \n",
    "    def wasserstein(self,other):\n",
    "        return wass_d(list(self.keys()),list(other.keys()),list(self.values()),list(other.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "class STVA(MergableMisraGries):\n",
    "    def estimatePDF(self, element):\n",
    "        if element in self.keys():\n",
    "            totalfreq = sum(self.values())\n",
    "            return self[element]/totalfreq\n",
    "        return 0\n",
    "    \n",
    "    def TV(self,other):\n",
    "        tv = 0\n",
    "        for key in self.keys():\n",
    "            if self.estimatePDF(key) - other.estimatePDF(key)> 0:\n",
    "                tv += self.estimatePDF(key) - other.estimatePDF(key)\n",
    "        return tv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#Trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Stream():\n",
    "  def __init__(self, n = 100000):\n",
    "    self.stream = []\n",
    "    self.n = n\n",
    "\n",
    "  def generatenormal(self, mu = 0, sigma = 1):\n",
    "    self.stream = np.random.normal(mu, sigma, self.n)\n",
    "\n",
    "  def input(data):\n",
    "    self.stream = data\n",
    "\n",
    "  def discretize(self, disc = 0.05):\n",
    "    m = []\n",
    "    for element in self.stream:\n",
    "      q = math.floor(element/disc)\n",
    "      m.append(q*disc+disc/2)\n",
    "    self.stream = m\n",
    "\n",
    "  def histogram(self):\n",
    "    self.hist = counter(self.stream)\n",
    "    return self.hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.371705\n",
      "0.24114000000000002\n",
      "0.1974126513658474\n"
     ]
    }
   ],
   "source": [
    "stream1 = Stream()\n",
    "stream1.generatenormal(0,4)\n",
    "stream1.discretize(0.5)\n",
    "stream2 = Stream()\n",
    "stream2.generatenormal(2,6)\n",
    "stream2.discretize(0.5)\n",
    "\n",
    "\n",
    "\n",
    "SWA1 = SWA(100).scanunweighted(stream1.stream)\n",
    "SWA2 = SWA(100).scanunweighted(stream2.stream)\n",
    "\n",
    "print(SWA1.wasserstein(SWA2))\n",
    "\n",
    "STVA1 = STVA(10000).scanunweighted(stream1.stream)\n",
    "STVA2 = STVA(10000).scanunweighted(stream2.stream)\n",
    "\n",
    "print(STVA1.TV(STVA2))\n",
    "print(2*norm.cdf(0.25)-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#Synthetic Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def TVGaussian(m1,m2):"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
