{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "import numpy as np\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "from cycler import cycler\n",
    "\n",
    "from matplotlib import rc\n",
    "rc('font', **{'family': 'serif', 'serif': 'Computer Modern Roman', 'size'   : 11})\n",
    "rc('text', usetex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_gaussians(size):\n",
    "    means = [(0,-0.5),(0,0.5)]\n",
    "    var = 0.01\n",
    "    x = torch.FloatTensor([])\n",
    "    for mean_tuple in means:\n",
    "        gauss = torch.randn((size//len(means),2)) * var\n",
    "        gauss[:,0] += mean_tuple[0]\n",
    "        gauss[:,1] += mean_tuple[1]\n",
    "        x = torch.cat((x, gauss),dim=0)\n",
    "    return x\n",
    "\n",
    "def return_input_distrib(size):\n",
    "    means = [(0,0)]\n",
    "    var = (1., 1.)\n",
    "    x = torch.FloatTensor([])\n",
    "    for mean_tuple in means:\n",
    "        gauss = torch.randn((size//len(means),2))\n",
    "        gauss[:,0] = mean_tuple[0] + gauss[:,0] * var[0]\n",
    "        gauss[:,1] = mean_tuple[1] + gauss[:,1] * var[1]\n",
    "        x = torch.cat((x, gauss),dim=0)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEmbedding(torch.nn.Module):\n",
    "    def __init__(self, num_channels, max_positions=10000, endpoint=False):\n",
    "        super().__init__()\n",
    "        self.num_channels = num_channels\n",
    "        self.max_positions = max_positions\n",
    "        self.endpoint = endpoint\n",
    "\n",
    "    def forward(self, x):\n",
    "        freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)\n",
    "        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))\n",
    "        freqs = (1 / self.max_positions) ** freqs\n",
    "        x = x.ger(freqs.to(x.dtype))\n",
    "        x = torch.cat([x.cos(), x.sin()], dim=1)\n",
    "        return x\n",
    "\n",
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, data_dim=2, hidden_dim=256):\n",
    "        super().__init__()\n",
    "        self.map_noise = PositionalEmbedding(hidden_dim)\n",
    "        self.net_0 = torch.nn.Sequential(torch.nn.Linear(data_dim, hidden_dim), torch.nn.GELU())\n",
    "        self.net_1 = torch.nn.Sequential(torch.nn.Linear(2 * hidden_dim, hidden_dim), torch.nn.GELU(),\n",
    "                                    torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.GELU(),\n",
    "                                    torch.nn.Linear(hidden_dim, data_dim))\n",
    "    def forward(self, x, t):\n",
    "        t = self.map_noise(t)\n",
    "        x = self.net_0(x)\n",
    "        x = torch.cat((x,t),dim=1)\n",
    "        x = self.net_1(x)\n",
    "        return x\n",
    "\n",
    "class ConsistencyModel(torch.nn.Module):\n",
    "    def __init__(self, sigma_min, sigma_max, sigma_data=0.5, hidden_dim=256):\n",
    "        super().__init__()\n",
    "        self.sigma_min = sigma_min\n",
    "        self.sigma_max = sigma_max\n",
    "        self.sigma_data = sigma_data\n",
    "        self.model = MLP(hidden_dim=hidden_dim)\n",
    "\n",
    "    def forward(self, x, sigma):\n",
    "        sigma = sigma.unsqueeze(1)\n",
    "        c_skip = self.sigma_data**2 / ((sigma - self.sigma_min) ** 2 + self.sigma_data**2)\n",
    "        c_out = (self.sigma_data * (sigma - self.sigma_min)) / (self.sigma_data**2 + sigma**2) ** 0.5\n",
    "        c_in = 1  / (self.sigma_data ** 2 + sigma ** 2).sqrt()\n",
    "\n",
    "        c_noise = sigma.log() / 4\n",
    "\n",
    "        F_x = self.model((c_in * x), c_noise.flatten())\n",
    "        D_x = c_skip * x + c_out * F_x.to(torch.float32)\n",
    "        return D_x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sigmas_karras(num_timesteps, sigma_min, sigma_max, rho=7.0, device=\"cpu\"):\n",
    "    \"\"\"Constructs the noise schedule of Karras et al. (2022).\"\"\"\n",
    "    '''ramp = torch.linspace(0, 1, int(n))\n",
    "    min_inv_rho = sigma_min ** (1 / rho)\n",
    "    max_inv_rho = sigma_max ** (1 / rho)\n",
    "    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n",
    "    return torch.flip(sigmas, dims=(0,))'''\n",
    "\n",
    "    rho_inv = 1.0 / rho\n",
    "    # Clamp steps to 1 so that we don't get nans\n",
    "    steps = torch.arange(num_timesteps, device=device) / max(num_timesteps - 1, 1)\n",
    "    sigmas = sigma_min**rho_inv + steps * (\n",
    "        sigma_max**rho_inv - sigma_min**rho_inv\n",
    "    )\n",
    "    sigmas = sigmas**rho\n",
    "    return sigmas\n",
    "\n",
    "def improved_timesteps_schedule(current_training_step, total_training_steps, initial_timesteps = 10, final_timesteps = 1280):\n",
    "    \"\"\"Implements the improved timestep discretization schedule.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    current_training_step : int\n",
    "        Current step in the training loop.\n",
    "    total_training_steps : int\n",
    "        Total number of steps the model will be trained for.\n",
    "    initial_timesteps : int, default=2\n",
    "        Timesteps at the start of training.\n",
    "    final_timesteps : int, default=150\n",
    "        Timesteps at the end of training.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    int\n",
    "        Number of timesteps at the current point in training.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)\n",
    "    \"\"\"\n",
    "    total_training_steps_prime = math.floor(\n",
    "        total_training_steps\n",
    "        / (math.log2(math.floor(final_timesteps / initial_timesteps)) + 1)\n",
    "    )\n",
    "    num_timesteps = initial_timesteps * math.pow(\n",
    "        2, math.floor(current_training_step / total_training_steps_prime)\n",
    "    )\n",
    "    num_timesteps = min(num_timesteps, final_timesteps) + 1\n",
    "\n",
    "    return num_timesteps\n",
    "\n",
    "def lognormal_timestep_distribution(num_samples, sigmas, mean = -1.1, std = 2.0):\n",
    "    \"\"\"Draws timesteps from a lognormal distribution.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    num_samples : int\n",
    "        Number of samples to draw.\n",
    "    sigmas : Tensor\n",
    "        Standard deviations of the noise.\n",
    "    mean : float, default=-1.1\n",
    "        Mean of the lognormal distribution.\n",
    "    std : float, default=2.0\n",
    "        Standard deviation of the lognormal distribution.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Tensor\n",
    "        Timesteps drawn from the lognormal distribution.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)\n",
    "    \"\"\"\n",
    "    #sigmas = torch.flip(sigmas, dims=(0,))\n",
    "    pdf = torch.erf((torch.log(sigmas[1:]) - mean) / (std * math.sqrt(2))) - torch.erf(\n",
    "        (torch.log(sigmas[:-1]) - mean) / (std * math.sqrt(2))\n",
    "    )\n",
    "    #pdf = torch.flip(pdf, dims=(0,))\n",
    "    #print('pdf : ', pdf)\n",
    "    timesteps = torch.multinomial(pdf, num_samples, replacement=True)\n",
    "\n",
    "    return timesteps\n",
    "\n",
    "def improved_loss_weighting(sigmas):\n",
    "    \"\"\"Computes the weighting for the consistency loss.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    sigmas : Tensor\n",
    "        Standard deviations of the noise.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Tensor\n",
    "        Weighting for the consistency loss.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    [1] [Improved Techniques For Consistency Training](https://arxiv.org/pdf/2310.14189.pdf)\n",
    "    \"\"\"\n",
    "    return 1 / (sigmas[1:] - sigmas[:-1])\n",
    "\n",
    "def draw_transport_cost_per_timestep(batch_data_preds, batch_data, batch_z, sigmas_i, sigmas, steps, training_step):\n",
    "    fig, ax = plt.subplots( nrows=1, ncols=1, figsize=(4.5,3.5))\n",
    "\n",
    "    sum_per_steps = torch.zeros(len(sigmas))\n",
    "    count_per_steps = torch.zeros(len(sigmas)) + 1e-5\n",
    "\n",
    "    color_std = 'red'\n",
    "    color_GI = 'dodgerblue'\n",
    "    for i in range(len(batch_data)):\n",
    "        dist = ((batch_data_preds[i,:] - batch_z[i,:])**2).sum().cpu().numpy()\n",
    "        sigma_i = sigmas_i[i].cpu().numpy()\n",
    "        sum_per_steps[steps[i]] += dist\n",
    "        count_per_steps[steps[i]] += 1\n",
    "        if i==0:\n",
    "            ax.scatter(sigma_i, dist, alpha=1., marker='x', color=color_GI, label=\"Pointwise cost for GC\")\n",
    "        else:\n",
    "            ax.scatter(sigma_i, dist, alpha=1., marker='x', color=color_GI)\n",
    "\n",
    "    sum_per_steps_standard = torch.zeros(len(sigmas))\n",
    "    count_per_steps_standard = torch.zeros(len(sigmas)) + 1e-5\n",
    "    for i in range(len(batch_data)):\n",
    "        dist = ((batch_data[i,:] - batch_z[i,:])**2).sum().cpu().numpy()\n",
    "        sigma_i = sigmas_i[i].cpu().numpy()\n",
    "        sum_per_steps_standard[steps[i]] += dist\n",
    "        count_per_steps_standard[steps[i]] += 1\n",
    "        if i==0:\n",
    "            ax.scatter(sigma_i, dist, alpha=1., marker='+', color=color_std, label=\"Pointwise cost for IC\")\n",
    "        else:\n",
    "            ax.scatter(sigma_i, dist, alpha=1., marker='+', color=color_std)\n",
    "\n",
    "    sigma_array = []\n",
    "    mean_array = []\n",
    "    mean_standard_array = []\n",
    "    for i in range(len(sigmas)):\n",
    "        sigma_array.append(sigmas[i].item())\n",
    "        mean_array.append((sum_per_steps[i] / count_per_steps[i]).item())\n",
    "        mean_standard_array.append((sum_per_steps_standard[i] / count_per_steps_standard[i]).item())\n",
    "\n",
    "    ax.plot(sigma_array[:-1], mean_array[:-1], color=color_GI, label=\"Mean cost for GC\")\n",
    "    ax.plot(sigma_array[:-1], mean_standard_array[:-1], color=color_std, label=\"Mean cost for IC\")\n",
    "\n",
    "    ax.grid(linestyle='--')\n",
    "    ax.set_axisbelow(True)\n",
    "    ax.tick_params(axis=\"x\", direction=\"in\")\n",
    "    ax.tick_params(axis=\"y\", direction=\"in\")\n",
    "\n",
    "    plt.xlabel(r'Timestep')\n",
    "    plt.ylabel(r'Mean quadratic transport cost')\n",
    "    ax.set_title(r'Gaussians 1m-2m')\n",
    "\n",
    "    for axis in [ax.xaxis, ax.yaxis]:\n",
    "        formatter = ScalarFormatter()\n",
    "        formatter.set_scientific(False)\n",
    "        axis.set_major_formatter(formatter)\n",
    "\n",
    "    handles, legend_labels = ax.get_legend_handles_labels()\n",
    "    fig.legend(handles, legend_labels, ncol=1, bbox_to_anchor=(0.6,0.9))\n",
    "    fig.tight_layout()\n",
    "    fig.savefig('viz/transport_cost_1to2_'+str(training_step)+'.pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "def draw_arrows(batch_z, batch_preds, color):\n",
    "    for i in range(batch_z.shape[0]):\n",
    "        x_i, y_i = batch_z[i,0], batch_z[i,1]\n",
    "        pred_x_i, pred_y_i = batch_preds[i,0], batch_preds[i,1]\n",
    "        dx, dy = pred_x_i - x_i, pred_y_i - y_i\n",
    "        plt.arrow(x_i.cpu().numpy(), y_i.cpu().numpy(), dx.cpu().numpy(), dy.cpu().numpy(), alpha=0.4, \\\n",
    "            length_includes_head=True, facecolor=color, edgecolor=color,width=0.005)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ema_pytorch import EMA\n",
    "\n",
    "len_data = 10000\n",
    "batch_size = 256\n",
    "training_steps = 10000\n",
    "lr = 0.00005\n",
    "s0 = 10\n",
    "s1 = 100\n",
    "rho = 7\n",
    "sigma_min = 0.001\n",
    "sigma_max = 5\n",
    "hidden_dim = 256\n",
    "generator_induced_trajectory = False\n",
    "device = 'cuda:0'\n",
    "print_freq = 1000\n",
    "model = ConsistencyModel(sigma_min, sigma_max, hidden_dim=hidden_dim).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "model_ema = EMA(\n",
    "    model,\n",
    "    beta = 0.9999,              # exponential moving average factor\n",
    "    update_after_step = 0,    # only after this number of .update() calls will it start updating\n",
    "    update_every = 2,          # how often to actually update, to save on compute (updates every 10th .update() call)\n",
    ")\n",
    "datapoints = return_gaussians(len_data)\n",
    "dataset = torch.utils.data.TensorDataset(datapoints)\n",
    "loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "sigmas = get_sigmas_karras(s0, sigma_min, sigma_max, rho=rho)\n",
    "###### Training Loop #########\n",
    "current_training_step = 0\n",
    "while True:\n",
    "    for (idx, batch) in enumerate(loader):\n",
    "        data = batch[0].to(device)\n",
    "        batch_z = return_input_distrib(data.shape[0]).to(device)\n",
    "        current_n_step = improved_timesteps_schedule(current_training_step, training_steps,\n",
    "                                        initial_timesteps = s0, final_timesteps = s1)\n",
    "        sigmas = get_sigmas_karras(current_n_step, sigma_min, sigma_max, rho=rho)\n",
    "        steps = lognormal_timestep_distribution(len(data), sigmas)\n",
    "        loss_weights = improved_loss_weighting(sigmas)[steps].to(device)\n",
    "        sigmas_i = sigmas[steps].to(device)\n",
    "        sigmas_ip1 = sigmas[steps + 1].to(device)\n",
    "        batch_z_i =  data + sigmas_i.view(sigmas_i.shape[0],1) * batch_z\n",
    "        batch_z_ip1 = data + sigmas_ip1.view(sigmas_ip1.shape[0],1) * batch_z\n",
    "        if generator_induced_trajectory:\n",
    "            with torch.no_grad():\n",
    "                data_pred = model_ema(batch_z_i, sigmas_i)\n",
    "            batch_z_i = data_pred + sigmas_i.view(sigmas_i.shape[0],1) * batch_z\n",
    "            batch_z_ip1 = data_pred + sigmas_ip1.view(sigmas_ip1.shape[0],1) * batch_z\n",
    "        optimizer.zero_grad()\n",
    "        with torch.no_grad():\n",
    "            pred_z_i = model(batch_z_i, sigmas_i)\n",
    "        pred_z_ip1 = model(batch_z_ip1, sigmas_ip1)\n",
    "        loss = ((pred_z_ip1 - pred_z_i) ** 2).sum(dim=1)\n",
    "        loss = (loss_weights * loss).mean()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model_ema.update()\n",
    "        if (current_training_step % print_freq) == 0:\n",
    "            print('step : ', current_training_step)\n",
    "            print('loss : ', loss)\n",
    "            with torch.no_grad():\n",
    "                x_test, x_i_test, z_test, sigmas_test = data, batch_z_i, batch_z, sigmas_i\n",
    "                sigmas_max = sigmas_i * 0 + sigma_max\n",
    "                generated = model_ema(batch_z * sigmas_max.view(sigmas_i.shape[0],1), sigmas_max)\n",
    "                if generator_induced_trajectory:\n",
    "                    draw_transport_cost_per_timestep(data_pred, data, batch_z, sigmas_i, sigmas, steps, current_training_step)\n",
    "                fig, ax = plt.subplots( nrows=1, ncols=1, figsize=(4.5,3.5))\n",
    "                if generator_induced_trajectory:\n",
    "                    data_pred_test = model_ema(x_i_test, sigmas_test)\n",
    "                    tilde_x_test = data_pred_test + sigmas_test.unsqueeze(1) * z_test\n",
    "                else:\n",
    "                    data_pred_test = model_ema(x_i_test, sigmas_test)\n",
    "                ax.scatter(generated[:,0].cpu().numpy(), generated[:,1].cpu().numpy(), marker='>', alpha=0.4, label='generated',color='purple')\n",
    "                ax.scatter(data[:32,0].cpu().numpy(), data[:32,1].cpu().numpy(), marker='o', alpha=0.4, label='data', color='darkcyan')\n",
    "\n",
    "                ax.grid(linestyle='--')\n",
    "                ax.set_axisbelow(True)\n",
    "                ax.tick_params(axis=\"x\", direction=\"in\")\n",
    "                ax.tick_params(axis=\"y\", direction=\"in\")\n",
    "                ax.set_ylim([-1, 1])\n",
    "                ax.set_xlim([-1, 1])\n",
    "                plt.xlabel(r'$x$')\n",
    "                plt.ylabel(r'$y$')\n",
    "                for axis in [ax.xaxis, ax.yaxis]:\n",
    "                    formatter = ScalarFormatter()\n",
    "                    formatter.set_scientific(False)\n",
    "                    axis.set_major_formatter(formatter)\n",
    "\n",
    "                handles, legend_labels = ax.get_legend_handles_labels()\n",
    "                handles =[handles[-1]] +  handles[:-1]\n",
    "                legend_labels =[legend_labels[-1]] + legend_labels[:-1]\n",
    "                fig.legend(handles, legend_labels, ncol=2, bbox_to_anchor=(0.9,0.05), handletextpad=0.1)\n",
    "                ax.set_title(r'Gaussians 1m-2m')\n",
    "                fig.tight_layout()\n",
    "                fig.savefig('viz/generations_'+str(current_training_step)+'.pdf', bbox_inches='tight')\n",
    "                plt.show()\n",
    "\n",
    "        current_training_step += 1\n",
    "        if current_training_step == training_steps:\n",
    "            break\n",
    "    if current_training_step == training_steps:\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "consistency",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
