{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 128,
      "metadata": {
        "id": "EvuPYzXTcPgl"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "import random\n",
        "import torch\n",
        "from scipy.special import erf\n",
        "import matplotlib.pyplot as plt\n",
        "from math import*\n",
        "from scipy.integrate import quad as itg\n",
        "from torch.utils.data import DataLoader, Dataset\n",
        "#torch\n",
        "import torch\n",
        "import torchvision.datasets as datasets\n",
        "from torchvision.datasets import MNIST\n",
        "from torchvision import transforms\n",
        "from sklearn.manifold import TSNE\n",
        "import seaborn as sns\n",
        "from sklearn.decomposition import PCA\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 129,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rVTjYhW_sEcK",
        "outputId": "55fe065e-a5d3-4476-ed16-e8b3729ba9e9"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "device(type='cuda', index=0)"
            ]
          },
          "metadata": {},
          "execution_count": 129
        }
      ],
      "source": [
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "device"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LFf8ElI3qWOI"
      },
      "source": [
        "# Creating the data set\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 130,
      "metadata": {
        "id": "JKoBC-rDTbn1"
      },
      "outputs": [],
      "source": [
        "\n",
        "d=1000        #dimension\n",
        "\n",
        "N=54000       #size of the pool of data. Not all are used for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 131,
      "metadata": {
        "id": "ofuE23-V5Ke6"
      },
      "outputs": [],
      "source": [
        "#generating the data. Here, a 3 cluster mixture.\n",
        "X_train=np.random.randn(N,d)\n",
        "X_test=np.random.randn(N,d)\n",
        "\n",
        "\n",
        "mu1=np.zeros((int(N/6),d))\n",
        "mu2=np.zeros((int(N/3),d))\n",
        "mu3=np.zeros((int(N/2),d))\n",
        "\n",
        "mu1[:,0]=np.ones(int(N/6))*3\n",
        "mu2[:,0]=-np.ones(int(N/3))*3\n",
        "mu3[:,1]=-np.ones(int(N/2))*3\n",
        "\n",
        "\n",
        "mean=np.vstack((mu1,mu3,mu2))\n",
        "assert mean.shape==X_train.shape\n",
        "\n",
        "mu=torch.from_numpy(np.vstack((mu1[0],mu3[0],mu2[0])).astype(np.float32)).to(device)\n",
        "\n",
        "\n",
        "X_train=X_train+mean\n",
        "X_test=X_test+mean\n",
        "\n",
        "X_train= X_train.astype(np.float32)\n",
        "X_test= X_test.astype(np.float32)\n",
        "\n",
        "\n",
        "X_test=X_test[:2000]\n",
        "X_train=X_train[:]\n",
        "n, d = X_train.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 132,
      "metadata": {
        "id": "D9v3yr9gkL0j"
      },
      "outputs": [],
      "source": [
        "X_test=torch.from_numpy(X_test)\n",
        "X_train=torch.from_numpy(X_train)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 133,
      "metadata": {
        "id": "VwHlSqwTyqFi"
      },
      "outputs": [],
      "source": [
        "\n",
        "T=5    #number of discretization steps for the time integral in the definition of the denoising risk.\n",
        "\n",
        "X_tr=X_train.unsqueeze(2)\n",
        "Xis=torch.randn(X_tr.shape)\n",
        "\n",
        "Xt=Xis.clone()\n",
        "X1t=X_tr.clone()\n",
        "\n",
        "X_ts=X_test.unsqueeze(2)\n",
        "Xis_test=torch.randn(X_ts.shape)\n",
        "Xt_test=Xis_test.clone()\n",
        "X1t_test=X_ts.clone()\n",
        "\n",
        "for k in range(1,T):     #for each time index t, we generate x(t)=alpha_t x0+beta_t x1\n",
        "  t=k/T\n",
        "  X_add=t*X_tr+(1-t)*Xis\n",
        "  Xt=torch.cat((Xt, X_add), dim=2)\n",
        "  X1t=torch.cat((X1t, X_tr), dim=2)\n",
        "\n",
        "  X_add_test=t*X_ts+(1-t)*Xis_test\n",
        "  Xt_test=torch.cat((Xt_test, X_add_test), dim=2)\n",
        "  X1t_test=torch.cat((X1t_test, X_ts), dim=2)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pt=torch.from_numpy(np.array([np.cos(np.pi*k/T) for k in range(1,T)]).astype(np.float32)) #time encoding\n",
        "pt=pt.to(device)"
      ],
      "metadata": {
        "id": "pybQU7nWFmY_"
      },
      "execution_count": 134,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Xt=Xt[:,:,1:]\n",
        "X1t=X1t[:,:,1:]\n",
        "Xt_test=Xt_test[:,:,1:]\n",
        "X1t_test=X1t_test[:,:,1:]"
      ],
      "metadata": {
        "id": "YdXTU5HcNlXR"
      },
      "execution_count": 135,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#shuffling the data\n",
        "\n",
        "indices = torch.randperm(N)\n",
        "Xt=Xt[indices]\n",
        "X1t=X1t[indices]\n"
      ],
      "metadata": {
        "id": "5YLWavKYG40w"
      },
      "execution_count": 136,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "USDSs-tkqhpo"
      },
      "source": [
        "# Training the autoencoder"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 137,
      "metadata": {
        "id": "QjzIKO1jdxdA"
      },
      "outputs": [],
      "source": [
        "class generate_data(Dataset):\n",
        "  def __init__(self,n):\n",
        "    self.idx=random.sample(range(N),n)\n",
        "\n",
        "    self.X=Xt.to(device)\n",
        "    self.X1=X1t.to(device)\n",
        "    self.samples=n\n",
        "\n",
        "\n",
        "  def __getitem__(self,idx):\n",
        "    return self.X[idx].to(device), self.X1[idx].to(device)\n",
        "\n",
        "  def __len__(self):\n",
        "    return self.samples\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 138,
      "metadata": {
        "id": "edalGVUoeL3E"
      },
      "outputs": [],
      "source": [
        "class AE(torch.nn.Module):\n",
        "    def __init__(self, d,r, informed=False):\n",
        "        super(AE, self).__init__()\n",
        "\n",
        "        np.random.seed(55)\n",
        "        w0=np.random.randn(r,d).astype(np.float32)\n",
        "        init=torch.from_numpy(w0.T)\n",
        "\n",
        "        self.r=r   #number of hidden units\n",
        "        self.we=torch.nn.Parameter(init)   #weights\n",
        "        self.vt=torch.nn.Parameter(torch.ones(r))  #time encoding weights\n",
        "        self.skip=torch.nn.Parameter(torch.randn(1)*0)  #skip connection\n",
        "\n",
        "    def forward(self, x):\n",
        "\n",
        "        x=x\n",
        "        identity=x\n",
        "        sigma=torch.tanh\n",
        "        h1=torch.einsum(\"ij, nit->njt\",self.we,x)/np.sqrt(d)\n",
        "        enct=torch.unsqueeze(torch.outer(self.vt,pt),0).repeat(x.shape[0],1,1)\n",
        "        h=torch.tanh(h1+enct)\n",
        "        yhat = torch.einsum(\"ij, njt->nit\",self.we,h)/np.sqrt(d)\n",
        "        yhat+=self.skip*identity\n",
        "        return yhat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 139,
      "metadata": {
        "id": "_UzyH-Q5e4Je"
      },
      "outputs": [],
      "source": [
        "def quadloss(ypred, y):\n",
        "    return torch.sum((ypred-y)**2)/T"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 140,
      "metadata": {
        "id": "aEMvSNsPcPg-"
      },
      "outputs": [],
      "source": [
        "\n",
        "def train(train_loader, informed=False,verbose=False):\n",
        "    ae=AE(d, r=r, informed=informed).to(device)\n",
        "\n",
        "\n",
        "    optimizer = torch.optim.SGD([{'params': [ae.we],\"weight_decay\":.1/d, \"lr\":.5},{'params': [ae.skip],\"weight_decay\":0., \"lr\":.5/d**2},{'params': [ae.vt],\"weight_decay\":0., \"lr\":.5/d}])\n",
        "\n",
        "    # Collecting the order parameters\n",
        "    Q_list=[]\n",
        "    M_list=[]\n",
        "    b_list=[]\n",
        "    v_list=[]\n",
        "\n",
        "    Q_list.append(((ae.we).T@ae.we/d).detach().cpu().numpy())\n",
        "    M_list.append(((ae.we).T@mu.T/np.sqrt(d)).detach().cpu().numpy())\n",
        "    b_list.append(float((ae.skip).detach().cpu()))\n",
        "    v_list.append((ae.vt).detach().cpu())\n",
        "\n",
        "    for t in range(1):\n",
        "        for x,y in train_loader:\n",
        "          y_pred = ae(x)\n",
        "          loss = quadloss(y_pred,y)\n",
        "          optimizer.zero_grad()\n",
        "          loss.backward()\n",
        "          optimizer.step()\n",
        "\n",
        "          Q_list.append(((ae.we).T@ae.we/d).detach().cpu().numpy())\n",
        "          M_list.append(((ae.we).T@mu.T/np.sqrt(d)).detach().cpu().numpy())\n",
        "          b_list.append(float((ae.skip).detach().cpu()))\n",
        "          v_list.append((ae.vt).detach().cpu())\n",
        "\n",
        "\n",
        "    return ae.we.detach().cpu().numpy(),float(ae.skip.detach().cpu()), Q_list, M_list,b_list, v_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 141,
      "metadata": {
        "id": "c7jcKBVBcPhe"
      },
      "outputs": [],
      "source": [
        "samples=7000\n",
        "r=2\n",
        "training_set=generate_data(samples)\n",
        "train_loader=DataLoader(training_set,batch_size=int(1),shuffle=True)\n",
        "we,skip,Q,M,b,v=train(train_loader, informed=False)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "np.save(\"Cean_Simu_M.npy\", M)"
      ],
      "metadata": {
        "id": "6vdcPBEHsgVI"
      },
      "execution_count": 145,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "np.save(\"Cean_Simu_Q.npy\", Q)"
      ],
      "metadata": {
        "id": "y-LENoG_skSU"
      },
      "execution_count": 157,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "np.save(\"Cean_Simu_b.npy\", b)"
      ],
      "metadata": {
        "id": "8XbkgSDkNLge"
      },
      "execution_count": 147,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 149,
      "metadata": {
        "id": "2Xi6jbgDCn3p"
      },
      "outputs": [],
      "source": [
        "np.save(\"Cean_Simu_v.npy\", v)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Sampling"
      ],
      "metadata": {
        "id": "zgt_ykDQoobm"
      },
      "execution_count": 159,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 151,
      "metadata": {
        "id": "QdlNYbo8sCHH"
      },
      "outputs": [],
      "source": [
        "T_gen=50\n",
        "dt=1/T_gen\n",
        "N=5000\n",
        "X=np.random.randn(N, d)\n",
        "\n",
        "for k in range(int(.9*T_gen)):\n",
        "  t=k/T_gen\n",
        "  velocity=(1+t/(1-t+1e-5))*(skip*X + np.tanh((X@we/np.sqrt(d))+np.cos(np.pi*t)*v[-1].numpy())@we.T/np.sqrt(d))-1/(1-t+1e-5)*X\n",
        "  X=X+velocity*dt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 158,
      "metadata": {
        "id": "9VkKTMyGaJNF"
      },
      "outputs": [],
      "source": [
        "\n",
        "Samples = X[:,:2]#pca.fit_transform(X)\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "np.save(\"Cean_Simu_Samples.npy\", Samples)"
      ],
      "metadata": {
        "id": "GX4B9rzznqOK"
      },
      "execution_count": 156,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "q-wi4y3s3DDE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "kdWFbNhjpv0u"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}