{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d6d33d61",
   "metadata": {},
   "source": [
    "### Demomstrates the performance of TT-Cross Versus NN for function approximation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ddacdf5-a59b-48b1-8431-57c9032fb439",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tt_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff5e3456-9ee6-4705-a749-de3abd085edc",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65330fb7-0821-4cd6-94dd-b49fef8c643f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def gmm(n=2,nmix=3,L=1,mx_coef=None,mu=None,s=0.1, device='cpu'):\n",
    "    \"\"\"\n",
    "        Mixture of spherical Gaussians (un-normalized)\n",
    "        nmix: number of mixture coefficients\n",
    "        n: dimension of the domain\n",
    "        s: variance\n",
    "        mu: the centers assumed to be in : [-L,L]^n\n",
    "    \"\"\"\n",
    "    n_sqrt = torch.sqrt(torch.tensor([n]).to(device))\n",
    "    if mx_coef is None: # if centers and mixture coef are not given, generate them randomly\n",
    "        mx_coef = torch.rand(nmix).to(device)\n",
    "        mx_coef = mx_coef/torch.sum(mx_coef)\n",
    "        mu = (torch.rand(nmix,n).to(device)-0.5)*2*L\n",
    "\n",
    "    def pdf(x):\n",
    "        result = torch.tensor([0]).to(device)\n",
    "        for k in range(nmix):\n",
    "            l = torch.linalg.norm(mu[k]-x, dim=1)/n_sqrt\n",
    "            result = result + mx_coef[k]*torch.exp(-(l/s)**2)\n",
    "        return result\n",
    "\n",
    "    return pdf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29d66285-42f9-4bb4-bc0f-57d82e20bd5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 10\n",
    "L = 1\n",
    "nmix = 1\n",
    "s = 0.2\n",
    "pdf = gmm(n=dim,nmix=nmix,L=L,mx_coef=None,mu=None,s=s, device=device) # Or define an arbitrary function of your choice"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e7ae0b",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ee975647",
   "metadata": {},
   "source": [
    "##### Find TT approximation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b86ba96-9280-4706-9f55-820ea55ce398",
   "metadata": {},
   "outputs": [],
   "source": [
    "# So TT approximation\n",
    "n_discretization = torch.tensor([200]*dim).to(device)\n",
    "domain = [torch.linspace(-L,L,n_discretization[i]).to(device) for i in range(dim)] \n",
    "\n",
    "import time \n",
    "t1 = time.time()\n",
    "tt_gmm = cross_approximate(fcn=pdf,  max_batch=10**6, domain=domain, \n",
    "                        rmax=200, nswp=20, eps=1e-3, verbose=True, \n",
    "                        kickrank=3, device=device)\n",
    "t2 = time.time()\n",
    "print(\"time taken: \", t2-t1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5c05dfe",
   "metadata": {},
   "source": [
    "#### Prepare test set and train set for NN\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7239100c-5790-43f7-bbc3-bf6a3aeedb35",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ndata_train = 218400\n",
    "ndata_test = 1000\n",
    "\n",
    "x_train = 2*L*(-0.5 + torch.rand((ndata_train,dim)).to(device))\n",
    "y_train = pdf(x_train)\n",
    "\n",
    "x_test = 2*L*(-0.5 + torch.rand((ndata_test,dim)).to(device))\n",
    "y_test = pdf(x_test)\n",
    "\n",
    "data_train = torch.cat((x_train.view(-1,dim),y_train.view(-1,1)),dim=-1)\n",
    "data_test = torch.cat((x_test.view(-1,dim),y_test.view(-1,1)),dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd26218e-27c2-4c03-9e52-25f2c0640edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the error in TT-approximation\n",
    "y_tt =  get_value(tt_model=tt_gmm, x=x_test.to(device),  domain=domain, \n",
    "                    n_discretization=n_discretization , max_batch=10**5, device=device)\n",
    "\n",
    "mse_tt = (((y_tt.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()\n",
    "print(\"mse_tt: \", mse_tt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfdaac36-cabf-4505-91a8-7ab4617903c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c3e10c0c",
   "metadata": {},
   "source": [
    "##### Fit an NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6817e869-498f-4627-8ad5-4876c271bf01",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self, dim=2, width=32):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(dim, width),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(width, width),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(width, 1),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork(dim=dim, width=dim*nmix*10).to(device)\n",
    "def train_loop(data, model, loss_fn, optimizer, batch_size):\n",
    "    size = data.shape[0]\n",
    "    counter = 0\n",
    "    for i in range(int(size/batch_size)-1):\n",
    "        # Compute prediction and loss\n",
    "        next_counter = (counter+batch_size)\n",
    "        x_data = data[counter:next_counter,:-1]\n",
    "        y_data = data[counter:next_counter,-1].view(-1,1)\n",
    "        y_pred = model(x_data)\n",
    "        loss = loss_fn(y_pred, y_data)\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        counter = 1*next_counter\n",
    "\n",
    "        if (i % int(0.25*size/batch_size)) == 0 :\n",
    "            loss = loss.item()\n",
    "            print(f\"loss: {loss:>7f}\")\n",
    "\n",
    "\n",
    "def test_loop(data, model, loss_fn):\n",
    "    x_data = data[:,:-1]\n",
    "    y_data = data[:,-1]\n",
    "    with torch.no_grad():\n",
    "        pred = model(x_data)\n",
    "        test_loss = loss_fn(pred, y_data).item()\n",
    "    print(f\"Test Error: \", test_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66181567-4394-4706-9c8f-9fb526bd7aab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "38b5456f",
   "metadata": {},
   "source": [
    "##### Train NN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88f91213-495d-4c42-bd26-77a548caf7f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize NN\n",
    "learning_rate = 1e-3\n",
    "batch_size = 100\n",
    "epochs = 10\n",
    "y_nn_0 = model(x_test)\n",
    "mse_nn_0 = (((y_nn_0.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()\n",
    "print(\"mse_nn_0: \", mse_nn_0)\n",
    "\n",
    "# Train NN\n",
    "loss_fn = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "batch_size = 100\n",
    "epochs = 10\n",
    "t1 = time.time()\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train_loop(data_train, model, loss_fn, optimizer, batch_size)\n",
    "    test_loop(data_test, model, loss_fn)\n",
    "t2 = time.time()\n",
    "print(\"time taken: \", t2-t1)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7539baea-ecda-4204-ad3b-58b246262033",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_nn = model(x_test)\n",
    "mse_nn = (((y_nn.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()\n",
    "print(\"NN Relative MSE: \", mse_nn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c542ebc4-ce13-4403-a75b-3a43300add3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_tt = (((y_tt.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()\n",
    "print(\"TT Relative MSE: \", mse_tt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d815d40d-ca43-4732-a534-98b7a6a9fe96",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_nn = model(x_test)\n",
    "mse_nn = ((y_nn-y_test)**2).mean()\n",
    "print(\"NN Absolute MSE: \", mse_nn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf370189-7146-431a-a7e7-1b38ed504902",
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_tt = (((y_tt.view(-1)-y_test.view(-1)))**2).mean()\n",
    "print(\"TT Absolute MSE: \", mse_tt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75724813-e417-49aa-9510-2a9042977193",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c30acd-efeb-4493-a41d-b9c3f1e1a61e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
