{"cells":[{"cell_type":"markdown","metadata":{"id":"KeyG3z0UIccS"},"source":["# Basic Setting"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"2khWiVyyhJPv"},"outputs":[],"source":["!git clone https://github.com/Math-Jacobs/bfm\n","!pip install bfm/python\n","!pip install pot"]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive', force_remount = True)"],"metadata":{"id":"udUUh02Phe6o"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import sys\n","sys.path.append('drive/MyDrive/WDHA')"],"metadata":{"id":"2VKyTzbvhdQ4"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"UweIx0jKchjt"},"outputs":[],"source":["from time import time\n","import numpy as np\n","import numpy.ma as ma\n","import seaborn as sns\n","from scipy.stats import norm\n","import scipy.integrate as integrate\n","import gc\n","from scipy.stats import multivariate_normal\n","from matplotlib.colors import LinearSegmentedColormap\n","from metric import *\n","from functions import *\n","\n","import ot\n","from scipy.fftpack import dctn, idctn\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from w2 import BFM\n","%config InlineBackend.figure_format = 'retina'\n","plt.rcParams['figure.figsize'] = (13, 8)\n","plt.rcParams['image.cmap'] = 'viridis'"]},{"cell_type":"markdown","metadata":{"id":"cGwQ20kbcws5"},"source":["## Dataset and settings"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ixiLFyA1coaK"},"outputs":[],"source":["n1, n2 = 1024, 1024\n","x, y = np.meshgrid(np.linspace(0.5/n1,1-0.5/n1,n1),\n","                    np.linspace(0.5/n2,1-0.5/n2,n2))\n","func1 = 1/2 * (x**2 + y**2)\n","r = 0.15\n","\n","cir1, cir2, cir3, cir4, ans = np.zeros((n2, n1)), np.zeros((n2, n1)),np.zeros((n2, n1)),np.zeros((n2, n1)), np.zeros((n2, n1))\n","cir1[(x-0.2)**2 + (y-0.2)**2 < r**2] = 1\n","cir2[(x-0.2)**2 + (y-0.8)**2 < r**2] = 1\n","cir3[(x-0.8)**2 + (y-0.2)**2 < r**2] = 1\n","cir4[(x-0.8)**2 + (y-0.8)**2 < r**2] = 1\n","ans[(x-0.5)**2 + (y-0.5)**2 < r**2] = 1\n","\n","cir1 *= n1 * n2 / np.sum(cir1)\n","cir2 *= n1 * n2 / np.sum(cir2)\n","cir3 *= n1 * n2 / np.sum(cir3)\n","cir4 *= n1 * n2 / np.sum(cir4)\n","ans *= n1 * n2 / np.sum(ans)\n","circles4 = [cir1,cir2,cir3,cir4]\n","\n","r = 0.1\n","# Initialize densities\n","mu1 = np.zeros((n2, n1))\n","mu1[(x-0.8)**2 + (y-0.8)**2 < r**2] = 1\n","mu2 = np.zeros((n2, n1))\n","mu2[(0.8-r/2.5<x) & (x<0.8+r/2.5) & (0.3-r < y) & (y < 0.3+r)] = 1\n","mu2[(0.3-r/2.5<y) & (y<0.3+r/2.5) & (0.8-r < x) & (x < 0.8+r)] = 1\n","\n","\n","# Normalize\n","mu1 *= n1*n2 / np.sum(mu1)\n","mu2 *= n1*n2 / np.sum(mu2)\n","\n","heart = np.zeros((n2, n1))\n","heart[((10*x-2)**2+(10*(y-0.3))**2-1)**3-(10*x-2)**2*(10*(y-0.3))**3<0] = 1\n","heart *= n1 * n2 / np.sum(heart)\n","\n","rectangle = np.zeros((n2, n1))\n","rectangle[(x<0.3) & (x > 0.1) & (y>0.7) & (y<0.9)] = 1\n","rectangle *= n1*n2 / np.sum(rectangle)\n","\n","mu = [mu1,mu2, heart, rectangle]"]},{"cell_type":"markdown","metadata":{"jp-MarkdownHeadingCollapsed":true,"id":"ijc0IbN5hJP0"},"source":["# Shape Example"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"M3rzajjFhJP0"},"outputs":[],"source":["plotting(mu, np.zeros((n2,n1)),'_',save_option = False)"]},{"cell_type":"markdown","metadata":{"id":"8mCdEwAQhJP1"},"source":["## WDHA"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"BLtIDJDGhJP1"},"outputs":[],"source":["mu_WGHA = frechet_mean(mu, 300, 'MU', save_option = False, return_option = True)"]},{"cell_type":"markdown","metadata":{"id":"FQZmpOHXhJP1"},"source":["## CWB"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"uUldpwcnhJP2"},"outputs":[],"source":["mu_CWB = frechet_mean_pot(mu, 5e-3,'MU',print_option=True, plot_option=True, save_option=False, return_option = True)"]},{"cell_type":"markdown","metadata":{"id":"TNcm_lGHhJP2"},"source":["## DSB"]},{"cell_type":"code","source":["mu_DSB = frechet_mean_pot_debiased(mu, 5e-3,'MU',print_option=True, plot_option=True, save_option=False, return_option = True)"],"metadata":{"id":"G3wrWi3_Uhmp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LVrMsiY4hJP2"},"source":["## Analysis"]},{"cell_type":"markdown","metadata":{"id":"9Np2RSqUhJP3"},"source":["### Average $W_2$ distance between given distributions and barycenter"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yNg0YVwDhJP3"},"outputs":[],"source":["print(avgw2(mu,mu_WGHA),avgw2(mu,mu_CWB),avgw2(mu,mu_DSB))"]},{"cell_type":"markdown","metadata":{"jp-MarkdownHeadingCollapsed":true,"id":"C5XgeUpkhJP3"},"source":["# Ball Examples"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"NG3e9rIMhJP3"},"outputs":[],"source":["plotting(circles4, np.zeros((n2,n1)),'_',save_option = False)"]},{"cell_type":"markdown","metadata":{"id":"8bZ3TCyKhJP3"},"source":["## WGHA"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ORcxg7GyhJP3"},"outputs":[],"source":["Ball4_WGHA = frechet_mean(circles4, 300, '4circles', save_option = False, return_option = True)"]},{"cell_type":"markdown","metadata":{"id":"0UbBw1_KhJP4"},"source":["## CWB"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"aNHik7eUhJP4"},"outputs":[],"source":["Ball4_CWB = frechet_mean_pot(circles4, 5e-3,'4circles',print_option=True, plot_option=True, save_option=False, return_option = True)"]},{"cell_type":"markdown","metadata":{"id":"7XATJez6hJP4"},"source":["## DSB"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Mqg008JnhJP4"},"outputs":[],"source":["Ball4_DSB = frechet_mean_pot_debiased(circles4, 5e-3,'4circles',print_option=True, plot_option=True, save_option=False, return_option = True)"]},{"cell_type":"markdown","metadata":{"jp-MarkdownHeadingCollapsed":true,"id":"wZgu4A2shJP4"},"source":["## Analysis"]},{"cell_type":"markdown","metadata":{"id":"hgk-CbwWhJP4"},"source":["### $W_2$ distance from ground truth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"NljPyYjHhJP5"},"outputs":[],"source":["print(w2dist(Ball4_WGHA,ans),w2dist(Ball4_CWB,ans),w2dist(Ball4_DSB,ans))"]},{"cell_type":"markdown","metadata":{"id":"ZOiRtZT7hJP5"},"source":["### $L_2$ distance from groundtruth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"beJ2Z0RZhJP5"},"outputs":[],"source":["print(l2dist(Ball4_WGHA,ans),l2dist(Ball4_CWB,ans),l2dist(Ball4_DSB,ans))"]},{"cell_type":"markdown","metadata":{"id":"QMtVkWb9hJP5"},"source":["### Average $W_2$ distance from given distributions"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"eCFJwWNmhJP5"},"outputs":[],"source":["print(avgw2(circles4,Ball4_WGHA),avgw2(circles4,Ball4_CWB),avgw2(circles4,Ball4_DSB))"]},{"cell_type":"markdown","metadata":{"id":"eXz5IkwIoo_g"},"source":["# MNIST Examples"]},{"cell_type":"markdown","source":["Before you start, download **Images(500x500).npy** and **WriterInfo.npy** through the link : https://drive.google.com/drive/folders/1f2o1kjXLvcxRgtmMMuDkA2PQ5Zato4Or"],"metadata":{"id":"JI3rpAPuYyD3"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"LtU8kVELMEd_"},"outputs":[],"source":["Images = np.load('/content/drive/My Drive/WDHA/Images(500x500).npy')\n","WriterInfo = np.load('/content/drive/My Drive/WDHA/WriterInfo.npy')\n","digit = WriterInfo[:,0]\n","user = WriterInfo[:,-1]\n","num_image = 100\n","num_iter = 300\n","numbers8 = 255 - Images[(digit == 8)][:num_image].astype('float64')\n","\n","for j in range(num_image):\n","    numbers8[j] /= np.sum(numbers8[j])\n","    numbers8[j] *= 500 * 500\n","del Images, WriterInfo, user, digit"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tVZOELBOpHzz"},"outputs":[],"source":["plt.imshow(numbers8[0], origin = 'lower')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Lv1Nu__evTwW"},"outputs":[],"source":["from matplotlib.colors import LinearSegmentedColormap\n","\n","def plotting_mnist(dist,name,save_option = False):\n","\n","  colors = [\n","    (1.0, 1.0, 1.0),  # White (background)\n","    (0.0, 0.0, 0.8),  # Very vivid light blue (reduce red and green, max out blue)\n","    (0.0, 0.0, 0.7),  # Dark blue\n","    (0.0, 0.0, 0.6)   # Black\n","  ]\n","  custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors)\n","  vmin, vmax = 0, 70\n","  plt.imshow(dist, cmap=custom_cmap, origin = 'lower', vmin = vmin, vmax = vmax)\n","\n","  plt.xticks([0, plt.gca().get_xlim()[1]], ['0', '1'])  # Custom x-axis labels\n","  plt.yticks([0, plt.gca().get_ylim()[1]], ['0', '1'])  # Custom y-axis labels\n","  # plt.colorbar()\n","  if save_option:\n","    plt.savefig('%s.jpg'%(name))\n","  plt.show()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"NRylHpJoeRO-"},"outputs":[],"source":["plotting_mnist(frechet_mean(numbers8, num_iter, 'mnist',plot_option = False, save_option = False, return_option = True), '-')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Xc2pW7rIpMQ3"},"outputs":[],"source":["import ot\n","\n","tic = time()\n","weights = np.array([1/num_image] * num_image)\n","mean_CWB = ot.bregman.convolutional_barycenter2d(numbers8, 5*10**(-3), weights, numItermax = num_iter,stopThr=0.0)\n","plotting_mnist(mean_CWB, '8CWB', save_option = False)\n","print(toc-tic)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"qyMfJMs2ITCT"},"outputs":[],"source":["import ot\n","\n","tic = time()\n","weights = np.array([1/num_image] * num_image)\n","mean_DSB = ot.bregman.convolutional_barycenter2d_debiased(numbers8, 5*10**(-3), weights, numItermax = num_iter,stopThr=0.0)\n","\n","toc = time()\n","plotting_mnist(mean_DSB, '8DSB', save_option = False)\n","print(toc-tic)"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.2"}},"nbformat":4,"nbformat_minor":0}