{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16db7d5f-bc9f-4092-be25-8d9590226f54",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import torch \n",
    "import ot \n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "os.chdir('.')\n",
    "from lib.gromov_test import partial_gromov_ver1,cost_matrix_d,tensor_dot_param,tensor_dot_func,gwgrad_partial,partial_gromov_wasserstein,gwgrad_partial1\n",
    "from lib.opt import *\n",
    "from lib.pu_learning import *\n",
    "\n",
    "import numpy as np \n",
    "import numba as nb\n",
    "import warnings\n",
    "import time\n",
    "from ot.backend import get_backend, NumpyBackend\n",
    "from ot.lp import emd\n",
    "\n",
    "from sklearn.datasets import load_svmlight_file\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "from sklearn.metrics import accuracy_score, recall_score, precision_score\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80b3356e-d428-4f4e-bea0-ae4481cd1c03",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@nb.njit(cache=True)\n",
    "def tensor_dot_param(C1,C2,Lambda=0,loss='square_loss'):\n",
    "    if loss=='square_loss':\n",
    "        def f1(r1):\n",
    "            return r1**2-2*Lambda\n",
    "        def f2(r2):\n",
    "            return r2**2\n",
    "        def h1(r1):\n",
    "            return r1\n",
    "        def h2(r2):\n",
    "            return 2*r2\n",
    "    # else:\n",
    "    #     warnings.warn(\"loss function error\")\n",
    "\n",
    "    fC1=f1(C1)\n",
    "    fC2=f2(C2)\n",
    "    hC1=h1(C1)\n",
    "    hC2=h2(C2)\n",
    "    \n",
    "    return fC1,fC2,hC1,hC2\n",
    "\n",
    "@nb.njit(cache=True)\n",
    "def tensor_dot_func(fC1,fC2,hC1,hC2,Gamma):\n",
    "    #Gamma=np.ascontiguousarray(Gamma)\n",
    "    n,m=Gamma.shape\n",
    "    Gamma_1=Gamma.sum(1).reshape((-1,1))\n",
    "    Gamma_2=Gamma.sum(0).reshape((-1,1))\n",
    "    C1=fC1.dot(Gamma_1).dot(np.ones((1,m)))\n",
    "    C2=np.ones((n,1)).dot(Gamma_2.T).dot(fC2.T)\n",
    "    tensor_dot=C1+C2-hC1.dot(Gamma).dot(hC2.T) \n",
    "    return tensor_dot\n",
    "\n",
    "@nb.njit(cache=True)\n",
    "def gwgrad_partial1(C1, C2, T,loss='square'):\n",
    "    \"\"\"Compute the GW gradient. Note: we can not use the trick in :ref:`[12] <references-gwgrad-partial>`\n",
    "    as the marginals may not sum to 1.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    C1: array of shape (n_p,n_p)\n",
    "        intra-source (P) cost matrix\n",
    "\n",
    "    C2: array of shape (n_u,n_u)\n",
    "        intra-target (U) cost matrix\n",
    "\n",
    "    T : array of shape(n_p+nb_dummies, n_u) (default: None)\n",
    "        Transport matrix\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    numpy.array of shape (n_p+nb_dummies, n_u)\n",
    "        gradient\n",
    "\n",
    "\n",
    "    .. _references-gwgrad-partial:\n",
    "    References\n",
    "    ----------\n",
    "    .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,\n",
    "        \"Gromov-Wasserstein averaging of kernel and distance matrices.\"\n",
    "        International Conference on Machine Learning (ICML). 2016.\n",
    "    \"\"\"\n",
    "    #T=np.ascontiguousarray(T)\n",
    "    if loss=='square':\n",
    "        cC1 = np.dot(C1 ** 2 , np.dot(T, np.ones(C2.shape[0]).reshape(-1, 1)))\n",
    "        cC2 = np.dot(np.dot(np.ones(C1.shape[0]).reshape(1, -1), T), C2 ** 2 )\n",
    "        constC = cC1 + cC2\n",
    "        A = -2*np.dot(C1, T).dot(C2.T)\n",
    "        tens = constC + A\n",
    "    elif loss=='dot':\n",
    "        constC=0\n",
    "        A = -2*np.dot(C1, T).dot(C2.T)\n",
    "        tens = constC + A\n",
    "    return tens \n",
    "\n",
    "def partial_gromov_ver1(C1, C2, p, q, Lambda, G0=None,nb_dummies=1,\n",
    "                               thres=1, numItermax_gw=1000,numItermax=None, tol=1e-7,\n",
    "                               log=False, verbose=False, line_search=True,seed=0,truncate=True, **kwargs):\n",
    "   \n",
    "    r\"\"\"\n",
    "    Solves the partial optimal transport problem\n",
    "    and returns the OT plan\n",
    "\n",
    "    The function considers the following problem:\n",
    "\n",
    "    .. math::\n",
    "        \\gamma = \\mathop{\\arg \\min}_\\gamma \\quad \\langle \\gamma, \\mathbf{M} \\rangle_F\n",
    "\n",
    "    .. math::\n",
    "        s.t. \\ \\gamma \\mathbf{1} &\\leq \\mathbf{a}\n",
    "\n",
    "             \\gamma^T \\mathbf{1} &\\leq \\mathbf{b}\n",
    "\n",
    "             \\gamma &\\geq 0\n",
    "\n",
    "             \\mathbf{1}^T \\gamma^T \\mathbf{1} = m &\\leq \\min\\{\\|\\mathbf{a}\\|_1, \\|\\mathbf{b}\\|_1\\}\n",
    "\n",
    "    where :\n",
    "\n",
    "    - :math:`\\mathbf{M}` is the metric cost matrix\n",
    "    - :math:`\\Omega` is the entropic regularization term, :math:`\\Omega(\\gamma) = \\sum_{i,j} \\gamma_{i,j}\\log(\\gamma_{i,j})`\n",
    "    - :math:`\\mathbf{a}` and :math:`\\mathbf{b}` are the sample weights\n",
    "    - `m` is the amount of mass to be transported\n",
    "\n",
    "    The formulation of the problem has been proposed in\n",
    "    :ref:`[29] <references-partial-gromov-wasserstein>`\n",
    "\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    C1 : ndarray, shape (ns, ns)\n",
    "        Metric cost matrix in the source space\n",
    "    C2 : ndarray, shape (nt, nt)\n",
    "        Metric costfr matrix in the target space\n",
    "    p : ndarray, shape (ns,)\n",
    "        Distribution in the source space\n",
    "    q : ndarray, shape (nt,)\n",
    "        Distribution in the target space\n",
    "    m : float, optional\n",
    "        Amount of mass to be transported\n",
    "        (default: :math:`\\min\\{\\|\\mathbf{p}\\|_1, \\|\\mathbf{q}\\|_1\\}`)\n",
    "    nb_dummies : int, optional\n",
    "        Number of dummy points to add (avoid instabilities in the EMD solver)\n",
    "    G0 : ndarray, shape (ns, nt), optional\n",
    "        Initialization of the transportation matrix\n",
    "    thres : float, optional\n",
    "        quantile of the gradient matrix to populate the cost matrix when 0\n",
    "        (default: 1)\n",
    "    numItermax : int, optional\n",
    "        Max number of iterations\n",
    "    tol : float, optional\n",
    "        tolerance for stopping iterations\n",
    "    log : bool, optional\n",
    "        return log if True\n",
    "    verbose : bool, optional\n",
    "        Print information along iterations\n",
    "    **kwargs : dict\n",
    "        parameters can be directly passed to the emd solver\n",
    "\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    gamma : (dim_a, dim_b) ndarray\n",
    "        Optimal transportation matrix for the given parameters\n",
    "    log : dict\n",
    "        log dictionary returned only if `log` is `True`\n",
    "\n",
    "\n",
    "    Examples\n",
    "    --------\n",
    "    >>> import ot\n",
    "    >>> import scipy as sp\n",
    "    >>> a = np.array([0.25] * 4)\n",
    "    >>> b = np.array([0.25] * 4)\n",
    "    >>> x = np.array([1,2,100,200]).reshape((-1,1))\n",
    "    >>> y = np.array([3,2,98,199]).reshape((-1,1))\n",
    "    >>> C1 = sp.spatial.distance.cdist(x, x)\n",
    "    >>> C2 = sp.spatial.distance.cdist(y, y)\n",
    "    >>> np.round(partial_gromov_wasserstein(C1, C2, a, b),2)\n",
    "    array([[0.  , 0.25, 0.  , 0.  ],\n",
    "           [0.25, 0.  , 0.  , 0.  ],\n",
    "           [0.  , 0.  , 0.25, 0.  ],\n",
    "           [0.  , 0.  , 0.  , 0.25]])\n",
    "    >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2)\n",
    "    array([[0.  , 0.  , 0.  , 0.  ],\n",
    "           [0.  , 0.  , 0.  , 0.  ],\n",
    "           [0.  , 0.  , 0.25, 0.  ],\n",
    "           [0.  , 0.  , 0.  , 0.  ]])\n",
    "\n",
    "\n",
    "    .. _references-partial-gromov-wasserstein:\n",
    "    References\n",
    "    ----------\n",
    "    ..  [29] Chapel, L., Alaya, M., Gasso, G. (2020). \"Partial Optimal\n",
    "        Transport with Applications on Positive-Unlabeled Learning\".\n",
    "        NeurIPS.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    # if m is None:\n",
    "    #     m = np.min((np.sum(p), np.sum(q)))\n",
    "    # elif m < 0:\n",
    "    #     raise ValueError(\"Problem infeasible. Parameter m should be greater\"\n",
    "    #                      \" than 0.\")\n",
    "    # elif m > np.min((np.sum(p), np.sum(q))):\n",
    "    #     raise ValueError(\"Problem infeasible. Parameter m should lower or\"\n",
    "    #                      \" equal than min(|a|_1, |b|_1).\")\n",
    "    \n",
    "        \n",
    "    if G0 is None:\n",
    "        G0 = np.outer(p, q)\n",
    "\n",
    "    cpt = 0\n",
    "    err = 1\n",
    "    \n",
    "    if log:\n",
    "        log_dict = {'err': [],'G0_mass':[],'Gprev_mass':[]}\n",
    "        \n",
    "    fC1,fC2,hC1,hC2=tensor_dot_param(C1,C2,Lambda=Lambda,loss='square_loss')\n",
    "    fC1,fC2,hC1,hC2=np.ascontiguousarray(fC1),np.ascontiguousarray(fC2),np.ascontiguousarray(hC1),np.ascontiguousarray(hC2)\n",
    "    C1,C2=np.ascontiguousarray(C1),np.ascontiguousarray(C2)\n",
    "    iter_num=0\n",
    "    n,m=C1.shape[0],C2.shape[0]\n",
    "    if numItermax is None:\n",
    "        numItermax=n*100\n",
    "    p_sum,q_sum=p.sum(),q.sum()\n",
    "    G0_orig=np.zeros((n,m))\n",
    "    \n",
    "    mu_extended,nu_extended,M_extended=np.zeros(n+1),np.zeros(m+1),np.zeros((n+1,m+1))\n",
    "    mu_extended[0:n],mu_extended[-1]=p,q_sum\n",
    "    nu_extended[0:m],nu_extended[-1]=q,p_sum\n",
    "        \n",
    "    while (err > tol and cpt < numItermax_gw):\n",
    "        #iter_num+=1\n",
    "        Gprev = G0.copy()\n",
    "\n",
    "        Mt_circ_G=tensor_dot_func(fC1,fC2,hC1,hC2,Gprev)\n",
    "        reg=2*Lambda*np.sum(Gprev)\n",
    "        \n",
    "        M_circ_G=gwgrad_partial1(C1, C2, Gprev)-reg\n",
    "        print('difference is', np.linalg.norm(M_circ_G-Mt_circ_G))\n",
    "        #M_tilde_circ_gamma=M_circ_gamma-reg \n",
    "        \n",
    "        # opt solver: \n",
    "        # Flamary's trick to fasten the computation: select only the subset of columns/lines\n",
    "        \n",
    "#        G0,innerlog_=opt_lp(p,q,Grad,Lambda=0,log=log,numItermax=numItermax,**kwargs)\n",
    "        \n",
    "#        eps=reg\n",
    "#        M_extended[:,-1],M_extended[-1,:]=reg,reg\n",
    "    \n",
    "        M_extended[0:n,0:m]=Mt_circ_G #-reg\n",
    "        \n",
    "        #M_extended[:idx_x.shape[0], :idx_y.shape[0]]= M_star[np.ix_(idx_x, idx_y)]\n",
    "        gamma_extended,log_dict=emd_lp(mu_extended,nu_extended,M_extended,numItermax=numItermax,log=log,**kwargs)\n",
    "        \n",
    "        G0=G0_orig.copy()\n",
    "        G0[0:n,0:m]=gamma_extended[:-1,:-1]\n",
    "        #G0[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]\n",
    "        if cpt % 10 == 0:  # to speed up the computations\n",
    "            err = np.linalg.norm(G0 - Gprev)\n",
    "            if log:\n",
    "                log['err'].append(err)\n",
    "            if verbose:\n",
    "                if cpt % 200 == 0:\n",
    "                    print('{:5s}|{:12s}|{:12s}'.format(\n",
    "                        'It.', 'Err', 'Loss') + '\\n' + '-' * 31)\n",
    "                print('{:5d}|{:8e}|{:8e}'.format(cpt, err,\n",
    "                                                 gwloss_partial(C1, C2, G0)))\n",
    "\n",
    "        \n",
    "        \n",
    "        # line search \n",
    "        deltaG = G0 - Gprev\n",
    "        \n",
    "        # line search \n",
    "        if line_search:\n",
    "            \n",
    "            \n",
    "            Mt_circ_deltaG=tensor_dot_func(fC1,fC2,hC1,hC2,deltaG)\n",
    "            a=np.sum(Mt_circ_deltaG*deltaG)\n",
    "            b=2 * (np.sum(Mt_circ_G * deltaG))\n",
    "            \n",
    "            M_circ_deltaG=gwgrad_partial1(C1, C2, deltaG)\n",
    "            deltaG_sum=np.sum(deltaG)\n",
    "            a1=np.sum(M_circ_deltaG*deltaG)-2*Lambda*deltaG_sum**2\n",
    "            b1= 2 * (np.sum(M_circ_G * deltaG)-reg*deltaG_sum)\n",
    "            \n",
    "            print('a1-a',a1-a)\n",
    "            print('b1-b',b1-b)\n",
    "            if a>0:  # due to numerical precision\n",
    "                if b>=0:\n",
    "                    alpha = 0\n",
    "                    cpt = numItermax_gw\n",
    "                else:\n",
    "                    alpha = min(1, np.divide(-b, 2.0 * a))\n",
    "            else:\n",
    "                if (a + b) < 0:\n",
    "                    alpha = 1\n",
    "                else:\n",
    "                    alpha = 0\n",
    "                    cpt = numItermax_gw\n",
    "        else:\n",
    "            alpha=1\n",
    "        \n",
    "        G0 = Gprev + alpha * deltaG\n",
    "        cpt += 1\n",
    "        print('cpt is',cpt)\n",
    "    if log:\n",
    "        log_dict.update(innerlog_)\n",
    "        return G0, log_dict #,iter_num\n",
    "    else:\n",
    "        return G0 #,iter_num"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c032785a-c620-436d-9b70-a6ea0927d5db",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def data_process(name='amazon_surf'):\n",
    "    # open the data file \n",
    "    if name in ['MNIST','EMNIST']:\n",
    "        data_file=torch.load('pu_learning/data/'+name+'.pt')\n",
    "        (X,l)=data_file\n",
    "        classes= None\n",
    "    elif 'surf' in name or 'decaf' in name:        \n",
    "        with open('pu_learning/data/'+name+'_fts.pkl', 'rb') as f:\n",
    "            data_file = pickle.load(f)\n",
    "     \n",
    "        if 'surf' in name:\n",
    "            X0=data_file['features']\n",
    "            l=data_file['labels']\n",
    "            classes=data_file['classes']\n",
    "            pca = PCA(n_components=10, random_state=0)\n",
    "            pca.fit(X0.T)\n",
    "            X = pca.components_.T\n",
    "        elif 'decaf' in name:\n",
    "            X0=data_file['fc8']\n",
    "            l=data_file['labels']\n",
    "            classes=data_file['classes']\n",
    "            pca = PCA(n_components=40, random_state=0)\n",
    "            pca.fit(X0.T)\n",
    "            X = pca.components_.T\n",
    "    return (X,l),classes\n",
    "\n",
    "\n",
    "def MNIST_figure(figure_list,label_list):\n",
    "    plt.figure(figsize=(10, 4))\n",
    "    for i in range(10):\n",
    "        plt.subplot(2, 5, i + 1)\n",
    "        plt.imshow(figure_list[i][0], cmap='gray')\n",
    "        plt.title(f\"Label: {label_list[i]}\")\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "def normalize_X(X):\n",
    "    div = np.max(X, axis=0) - np.min(X, axis=0)\n",
    "    div[div == 0] = 1 # Avoid division by zero\n",
    "    X = (X - np.min(X, axis=0)) / div\n",
    "    return X\n",
    "    \n",
    "# def convert_data(dataset,name='MNIST',visual=False):\n",
    "#     if name in ['MNIST','EMNIST']:\n",
    "#         X_list,label_list=dataset\n",
    "#             label_list_all.append(label_list)\n",
    "#         embedding_list_all=np.vstack(embedding_list_all)\n",
    "#         label_list_all=np.vstack(label_list_all).reshape(-1).astype(np.int64)\n",
    "#     return embedding_list_all,label_list_all\n",
    "\n",
    "\n",
    "# it is modified version \n",
    "def draw_pu_dataset_scar(dataset_p, dataset_u=None, size_p=10, size_u=20, prior=0.5, p_label=0,seed_nb=None,same_dataset=True):\n",
    "    \"\"\"Draw a Positive and Unlabeled dataset \"at random\"\"\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    dataset_p: name of the dataset among which the positives are drawn\n",
    "\n",
    "    dataset_u: name of the dataset among which the unlabeled are drawn\n",
    "\n",
    "    size_p: number of points in the positive dataset\n",
    "\n",
    "    size_u: number of points in the unlabeled dataset\n",
    "\n",
    "    prior: percentage of positives on the dataset (s)\n",
    "\n",
    "    seed_nb: seed\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    pandas.DataFrame of shape (n_p, d_p)\n",
    "        Positive dataset\n",
    "\n",
    "    pandas.DataFrame of shape (n_u, d_u)\n",
    "        Unlabeled dataset\n",
    "\n",
    "    pandas.Series of len (n_u)\n",
    "        labels of the unlabeled dataset\n",
    "    \"\"\"\n",
    "    x, l = dataset_p[0].copy(),dataset_p[1].copy()\n",
    "    A=l==p_label\n",
    "    B=l!=p_label\n",
    "    l[A],l[B]=1,0\n",
    "    x=normalize_X(x)\n",
    "\n",
    "    size_u_p = int(prior * size_u)\n",
    "    size_u_n = size_u - size_u_p\n",
    "    \n",
    "    xp_t = x[l == 1]\n",
    "    tp_t = l[l == 1]\n",
    "\n",
    "    xp, xp_other, _, tp_o = train_test_split(xp_t, tp_t, train_size=size_p,\n",
    "                                             random_state=seed_nb)\n",
    "    #print('xp_other shape',xp_other.shape)\n",
    "    if same_dataset or dataset_u is None:\n",
    "        xup, _, lup, _ = train_test_split(xp_other, tp_o, train_size=size_u_p,\n",
    "                                        random_state=seed_nb)\n",
    "    else:\n",
    "        x, l = dataset_u[0].copy(),dataset_u[1].copy()\n",
    "        x=normalize_X(x)\n",
    "        A=l==p_label\n",
    "        B=l!=p_label\n",
    "        l[A],l[B]=1,0\n",
    "        # x, t = make_data(dataset=dataset_u)\n",
    "        \n",
    "        # div = np.max(x, axis=0) - np.min(x, axis=0)\n",
    "        # div[div == 0] = 1\n",
    "        # x = (x - np.min(x, axis=0)) / div\n",
    "        xp_other = x[l == 1]\n",
    "        tp_o = l[l == 1]\n",
    "        xup, _, lup, _ = train_test_split(xp_other, tp_o,\n",
    "                                        train_size=size_u_p,\n",
    "                                        random_state=seed_nb)\n",
    "\n",
    "    xn_t = x[l == 0]\n",
    "    tn_t = l[l == 0]\n",
    "    xun, _, lun, _ = train_test_split(xn_t, tn_t, train_size=size_u_n,\n",
    "                                    random_state=seed_nb)\n",
    "    \n",
    "    xu = np.concatenate([xup, xun], axis=0)\n",
    "    yu = np.concatenate((np.ones(len(xup)), np.zeros(len(xun)))).astype(np.int64)\n",
    "    yu_2=np.concatenate((lup,lun))\n",
    "    #print(np.linalg.norm(yu-yu_2))\n",
    "    return xp, xu, yu_2\n",
    "\n",
    "def init_pgw_param(C1,C2,r):\n",
    "    n,m=C1.shape[0],C2.shape[0]\n",
    "    q=np.ones(m)/m  \n",
    "    p=np.ones(n)/n*r # make the mass of p to be r\n",
    "    mass=np.min((p.sum(),r))\n",
    "    return p,q,mass\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            \n",
    "\n",
    "def gamma_to_l(G,r):\n",
    "    n,m=G.shape\n",
    "    G_2=G.sum(0)\n",
    "    quantile=np.quantile(G_2,1-r)\n",
    "    l_G=np.zeros(m)\n",
    "    l_G[G_2>=quantile]=1\n",
    "    return l_G\n",
    "\n",
    "def init_param_ugw(C1,C2):\n",
    "    n,m=C1.shape[0],C2.shape[0]\n",
    "    n_pos,n_unl=n,m\n",
    "    nb_try=1\n",
    "    mu = (torch.ones([n_pos]) / n_pos).expand(nb_try, -1)\n",
    "    nu = (torch.ones([n_unl]) / n_unl).expand(nb_try, -1)\n",
    "    \n",
    "    grid_eps = [2. ** k for k in range(-9, -8, 1)]\n",
    "    grid_rho = [2. ** k for k in range(-10, -4, 1)]\n",
    "    eps=grid_eps[0]\n",
    "    rho=grid_rho[0]\n",
    "    rho2=grid_rho[0]\n",
    "    Cx=torch.from_numpy(C1).to(torch.float32).reshape((nb_try,n,n))\n",
    "    Cy=torch.from_numpy(C2).to(torch.float32).reshape((nb_try,m,m))\n",
    "    return mu,nu,eps,rho,rho2,Cx,Cy\n",
    "\n",
    "def init_flb_uot(C1,C2):\n",
    "    mu,nu,eps,rho,rho2,Cx,Cy=init_param_ugw(C1,C2)\n",
    "    print('eps in flb_uot is',eps)\n",
    "    _, _, init_plan = compute_batch_flb_plan(\n",
    "            mu, Cx, nu, Cy, eps=eps, rho=rho, rho2=rho2,\n",
    "            nits_sinkhorn=50000, tol_sinkhorn=1e-5)\n",
    "    \n",
    "    return init_plan[0].numpy().astype(np.float64)\n",
    "\n",
    "def init_flb_pot(C1,C2,p,q,r,Lambda=30.0,n=100):\n",
    "    p,q,mass=init_pgw_param(C1,C2,r)\n",
    "    S1,S2=C1.mean(0),C2.mean(0)\n",
    "    C=cost_matrix(S1,S2)\n",
    "    gamma,_=opt_lp(p,q,C,Lambda=Lambda,numItermax=n*500)\n",
    "    \n",
    "    return gamma\n",
    "\n",
    "def pu_prediction_gw(C1,C2,r=0.2,G0=None,method='pgw',param={'Lambda':30.0}):\n",
    "    C1,C2=C1.astype(np.float64),C2.astype(np.float64)\n",
    "    #C1,C2=cost_matrix_d(X_p,X_p),cost_matrix_d(X_u,X_u)\n",
    "    n,m=C1.shape[0],C2.shape[0]\n",
    "    size_p=int(m*r)\n",
    "    if size_p!=n:\n",
    "        print('# of positives in X_p and X_u are different, we suggest to modify them')\n",
    "    if method=='gw':\n",
    "        p=np.ones(n)/n\n",
    "    if method=='primal_pgw':\n",
    "        p,q,mass=init_pgw_param(C1,C2,r)\n",
    "#       mass=min(r*np.sum(q),np.sum(p)) # this used to avoid numerical issue \n",
    "        C1,C2=C1.astype(np.float64),C2.astype(np.float64)\n",
    "        gamma=partial_gromov_wasserstein(C1,C2,p,q,m=mass,G0=G0,numItermax=n*1000,nb_dummies=1,line_search=False)\n",
    "        \n",
    "    if method=='pgw':\n",
    "        Lambda=param['Lambda']\n",
    "        p,q,mass=init_pgw_param(C1,C2,r)\n",
    "        C1,C2=C1.astype(np.float64),C2.astype(np.float64)\n",
    "        gamma=partial_gromov_ver1(C1,C2,p,q,Lambda=Lambda,G0=G0,numItermax=n*1000,nb_dummies=1,line_search=False)\n",
    "    if method=='ugw':\n",
    "        mu,nu,eps,rho,rho2,Cx,Cy=init_param_ugw(C1,C2)\n",
    "        if 'rho' in param:\n",
    "            rho=param['rho']\n",
    "            rho2=rho\n",
    "        if 'eps' in param:\n",
    "            eps=param['eps']\n",
    "        # need to try different rho for better performance\n",
    "#        rho=0.0023 surf A\n",
    "        if type(G0)==np.ndarray:\n",
    "            init_plan=torch.from_numpy(G0).to(torch.float32).reshape((1,n,m))\n",
    "        elif type(G0)==torch.Tensor:\n",
    "            init_plan=G0\n",
    "        gamma = log_batch_ugw_sinkhorn(mu, Cx, nu, Cy, init=init_plan,\n",
    "                                eps=eps, rho=rho, rho2=rho2,\n",
    "                                nits_plan=3000, tol_plan=1e-5,\n",
    "                                nits_sinkhorn=3000, tol_sinkhorn=1e-6)\n",
    "        print('gamma_mass_diff',gamma.sum()-r)\n",
    "        gamma=gamma[0]\n",
    "    return gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3569c936-73f5-45ce-9e51-e676e47b4f03",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#nb_dummies=1\n",
    "p_label=1\n",
    "nb_dummies=1\n",
    "name1='MNIST' \n",
    "name2='EMNIST'\n",
    "#name3='webcam_surf'\n",
    "file_name=name1+name2+'.pt' #'surf.pt' #name1+'-'+name2+'.pt'\n",
    "try:\n",
    "    result=torch.load('pu_learning/result/'+filename)\n",
    "except:\n",
    "    result={}\n",
    "n=1000\n",
    "r=1/5\n",
    "m=int(n/r)\n",
    "seed_nb=3\n",
    "\n",
    "dataset1,_=data_process(name=name1)\n",
    "dataset2,_=data_process(name=name2)\n",
    "\n",
    "dataname_list=[name1,name2]\n",
    "dataset_list=[dataset1,dataset2]\n",
    "init_method_list=['flb_uot','flb_pot']\n",
    "method_list=['primal_pgw','pgw'] #\n",
    "for (data1_name,data1) in zip(dataname_list,dataset_list):\n",
    "    for (data2_name,data2) in zip(dataname_list,dataset_list):\n",
    "\n",
    "        print('data 1 is',data1_name)\n",
    "        print('data 2 is',data2_name)\n",
    "        if data1_name==data2_name:\n",
    "            same_dataset=True\n",
    "        else:\n",
    "            same_dataset=False\n",
    "        for init_method in init_method_list:\n",
    "            G0 = None\n",
    "\n",
    "            X_p,X_u,label_u=draw_pu_dataset_scar(data1,data2,p_label=p_label,prior=r,size_p=n, size_u=m,seed_nb=seed_nb,same_dataset=same_dataset)\n",
    "            C, C1, C2, mu, nu=compute_cost_matrices(P=X_p, U=X_u, prior=r, nb_dummies=1)\n",
    "            p,q=mu[0:n],nu[0:m]\n",
    "            C1=C1[0:n,0:n]\n",
    "            C2=C2[0:m,0:m]\n",
    "\n",
    "\n",
    "            time1=time.time()\n",
    "            if init_method=='pot_r' and C is not None:\n",
    "                G0=ot.emd(mu, nu, C)[:n, :] \n",
    "                #pu_w_emd(mu, nu, C, nb_dummies=nb_dummies)\n",
    "                #G0=G0[0:-nb_dummies,:]\n",
    "            elif init_method=='flb_pot':\n",
    "                G0=init_flb_pot(C1,C2,p,q,r,Lambda=30.0)\n",
    "            elif init_method=='flb_uot':\n",
    "                G0=init_flb_uot(C1,C2)\n",
    "\n",
    "            time2=time.time()\n",
    "            run_time=time2-time1\n",
    "            if G0 is not None:\n",
    "                l_G0=gamma_to_l(G0,r)\n",
    "                acc0=accuracy_score(l_G0,label_u)\n",
    "                result[init_method+'-'+data1_name+'-'+data2_name]={}\n",
    "                result[init_method+'-'+data1_name+'-'+data2_name]['accuracy']=acc0\n",
    "                result[init_method+'-'+data1_name+'-'+data2_name]['time']=run_time\n",
    "                #result[init_method+'-'+data1+'-'+data2]['G0']=G0\n",
    "                print('init method is',init_method)\n",
    "                print('accuracy is',acc0)\n",
    "                print('time is',run_time)    \n",
    "            # if G0 is not None:    \n",
    "                for method in method_list:\n",
    "                    if True: #if init_method+'-'+data1+'-'+data2+'-'+method not in result:\n",
    "                        if method=='ugw':\n",
    "                            param={'None'}\n",
    "                        elif method=='pgw':\n",
    "                            param={'Lambda':20.0}\n",
    "                        else:\n",
    "                            param=None\n",
    "                        time1=time.time()\n",
    "                        G=pu_prediction_gw(C1.copy(),C2.copy(),r=r,G0=G0.copy(),method=method,param=param)\n",
    "                        time2=time.time()\n",
    "                        run_time=time2-time1\n",
    "\n",
    "                        l_G=gamma_to_l(G,r)\n",
    "                        acc=accuracy_score(l_G,label_u)\n",
    "                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]={}\n",
    "                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]['time']=run_time\n",
    "                        result[init_method+'-'+data1_name+'-'+data2_name+'-'+method]['accuracy']=acc\n",
    "                        print('method is',method)\n",
    "                        print('accuracy is',acc)\n",
    "                        print('time is',run_time)\n",
    "                #torch.save(result,'pu_learning/result/'+file_name)\n",
    "                    \n",
    "# p,q=np.ones(n)*r/n,np.ones(m)/m\n",
    "\n",
    "# l_G0=gamma_to_l(G0,r)\n",
    "# acc_G0=accuracy_score(l_G0,label_u)\n",
    "# print('acc_G0',acc_G0)\n",
    "\n",
    "# if C is not None:\n",
    "#     G0=pu_w_emd(mu, nu, C, nb_dummies=nb_dummies)\n",
    "#     G0=G0[0:-nb_dummies,:]\n",
    "\n",
    "#     l_G0=gamma_to_l(G0,r)\n",
    "#     acc_G0=accuracy_score(l_G0,label_u)\n",
    "#     print('acc_G0',acc_G0)\n",
    "# gamma=pu_prediction_gw(C1,C2,r=r,method='ugw',G0=G0,param={'Lambda':30.0})\n",
    "# l_G=gamma_to_l(gamma,r)\n",
    "# acc=accuracy_score(l_G,label_u)\n",
    "# print('acc',acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af141b9f-ef84-42c8-9717-37e3c66cb293",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print('done')\n",
    "result=torch.load('pu_learning/result/MNIST-EMNIST.pt')\n",
    "for key in result:\n",
    "    str_list=key.split(\"-\")\n",
    "    if len(str_list)==3:\n",
    "        init,data1,data2=str_list[0],str_list[1],str_list[2]\n",
    "        print('data1 is',data1)\n",
    "        print('data2 is',data2)\n",
    "        print('init method is',init)\n",
    "        print('accuracy is', result[key]['accuracy'])\n",
    "        print('time is', result[key]['time'])\n",
    "\n",
    "    \n",
    "    elif len(str_list)==4:\n",
    "        init,data1,data2,method=str_list[0],str_list[1],str_list[2],str_list[3]\n",
    "        print('method is',method)\n",
    "        print('accuracy is', result[key]['accuracy'])\n",
    "        print('time is', result[key]['time'])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac2b8b1-e94b-49e7-92d9-2e46c4308d5e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c838c0b8-19b8-4a0e-9ac0-cabf3e193299",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
