{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "85a43e0a-446e-4563-adc6-1be71891a5e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "# add to path\n",
    "sys.path.insert(0,'../src')\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchdyn\n",
    "from torchdyn.datasets import generate_moons\n",
    "from scipy.sparse import csgraph\n",
    "from scipy.spatial import distance\n",
    "import scipy\n",
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "import ott\n",
    "from ott.problems.linear import linear_problem\n",
    "from ott.problems.quadratic import quadratic_problem\n",
    "from ott.solvers.linear import sinkhorn, sinkhorn_lr\n",
    "from ott.solvers.quadratic import gromov_wasserstein_lr\n",
    "import os.path\n",
    "import FRLC\n",
    "import importlib\n",
    "importlib.reload(FRLC)\n",
    "# Load device, set dtype\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dtype = torch.float64\n",
    "\n",
    "def Gaussian_mixture_three_10D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([\n",
    "    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the first Gaussian\n",
    "    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the second Gaussian\n",
    "    [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]   # Mean of the third Gaussian\n",
    "    ])\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(10)\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/3, 1/3, 1/3])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 10))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1, 2], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_two_10D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([\n",
    "    [0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the first Gaussian\n",
    "    [-0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0]  # Mean of the second Gaussian\n",
    "    ])\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(10)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/2, 1/2])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 10))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_three_2D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([[0, 0], [0, 1], [1, 1]])  # Example means\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(2)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/3, 1/3, 1/3])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 2))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1, 2], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_two_2D(n_samples):\n",
    "    # Define the means of the two Gaussians\n",
    "    means = np.array([[0.5, 0.5], [-0.5, 0.5]])  # Example means\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(2)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/2, 1/2])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 2))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "import math \n",
    "\n",
    "def eight_normal_sample(n, dim, scale=1, var=1):\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)\n",
    "    )\n",
    "    centers = [\n",
    "        (1, 0),\n",
    "        (-1, 0),\n",
    "        (0, 1),\n",
    "        (0, -1),\n",
    "        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "    ]\n",
    "    centers = torch.tensor(centers) * scale\n",
    "    noise = m.sample((n,))\n",
    "    multi = torch.multinomial(torch.ones(8), n, replacement=True)\n",
    "    data = []\n",
    "    for i in range(n):\n",
    "        data.append(centers[multi[i]] + noise[i])\n",
    "    data = torch.stack(data)\n",
    "    return data\n",
    "\n",
    "\n",
    "def sample_moons(n):\n",
    "    x0, _ = generate_moons(n, noise=0.2)\n",
    "    return x0 * 3 - 1\n",
    "\n",
    "def sample_8gaussians(n):\n",
    "    return eight_normal_sample(n, 2, scale=5, var=0.1).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "32f5cf2a-316b-40e5-b7ee-fefdfa50729b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Gaussian_mixture_three_10D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([\n",
    "    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the first Gaussian\n",
    "    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the second Gaussian\n",
    "    [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]   # Mean of the third Gaussian\n",
    "    ])\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(10)\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/3, 1/3, 1/3])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 10))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1, 2], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_two_10D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([\n",
    "    [0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0],  # Mean of the first Gaussian\n",
    "    [-0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0]  # Mean of the second Gaussian\n",
    "    ])\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(10)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/2, 1/2])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 10))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_three_2D(n_samples):\n",
    "    # Define the means of the three Gaussians\n",
    "    means = np.array([[0, 0], [0, 1], [1, 1]])  # Example means\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(2)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/3, 1/3, 1/3])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 2))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1, 2], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "def Gaussian_mixture_two_2D(n_samples):\n",
    "    # Define the means of the two Gaussians\n",
    "    means = np.array([[0.5, 0.5], [-0.5, 0.5]])  # Example means\n",
    "\n",
    "    # Shared covariance matrix\n",
    "    covariance = 0.05 * np.eye(2)  # 2D identity matrix multiplied by 0.05\n",
    "\n",
    "    # Mixing proportions (assuming equal proportions for simplicity)\n",
    "    proportions = np.array([1/2, 1/2])  # Sum must be 1\n",
    "\n",
    "    # Initialize an array to hold the sampled points\n",
    "    samples = np.zeros((n_samples, 2))\n",
    "\n",
    "    # Generate samples\n",
    "    for i in range(n_samples):\n",
    "        # Choose a Gaussian distribution based on mixing proportions\n",
    "        gaussian_index = np.random.choice([0, 1], p=proportions)\n",
    "\n",
    "        # Sample from the chosen Gaussian\n",
    "        samples[i, :] = np.random.multivariate_normal(means[gaussian_index], covariance)\n",
    "\n",
    "    return samples\n",
    "\n",
    "import math \n",
    "\n",
    "def eight_normal_sample(n, dim, scale=1, var=1):\n",
    "    m = torch.distributions.multivariate_normal.MultivariateNormal(\n",
    "        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)\n",
    "    )\n",
    "    centers = [\n",
    "        (1, 0),\n",
    "        (-1, 0),\n",
    "        (0, 1),\n",
    "        (0, -1),\n",
    "        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n",
    "        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n",
    "    ]\n",
    "    centers = torch.tensor(centers) * scale\n",
    "    noise = m.sample((n,))\n",
    "    multi = torch.multinomial(torch.ones(8), n, replacement=True)\n",
    "    data = []\n",
    "    for i in range(n):\n",
    "        data.append(centers[multi[i]] + noise[i])\n",
    "    data = torch.stack(data)\n",
    "    return data\n",
    "\n",
    "\n",
    "def sample_moons(n):\n",
    "    x0, _ = generate_moons(n, noise=0.2)\n",
    "    return x0 * 3 - 1\n",
    "\n",
    "def sample_8gaussians(n):\n",
    "    return eight_normal_sample(n, 2, scale=5, var=0.1).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "a895d422-d17d-4bf6-a7f6-821af4386a52",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting time experiment!\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Cost_Scet: 0.21017257869243622, Cost_Ours: 0.20741017048951882\n",
      "Time_Scet: 0.7505786418914795, Time_Ours: 0.3786803960800171\n",
      "LOT, left marginal error: 1.8978775187861174e-05, right marginal error: 1.8901560906670056e-05\n",
      "FRLC, mean left marginal error: 1.563960080355711e-05, right marginal error: 6.136053024849095e-18\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Cost_Scet: 0.1813611388206482, Cost_Ours: 0.1781790060454842\n",
      "Time_Scet: 0.7350003719329834, Time_Ours: 0.3538846969604492\n",
      "LOT, left marginal error: 1.6549891370232217e-05, right marginal error: 1.697310472081881e-05\n",
      "FRLC, mean left marginal error: 1.05568786875748e-06, right marginal error: 6.958032662280854e-18\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Iteration: 0\n",
      "Cost_Scet: 0.30756181478500366, Cost_Ours: 0.2936506874185894\n",
      "Time_Scet: 0.6773028373718262, Time_Ours: 0.3225651741027832\n",
      "LOT, left marginal error: 1.393902948620962e-05, right marginal error: 1.46472020787769e-05\n",
      "FRLC, mean left marginal error: 2.4798471127037736e-06, right marginal error: 8.89925798965313e-18\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from ott import utils\n",
    "import time\n",
    "import objective_grad\n",
    "import util\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "dtype = torch.float64\n",
    "\n",
    "importlib.reload(objective_grad)\n",
    "importlib.reload(FRLC)\n",
    "importlib.reload(util)\n",
    "\n",
    "print('Starting time experiment!')\n",
    "\n",
    "batch_size1, batch_size2 = 1000, 1000\n",
    "# 10-D Gaussian ground-cost\n",
    "x0 = Gaussian_mixture_three_10D(batch_size1)\n",
    "x1 = Gaussian_mixture_two_10D(batch_size2)\n",
    "C10D = torch.from_numpy(distance.cdist(x0, x1)).to(device)\n",
    "C_10DGaussian = C10D / C10D.max()\n",
    "# 2-D Gaussian ground-cost\n",
    "x0 = Gaussian_mixture_three_2D(batch_size1)\n",
    "x1 = Gaussian_mixture_two_2D(batch_size2)\n",
    "C2D = torch.from_numpy(distance.cdist(x0, x1)).to(device)\n",
    "C_2DGaussian = C2D / C2D.max()\n",
    "# 2-Moons 8-Gaussians ground-cost\n",
    "x0 = sample_8gaussians(batch_size1)\n",
    "x1 = sample_moons(batch_size2)\n",
    "C2M8G = torch.from_numpy(distance.cdist(x0, x1)).to(device)\n",
    "C_2M8G = C2M8G / C2M8G.max()\n",
    "\n",
    "cost_mats = [C_2M8G, C_2DGaussian, C_10DGaussian]\n",
    "\n",
    "# Fixing rank = 200 for the time experiments\n",
    "rank = 100\n",
    "# general costs\n",
    "Cost_Scet_list = []\n",
    "Cost_Ours_list = []\n",
    "# times\n",
    "Time_Scet_list = []\n",
    "Time_Ours_list = []\n",
    "\n",
    "N_rep = 10\n",
    "\n",
    "for i, C in enumerate(cost_mats):\n",
    "\n",
    "    one_N1 = np.ones((C.shape[0]))\n",
    "    a = one_N1 / C.shape[0]\n",
    "    one_N2 = np.ones((C.shape[1]))\n",
    "    b = one_N2 / C.shape[1]\n",
    "    \n",
    "    # Scetbon 21'\n",
    "    # Establish ground cost\n",
    "    geom_xy = ott.geometry.geometry.Geometry(cost_matrix=C.cpu().numpy())\n",
    "    # Define OT-problem\n",
    "    ot_prob = linear_problem.LinearProblem(geom_xy)\n",
    "    # Default initialization\n",
    "    solver = sinkhorn_lr.LRSinkhorn(rank=rank)\n",
    "    start = time.time()\n",
    "    ot_lr = solver(ot_prob)\n",
    "    end = time.time()\n",
    "    \n",
    "    # Yield coupling\n",
    "    P_Scet = ot_lr.matrix\n",
    "    \n",
    "    # Evaluate cost, time to run\n",
    "    Cost_Scet = np.sum(P_Scet * C.cpu().numpy())\n",
    "    Cost_Scet_list.append(Cost_Scet)\n",
    "    Time_Scet_list.append(end - start)\n",
    "\n",
    "    time_our_list = []\n",
    "    cost_our_list = []\n",
    "    left_error = []\n",
    "    right_error = []\n",
    "    \n",
    "    for j in range(N_rep):\n",
    "        # Ours\n",
    "        start = time.time()\n",
    "        P,_ = FRLC.FRLC_iteration(C, r=rank, device=device, dtype=torch.float64, \\\n",
    "                      printCost=False, returnFull=True, \\\n",
    "                      initialization='Full', max_inneriters_balanced=1000, \\\n",
    "                                  max_inneriters_relaxed=50, min_iter=7) # balanced\n",
    "        end = time.time()\n",
    "        \n",
    "        # Evaluate cost, time to run\n",
    "        P = P.cpu().numpy()\n",
    "        Cost_Ours = np.sum(P * C.cpu().numpy())\n",
    "        time_our_list.append(end - start)\n",
    "        cost_our_list.append(Cost_Ours)\n",
    "\n",
    "        left_error.append(np.linalg.norm(P @ one_N2 - a))\n",
    "        right_error.append(np.linalg.norm(P.T @ one_N1 - b))\n",
    "    \n",
    "    Cost_Ours_list.append(sum(cost_our_list) / len(cost_our_list))\n",
    "    Time_Ours_list.append(sum(time_our_list) / len(time_our_list))\n",
    "    \n",
    "    # Print results\n",
    "    print(f'Cost_Scet: {Cost_Scet_list[i]}, Cost_Ours: {Cost_Ours_list[i]}')\n",
    "    print(f'Time_Scet: {Time_Scet_list[i]}, Time_Ours: {Time_Ours_list[i]}')\n",
    "    \n",
    "    # Marginal-error\n",
    "    print(f'LOT, left marginal error: {np.linalg.norm(P_Scet @ one_N2 - a)}, right marginal error: {np.linalg.norm(P_Scet.T @ one_N1 - b)}')\n",
    "    print(f'FRLC, mean left marginal error: {sum(left_error) / len(left_error)}, right marginal error: {sum(right_error) / len(right_error)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "380b8303-7685-4df4-9b35-a2e457112759",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peterenv2 [~/.conda/envs/peterenv2/]",
   "language": "python",
   "name": "conda_peterenv2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
