{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2c587a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3af1d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multivariate_normal(z, mu, sigma, mq):\n",
    "    \"\"\"compute multivaraite normal base distribution using cholesky formula\"\"\"\n",
    "    L = torch.linalg.cholesky(sigma)              # [..., D, D]\n",
    "    diff = z - mu                                 # [..., D]\n",
    "    y = torch.linalg.solve_triangular(L, diff.unsqueeze(-1), upper=False).squeeze(-1)\n",
    "    quad = (y**2).sum(-1)\n",
    "    ldj = torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1)\n",
    "    D = mq.sum(-1)\n",
    "    log_prob = -0.5 * (D * math.log(2*math.pi) + quad) - ldj\n",
    "    return log_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "12f35d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = torch.tensor([1,1,0])\n",
    "cov = torch.tensor([[1,0.5,0],[0.5,1,0],[0,0,1]])\n",
    "x = torch.tensor([0.2,0.5,0])\n",
    "mask = torch.tensor([1,1,0]).bool()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5fe70bfe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-2.0207)\n"
     ]
    }
   ],
   "source": [
    "print(multivariate_normal(x, mean, cov, mask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "296f023c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([6, 4]) torch.Size([6, 4, 4]) torch.Size([1000, 6, 4]) torch.Size([2, 3, 1000, 4])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# Example shapes\n",
    "B, D, K = 2, 3, 4\n",
    "num_samples = 1000\n",
    "\n",
    "# Random mean and covariance for demonstration\n",
    "mean = torch.randn(B, D, K)\n",
    "cov = torch.zeros(B, D, K, K)\n",
    "for b in range(B):\n",
    "    for d in range(D):\n",
    "        A = torch.randn(K, K)\n",
    "        cov[b, d] = A @ A.T + 1e-5 * torch.eye(K)  # Make positive definite\n",
    "\n",
    "# Flatten B and D to create a batch of multivariate normals\n",
    "mean_flat = mean.view(B * D, K)\n",
    "cov_flat = cov.view(B * D, K, K)\n",
    "\n",
    "# Create distributions\n",
    "mvn = torch.distributions.MultivariateNormal(mean_flat, covariance_matrix=cov_flat)\n",
    "\n",
    "# Sample\n",
    "samples_flat = mvn.sample((num_samples,))  # shape [num_samples, B*D, K]\n",
    "\n",
    "# Permute and reshape to [B, D, num_samples, K]\n",
    "samples = samples_flat.permute(1, 0, 2).view(B, D, num_samples, K)\n",
    "\n",
    "print(mean_flat.shape, cov_flat.shape,samples_flat.shape, samples.shape)  # torch.Size([B, D, 1000, K])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1e43af7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.3333, -0.6667,  0.0000],\n",
      "        [-0.6667,  1.3333,  0.0000],\n",
      "        [ 0.0000,  0.0000,  1.0000]]) tensor([-0.8000, -0.5000,  0.0000]) tensor([[-0.7333, -0.1333,  0.0000]]) tensor([0.6533])\n"
     ]
    }
   ],
   "source": [
    "inv_cov = torch.inverse(cov)\n",
    "err = x - mean\n",
    "p1 = err[None,:]@inv_cov\n",
    "p2 = p1@err\n",
    "print(inv_cov, err, p1, p2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ef8ee1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "inv_cov = torch.inverse(cov)\n",
    "err = x - mean\n",
    "p1 = err[None,:]@inv_cov\n",
    "p2 = p1@err\n",
    "print(inv_cov, err, p1, p2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f5cba16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 2.0025,  2.0634, -0.6963,  0.4484],\n",
      "         [-0.3921, -0.3099, -0.6954, -1.3219]],\n",
      "\n",
      "        [[-0.1683,  2.5200, -0.0513,  0.3020],\n",
      "         [-1.8195, -1.0635,  0.4651,  0.2344]],\n",
      "\n",
      "        [[-0.4483,  0.6659,  1.0655,  0.4459],\n",
      "         [-0.4508,  0.5724,  1.1171,  0.4814]]]) tensor([[[ 2.0025,  2.0634, -0.6963,  0.4484],\n",
      "         [-0.3921, -0.3099, -0.6954, -1.3219]],\n",
      "\n",
      "        [[-0.1683,  2.5200, -0.0513,  0.3020],\n",
      "         [-1.8195, -1.0635,  0.4651,  0.2344]],\n",
      "\n",
      "        [[-0.4483,  0.6659,  1.0655,  0.4459],\n",
      "         [-0.4508,  0.5724,  1.1171,  0.4814]]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "B, D, K = 3,2,4\n",
    "\n",
    "X = torch.randn((B,D,K))\n",
    "X_flattened = X.view(-1, K)\n",
    "X_new = X_flattened.view(B,D,K)\n",
    "\n",
    "print(X, X_new)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "linodenet",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
