{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41f2ac59ae9e0807",
   "metadata": {},
   "source": [
    "# Flow Matching Example\n",
    "\n",
    "## In this notebook, we demonstrate how to create correlated noise latents that are then passed to the flow matching model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39a02827-ccad-4b7f-be9f-f40fd82b079e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Positive Definitive Requirement for General Covariance Matrix and N Images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63cbc715-35a1-41c6-ab63-d9c684dc2804",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "The generalized, full covariance matrix takes the form:\n",
    "\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "\\Sigma_1 & \\Sigma_{12} & \\cdots & \\Sigma_{1N} & \\Sigma_{1\\theta} \\\\\n",
    "\\Sigma_{12}^T & \\Sigma_2 & \\cdots & \\Sigma_{2N} & \\Sigma_{2\\theta} \\\\\n",
    "\\vdots & \\vdots & \\ddots & \\vdots & \\vdots \\\\\n",
    "\\Sigma_{1N}^T & \\Sigma_{2N}^T & \\cdots & \\Sigma_N & \\Sigma_{N\\theta} \\\\\n",
    "\\Sigma_{1\\theta}^T & \\Sigma_{2\\theta}^T & \\cdots & \\Sigma_{N\\theta}^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "=\n",
    "\\begin{pmatrix}\n",
    "v_1 I_d & \\rho_{12} I_d & \\cdots & \\rho_{1N} I_d & \\rho_{1\\theta} \\mathbf{1}_d \\\\\n",
    "\\rho_{12} I_d & v_2 I_d & \\cdots & \\rho_{2N} I_d & \\rho_{2\\theta} \\mathbf{1}_d \\\\\n",
    "\\vdots & \\vdots & \\ddots & \\vdots & \\vdots \\\\\n",
    "\\rho_{1N} I_d & \\rho_{2N} I_d & \\cdots & v_N I_d & \\rho_{N\\theta} \\mathbf{1}_d \\\\\n",
    "\\rho_{1\\theta} \\mathbf{1}_d^T & \\rho_{2\\theta} \\mathbf{1}_d^T & \\cdots & \\rho_{N\\theta} \\mathbf{1}_d^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "Where $v_i$ is the variance of image $x_i$, $\\rho_{ij}$ is the covariance between image $x_i$ and $x_j$, $\\rho_{i\\theta}$ is the covariance between image $x_i$ and $\\theta$, $d$ is the dimensionality of each image (e.g. 100 for a 10x10 image), $N$ is the number of images, and $\\sigma_\\theta^2$ is the variance of $\\theta$.\n",
    "\n",
    "We partition $\\Sigma$ into blocks:\n",
    "\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "A & B \\\\\n",
    "B^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "Where:\n",
    "\n",
    "$$\n",
    "A =\n",
    "\\begin{pmatrix}\n",
    "v_1 I_d & \\rho_{12} I_d & \\rho_{13} I_d & \\cdots & \\rho_{1N} I_d \\\\\n",
    "\\rho_{12} I_d & v_2 I_d & \\rho_{23} I_d & \\cdots & \\rho_{2N} I_d \\\\\n",
    "\\rho_{13} I_d & \\rho_{23} I_d & v_3 I_d & \\cdots & \\rho_{3N} I_d \\\\\n",
    "\\vdots & \\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "\\rho_{1N} I_d & \\rho_{2N} I_d & \\rho_{3N} I_d & \\cdots & v_N I_d\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "\n",
    "$A = M \\otimes I_d \\in \\mathbb{R}^{Nd \\times Nd}$ is the image-image block, with $M \\in \\mathbb{R}^{N \\times N}$ defined as:\n",
    "$$\n",
    "M_{ij} =\n",
    "\\begin{cases}\n",
    "v_i & \\text{if } i = j \\\\\n",
    "\\rho_{ij} & \\text{if } i \\neq j\n",
    "\\end{cases}\n",
    "$$\n",
    "\n",
    "$B \\in \\mathbb{R}^{Nd \\times 1}$ is the image–$\\theta$ cross-covariance block:\n",
    "\n",
    "$$\n",
    "B =\n",
    "\\begin{pmatrix}\n",
    "\\rho_{1\\theta} \\cdot \\mathbf{1}_d \\\\\n",
    "\\rho_{2\\theta} \\cdot \\mathbf{1}_d \\\\\n",
    "\\vdots \\\\\n",
    "\\rho_{N\\theta} \\cdot \\mathbf{1}_d\n",
    "\\end{pmatrix}\n",
    "= \\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d\n",
    "\\quad \\text{with} \\quad\n",
    "\\boldsymbol{\\rho}_\\theta =\n",
    "\\begin{pmatrix}\n",
    "\\rho_{1\\theta} \\\\\n",
    "\\rho_{2\\theta} \\\\\n",
    "\\vdots \\\\\n",
    "\\rho_{N\\theta}\n",
    "\\end{pmatrix}\n",
    "\\in \\mathbb{R}^{N \\times 1}\n",
    "$$\n",
    "\n",
    "According to [Gallier, 2019](https://www.cis.upenn.edu/~jean/schur-comp.pdf), if $A \\succ 0$ (positive definite) and invertible, then the following properties hold:\n",
    "\n",
    "$$\n",
    "\\begin{split}\n",
    "& \\text{(1)} \\quad M \\succ 0 \\quad iff \\quad A \\succ 0 \\quad \\text{and} \\quad C - B^T A^{-1} B \\succ 0 \\\\\n",
    "& \\text{(2)} \\quad \\text{If} \\quad A \\succ 0 \\quad \\text{then} \\quad M \\succeq 0 \\quad \\text{iff} \\quad C - B^T A^{-1} B \\succeq 0\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "Thus, a necessary and sufficient condition for $\\Sigma$ to be positive definite is that the Schur complement\n",
    "\n",
    "$$\n",
    "C - B^T A^{-1} B\n",
    "$$\n",
    "\n",
    "is positive, i.e.,\n",
    "\n",
    "$$\n",
    "\\sigma_\\theta^2 > B^T A^{-1} B.\n",
    "$$\n",
    "\n",
    "Using Kronecker product identities:\n",
    "$$\n",
    "A^{-1} = M^{-1} \\otimes I_d\n",
    "$$\n",
    "\n",
    "$$\n",
    "B = \\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d\n",
    "$$\n",
    "\n",
    "We compute:\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "B^T A^{-1} B\n",
    "&= (\\boldsymbol{\\rho}_\\theta^T \\otimes \\mathbf{1}_d^T)(M^{-1} \\otimes I_d)(\\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d) \\\\\n",
    "&= (\\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta)(\\mathbf{1}_d^T I_d \\mathbf{1}_d) \\\\\n",
    "&= d \\cdot \\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "Thus, the condition for positive definiteness becomes\n",
    "\n",
    "$$\n",
    "\\boxed{\n",
    "\\sigma_\\theta^2 > d \\cdot \\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta\n",
    "}\n",
    "$$\n",
    "\n",
    "Numerical issues such as rounding errors may result in eigenvalues that are very close to zero or slightly negative even if the theoretical condition holds. To mitigate this, in practice, a small $\\epsilon > 0$ is added to the diagonal of $\\Sigma$:\n",
    "\n",
    "$$\n",
    "\\Sigma \\leftarrow \\Sigma + \\epsilon I.\n",
    "$$\n",
    "\n",
    "This diagonal stabilization increases every eigenvalue by at least $\\epsilon$, ensuring that the matrix remains numerically positive definite and stable for sampling from a multivariate Gaussian distribution (or other operations such as Cholesky decompositions for inverting matrices)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b587ee66a211f9a2",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Positive Definitive Requirement for General Covariance Matrix and 2 Images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4f122608960ecc8",
   "metadata": {},
   "source": [
    "Our covariance matrix for two images and a $\\theta$ variable reduces to:\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "\\Sigma_1 & \\Sigma_{12} & \\Sigma_{1\\theta} \\\\\n",
    "\\Sigma_{12}^T & \\Sigma_2 & \\Sigma_{2\\theta} \\\\\n",
    "\\Sigma_{1\\theta}^T & \\Sigma_{2\\theta}^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "=\n",
    "\\begin{pmatrix}\n",
    "v_1 I_d & \\rho_{12} I_d & \\rho_{1\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{12} I_d & v_2 I_d & \\rho_{2\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d^T & \\rho_{2\\theta}\\,\\mathbf{1}_d^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix},\n",
    "$$\n",
    "\n",
    "\n",
    "We partition $\\Sigma$ as\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "A & B \\\\\n",
    "B^T & C\n",
    "\\end{pmatrix},\n",
    "$$\n",
    "where\n",
    "$$\n",
    "A = \\begin{pmatrix}\n",
    "v_1 I_d & \\rho_{12} I_d \\\\\n",
    "\\rho_{12} I_d & v_2 I_d\n",
    "\\end{pmatrix}, \\quad\n",
    "B = \\begin{pmatrix}\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{2\\theta}\\,\\mathbf{1}_d\n",
    "\\end{pmatrix}, \\quad\n",
    "C = \\sigma_\\theta^2.\n",
    "$$\n",
    "\n",
    "Our aforementioned requirement becomes:\n",
    "\n",
    "$$\n",
    "\\sigma_\\theta^2 > B^T A^{-1} B.\n",
    "$$\n",
    "\n",
    "Note that the block $A$ can be written as a Kronecker product:\n",
    "$$\n",
    "A = M \\otimes I_d, \\quad \\text{with } M = \\begin{pmatrix} v_1 & \\rho_{12} \\\\ \\rho_{12} & v_2 \\end{pmatrix}.\n",
    "$$\n",
    "The determinant of $M$ is\n",
    "$$\n",
    "\\det(M) = v_1v_2 - \\rho_{12}^2.\n",
    "$$\n",
    "Thus, the inverse of $M$ is\n",
    "$$\n",
    "M^{-1} = \\frac{1}{v_1v_2-\\rho_{12}^2}\\begin{pmatrix} v_2 & -\\rho_{12} \\\\ -\\rho_{12} & v_1 \\end{pmatrix},\n",
    "$$\n",
    "and hence,\n",
    "$$\n",
    "A^{-1} = M^{-1} \\otimes I_d = \\frac{1}{v_1v_2-\\rho_{12}^2}\n",
    "\\begin{pmatrix}\n",
    "v_2 I_d & -\\rho_{12} I_d \\\\\n",
    "-\\rho_{12} I_d & v_1 I_d\n",
    "\\end{pmatrix}.\n",
    "$$\n",
    "\n",
    "The cross-covariance block is given by\n",
    "$$\n",
    "B = \\begin{pmatrix} \\rho_{1\\theta}\\,\\mathbf{1}_d \\\\ \\rho_{2\\theta}\\,\\mathbf{1}_d \\end{pmatrix} \\quad \\text{and} \\quad B^T = \\begin{pmatrix} \\rho_{1\\theta}\\,\\mathbf{1}_d^T & \\rho_{2\\theta}\\,\\mathbf{1}_d^T \\end{pmatrix}.\n",
    "$$\n",
    "Then,\n",
    "$$\n",
    "A^{-1}B = \\frac{1}{v_1v_2-\\rho_{12}^2} \\begin{pmatrix}\n",
    "v_2 I_d & -\\rho_{12} I_d \\\\\n",
    "-\\rho_{12} I_d & v_1 I_d\n",
    "\\end{pmatrix}\n",
    "\\begin{pmatrix} \\rho_{1\\theta}\\,\\mathbf{1}_d \\\\ \\rho_{2\\theta}\\,\\mathbf{1}_d \\end{pmatrix}.\n",
    "$$\n",
    "Carrying out the block multiplication, we obtain\n",
    "$$\n",
    "A^{-1}B = \\frac{1}{v_1v_2-\\rho_{12}^2} \\begin{pmatrix}\n",
    "\\big(v_2\\,\\rho_{1\\theta} - \\rho_{12}\\,\\rho_{2\\theta}\\big)\\mathbf{1}_d \\\\\n",
    "\\big(v_1\\,\\rho_{2\\theta} - \\rho_{12}\\,\\rho_{1\\theta}\\big)\\mathbf{1}_d\n",
    "\\end{pmatrix}.\n",
    "$$\n",
    "Now,\n",
    "$$\n",
    "B^T A^{-1} B = \\frac{1}{v_1v_2-\\rho_{12}^2}\\left[\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d^T\\big(v_2\\,\\rho_{1\\theta} - \\rho_{12}\\,\\rho_{2\\theta}\\big)\\mathbf{1}_d + \\rho_{2\\theta}\\,\\mathbf{1}_d^T\\big(v_1\\,\\rho_{2\\theta} - \\rho_{12}\\,\\rho_{1\\theta}\\big)\\mathbf{1}_d\n",
    "\\right].\n",
    "$$\n",
    "Since $\\mathbf{1}_d^T\\mathbf{1}_d = d$, this simplifies to\n",
    "$$\n",
    "B^T A^{-1} B = \\frac{d}{v_1v_2-\\rho_{12}^2}\\left[v_2\\,\\rho_{1\\theta}^2 + v_1\\,\\rho_{2\\theta}^2 - 2\\,\\rho_{12}\\,\\rho_{1\\theta}\\,\\rho_{2\\theta}\\right].\n",
    "$$\n",
    "\n",
    "The Schur complement condition requires\n",
    "$$\n",
    "\\sigma_\\theta^2 > B^T A^{-1} B,\n",
    "$$\n",
    "Thus, the positive-definiteness condition for the full covariance matrix is\n",
    "$$\n",
    "\\boxed{\n",
    "\\sigma_\\theta^2 > \\frac{d}{v_1v_2 - \\rho_{12}^2}\\left(v_2\\,\\rho_{1\\theta}^2 + v_1\\,\\rho_{2\\theta}^2 - 2\\,\\rho_{12}\\,\\rho_{1\\theta}\\,\\rho_{2\\theta}\\right).\n",
    "}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9883c5efe5da71a0",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Positive Definitive Requirement for Images 1 and 2 being Independent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d563d513cf87b06",
   "metadata": {},
   "source": [
    "Now our covariance matrix takes the form:\n",
    "\n",
    "$$\n",
    "\\Sigma = \\begin{pmatrix}\n",
    "\\Sigma_1 & 0 & \\Sigma_{1\\theta} \\\\\n",
    "0 & \\Sigma_2 & \\Sigma_{2\\theta} \\\\\n",
    "\\Sigma_{1\\theta}^T & \\Sigma_{2\\theta}^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "Our previously derived condition\n",
    "\n",
    "$$\n",
    "B^T A^{-1} B = \\frac{1}{v_1v_2-\\rho_{12}^2}\\left[\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d^T\\big(v_2\\,\\rho_{1\\theta} - \\rho_{12}\\,\\rho_{2\\theta}\\big)\\mathbf{1}_d + \\rho_{2\\theta}\\,\\mathbf{1}_d^T\\big(v_1\\,\\rho_{2\\theta} - \\rho_{12}\\,\\rho_{1\\theta}\\big)\\mathbf{1}_d\n",
    "\\right].\n",
    "$$\n",
    "\n",
    "Reduces to\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "B^T A^{-1} B\n",
    "&= \\frac{1}{v_1v_2}\\left[\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d^T\\big(v_2\\,\\rho_{1\\theta})\\mathbf{1}_d + \\rho_{2\\theta}\\,\\mathbf{1}_d^T\\big(v_1\\,\\rho_{2\\theta}\\big)\\mathbf{1}_d\\right] \\\\\n",
    "&= d \\cdot \\left( \\frac{\\rho_{1\\theta}^2}{v_1} + \\frac{\\rho_{2\\theta}^2}{v_2} \\right)\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "\n",
    "\n",
    "Thus, the overall condition for $\\Sigma$ to be positive definite becomes\n",
    "\n",
    "$$\n",
    "\\sigma_\\theta^2 > d \\cdot \\left( \\frac{\\rho_{1\\theta}^2}{v_1} + \\frac{\\rho_{2\\theta}^2}{v_2} \\right).\n",
    "$$\n",
    "\n",
    "Extending the general form for multiple images (not just 2), we get:\n",
    "$$\n",
    "\\boxed{\n",
    "\\sigma_\\theta^2 > d \\cdot \\sum_{i=1}^N \\frac{\\rho_{i\\theta}^2}{v_i}\n",
    "}\n",
    "$$\n",
    "\n",
    "If we were to keeps things simple and let $v_i = v$ and $\\rho_{i\\theta} = \\rho$, this condition reduces to\n",
    "\n",
    "$$\n",
    "\\sigma_\\theta^2 > \\frac{N \\cdot \\rho^2\\,d}{\\sigma^2}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5792dfc4-22b9-417f-a64b-aaa591e3cf44",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Positive Definitive Requirement for General Covariance Matrix and 3 Images"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52e5e662-93d1-429c-8da8-a354526e1d6a",
   "metadata": {},
   "source": [
    "Now let's tackle one more case. 3 images with a general covariance matrix. Our covariance matrix takes the form:\n",
    "\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "\\Sigma_1 & \\Sigma_{12} & \\Sigma_{13} & \\Sigma_{1\\theta} \\\\\n",
    "\\Sigma_{12}^T & \\Sigma_2 & \\Sigma_{23} & \\Sigma_{2\\theta} \\\\\n",
    "\\Sigma_{13}^T & \\Sigma_{23}^T & \\Sigma_3 & \\Sigma_{3\\theta} \\\\\n",
    "\\Sigma_{1\\theta}^T & \\Sigma_{2\\theta}^T & \\Sigma_{3\\theta}^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}\n",
    "=\n",
    "\\begin{pmatrix}\n",
    "v_1 I_d & \\rho_{12} I_d & \\rho_{13} I_d & \\rho_{1\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{12} I_d & v_2 I_d & \\rho_{23} I_d & \\rho_{2\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{13} I_d & \\rho_{23} I_d & v_3 I_d & \\rho_{3\\theta}\\,\\mathbf{1}_d \\\\\n",
    "\\rho_{1\\theta}\\,\\mathbf{1}_d^T & \\rho_{2\\theta}\\,\\mathbf{1}_d^T & \\rho_{3\\theta}\\,\\mathbf{1}_d^T & \\sigma_\\theta^2\n",
    "\\end{pmatrix}.\n",
    "$$\n",
    "\n",
    "We partition $\\Sigma$ as:\n",
    "$$\n",
    "\\Sigma =\n",
    "\\begin{pmatrix}\n",
    "A & B \\\\\n",
    "B^T & C\n",
    "\\end{pmatrix},\n",
    "$$\n",
    "where:\n",
    "- $A = M \\otimes I_d$ is the image–image block,\n",
    "- $B = \\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d$ is the image–$\\theta$ cross-covariance vector,\n",
    "- $C = \\sigma_\\theta^2$ is the variance of $\\theta$.\n",
    "\n",
    "We define:\n",
    "$$\n",
    "M =\n",
    "\\begin{pmatrix}\n",
    "v_1 & \\rho_{12} & \\rho_{13} \\\\\n",
    "\\rho_{12} & v_2 & \\rho_{23} \\\\\n",
    "\\rho_{13} & \\rho_{23} & v_3\n",
    "\\end{pmatrix}, \\quad\n",
    "\\boldsymbol{\\rho}_\\theta =\n",
    "\\begin{pmatrix}\n",
    "\\rho_{1\\theta} \\\\\n",
    "\\rho_{2\\theta} \\\\\n",
    "\\rho_{3\\theta}\n",
    "\\end{pmatrix}.\n",
    "$$\n",
    "\n",
    "To ensure $\\Sigma \\succ 0$, we apply the Schur complement:\n",
    "$$\n",
    "\\sigma_\\theta^2 > B^T A^{-1} B.\n",
    "$$\n",
    "\n",
    "Using Kronecker identities:\n",
    "$$\n",
    "A^{-1} = M^{-1} \\otimes I_d,\n",
    "\\quad\n",
    "B = \\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d,\n",
    "$$\n",
    "\n",
    "we get:\n",
    "$$\n",
    "B^T A^{-1} B\n",
    "= (\\boldsymbol{\\rho}_\\theta^T \\otimes \\mathbf{1}_d^T)(M^{-1} \\otimes I_d)(\\boldsymbol{\\rho}_\\theta \\otimes \\mathbf{1}_d)\n",
    "= (\\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta)(\\mathbf{1}_d^T \\mathbf{1}_d)\n",
    "= d \\cdot \\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta.\n",
    "$$\n",
    "\n",
    "To compute $\\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta$, we use the explicit inverse of a symmetric $3 \\times 3$ matrix:\n",
    "$$\n",
    "M^{-1} = \\frac{1}{\\det M}\n",
    "\\begin{pmatrix}\n",
    "A & B & C \\\\\n",
    "B & D & E \\\\\n",
    "C & E & F\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "where:\n",
    "$$\n",
    "\\begin{aligned}\n",
    "A &= v_2 v_3 - \\rho_{23}^2, \\\\\n",
    "D &= v_1 v_3 - \\rho_{13}^2, \\\\\n",
    "F &= v_1 v_2 - \\rho_{12}^2, \\\\\n",
    "B &= \\rho_{13} \\rho_{23} - \\rho_{12} v_3, \\\\\n",
    "C &= \\rho_{12} \\rho_{23} - \\rho_{13} v_2, \\\\\n",
    "E &= \\rho_{12} \\rho_{13} - \\rho_{23} v_1.\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "Then,\n",
    "$$\n",
    "\\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta\n",
    "= \\frac{1}{\\det M}\n",
    "\\begin{pmatrix}\n",
    "\\rho_{1\\theta} & \\rho_{2\\theta} & \\rho_{3\\theta}\n",
    "\\end{pmatrix}\n",
    "\\begin{pmatrix}\n",
    "A & B & C \\\\\n",
    "B & D & E \\\\\n",
    "C & E & F\n",
    "\\end{pmatrix}\n",
    "\\begin{pmatrix}\n",
    "\\rho_{1\\theta} \\\\\n",
    "\\rho_{2\\theta} \\\\\n",
    "\\rho_{3\\theta}\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "Carrying out the multiplication:\n",
    "$$\n",
    "\\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta\n",
    "= \\frac{1}{\\det M} \\left[\n",
    "A\\,\\rho_{1\\theta}^2 + D\\,\\rho_{2\\theta}^2 + F\\,\\rho_{3\\theta}^2\n",
    "+ 2B\\,\\rho_{1\\theta}\\rho_{2\\theta}\n",
    "+ 2C\\,\\rho_{1\\theta}\\rho_{3\\theta}\n",
    "+ 2E\\,\\rho_{2\\theta}\\rho_{3\\theta}\n",
    "\\right]\n",
    "$$\n",
    "\n",
    "Thus:\n",
    "$$\n",
    "B^T A^{-1} B = d \\cdot \\boldsymbol{\\rho}_\\theta^T M^{-1} \\boldsymbol{\\rho}_\\theta\n",
    "= \\frac{d}{\\det M} \\left[\n",
    "A\\,\\rho_{1\\theta}^2 + D\\,\\rho_{2\\theta}^2 + F\\,\\rho_{3\\theta}^2\n",
    "+ 2B\\,\\rho_{1\\theta}\\rho_{2\\theta}\n",
    "+ 2C\\,\\rho_{1\\theta}\\rho_{3\\theta}\n",
    "+ 2E\\,\\rho_{2\\theta}\\rho_{3\\theta}\n",
    "\\right]\n",
    "$$\n",
    "\n",
    "Finally, our condition becomes:\n",
    "\n",
    "\n",
    "$$\n",
    "\\boxed{\n",
    "\\begin{aligned}\n",
    "\\sigma_\\theta^2 &> \\frac{d}{\\det M} \\Big[ \n",
    "(v_2 v_3 - \\rho_{23}^2)\\rho_{1\\theta}^2 +\n",
    "(v_1 v_3 - \\rho_{13}^2)\\rho_{2\\theta}^2 +\n",
    "(v_1 v_2 - \\rho_{12}^2)\\rho_{3\\theta}^2 \\\\\n",
    "&\\quad + 2(\\rho_{13} \\rho_{23} - \\rho_{12} v_3)\\rho_{1\\theta} \\rho_{2\\theta}\n",
    "+ 2(\\rho_{12} \\rho_{23} - \\rho_{13} v_2)\\rho_{1\\theta} \\rho_{3\\theta} + 2(\\rho_{12} \\rho_{13} - \\rho_{23} v_1)\\rho_{2\\theta} \\rho_{3\\theta}\n",
    "\\Big]\n",
    "\\end{aligned}\n",
    "}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "206783f3-2b78-4d0a-9cca-e1505b90ad86",
   "metadata": {},
   "source": [
    "## Constraints on $\\Sigma$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "119c8db50ee3491",
   "metadata": {},
   "source": [
    "Let us assume that $\\Sigma_{12} \\in \\mathbb{R}^d$ is a column vector:\n",
    "$$\n",
    "\\Sigma_{12} = \\begin{bmatrix}\n",
    "a_1 \\\\\n",
    "a_2 \\\\\n",
    "\\vdots \\\\\n",
    "a_d\n",
    "\\end{bmatrix}.\n",
    "$$\n",
    "Then the product\n",
    "$$\n",
    "AA^T\n",
    "$$\n",
    "is a $d \\times d$ matrix given by the outer product of $A$ with itself.\n",
    "\n",
    "Since\n",
    "$$\n",
    "(AA^T)^T = AA^T,\n",
    "$$\n",
    "the matrix $AA^T$ is **symmetric**.\n",
    "\n",
    "Because $AA^T$ is the outer product of the vector $A$ with itself, its rank is at most 1. In fact, if $A$ is nonzero,\n",
    "$$\n",
    "\\operatorname{rank}(AA^T) = 1.\n",
    "$$\n",
    "Thus, for any $d > 1$, $AA^T$ is not full rank.\n",
    "\n",
    "For any $x \\in \\mathbb{R}^d$, we have\n",
    "$$\n",
    "x^T (AA^T) x = (A^T x)^2 \\geq 0.\n",
    "$$\n",
    "Thus, $AA^T$ is **positive semidefinite**. However, since its rank is 1 for $d > 1$, it is not **positive definite** (it will have $d-1$ zero eigenvalues).\n",
    "\n",
    "**Conditions for a Rotation Matrix**\n",
    "\n",
    "A rotation matrix $R \\in \\mathbb{R}^{d \\times d}$ must satisfy the following:\n",
    "\n",
    "1. **Orthogonality:**\n",
    "   $$\n",
    "   R^T R = I_d,\n",
    "   $$\n",
    "   which means $R$ is invertible and full rank.\n",
    "\n",
    "2. **Determinant:**\n",
    "   $$\n",
    "   \\det(R) = 1 \\quad \\text{(for a proper rotation)}.\n",
    "   $$\n",
    "\n",
    "Note that rotation matrices are generally not symmetric (except for trivial cases such as $R = I_d$). They represent full-rank linear transformations that preserve the Euclidean norm.\n",
    "\n",
    "Since $\\operatorname{rank}(AA^T) = 1$ (assuming $A \\neq 0$ and $d > 1$), $AA^T$ is not full rank. A rotation matrix, by definition, must be full rank (invertible) as it must satisfy $R^T R = I_d$.\n",
    "\n",
    "\n",
    "\n",
    "Given that $A$ is a $d$-dimensional vector, consider the matrix $AA^T$:\n",
    "\n",
    "- **Symmetry vs. Orthogonality:**  \n",
    "  Although $AA^T$ is symmetric, a non-trivial rotation matrix in $\\mathbb{R}^d$ is not symmetric. For instance, a 2D rotation matrix\n",
    "  $$\n",
    "  R(\\theta)= \\begin{bmatrix}\n",
    "  \\cos \\theta & -\\sin \\theta \\\\\n",
    "  \\sin \\theta & \\cos \\theta\n",
    "  \\end{bmatrix}\n",
    "  $$\n",
    "  is not symmetric unless $\\theta = 0$ (or a multiple of $\\pi$).\n",
    "\n",
    "\n",
    "\n",
    "- **Positive Definiteness:**  \n",
    "  For $AA^T$ to be positive definite, every eigenvalue must be strictly positive. However, as a rank-1 matrix in $\\mathbb{R}^d$ (with $d > 1$), $AA^T$ has $d-1$ zero eigenvalues. Thus, it is only positive semidefinite.\n",
    "\n",
    "Thus, the matrix $AA^T$ for a $d$-dimensional vector $A$ does **not** satisfy the requirements to be a rotation matrix because:\n",
    "- It is not full rank (it has rank 1, so it is not invertible).\n",
    "- It fails the orthogonality condition needed for rotation matrices.\n",
    "- It is only positive semidefinite rather than positive definite (when $d > 1$).\n",
    "\n",
    "\n",
    "Thus, if $A \\in \\mathbb{R}^d$ is a column vector (with $d > 1$), then $AA^T$ is a $d \\times d$ symmetric matrix that is of rank 1 and positive semidefinite. This construction fails to yield a rotation matrix because a rotation matrix must be full rank, orthogonal, and satisfy $\\det(R)=1$, properties that are not met by $AA^T$ in this scenario.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63c122a1e7b2287b",
   "metadata": {},
   "source": [
    "## Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3b5104fddd362b6",
   "metadata": {},
   "source": [
    "Importing Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc80caa2d987e632",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "from models.model_configs import instantiate_model\n",
    "import torch\n",
    "from training.eval_loop import CFGScaledModel\n",
    "from flow_matching.solver.ode_solver import ODESolver\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afd39bb4ad93fbaf",
   "metadata": {},
   "source": [
    "Loading our Trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54b2383420867d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = Path(\"./output_dir/checkpoint-1899.pth\")\n",
    "args_filepath = checkpoint_path.parent / 'args.json'\n",
    "with open(args_filepath, 'r') as f:\n",
    "    args_dict = json.load(f)\n",
    "\n",
    "model = instantiate_model(architechture=args_dict['dataset'], is_discrete='discrete_flow_matching' in args_dict  and args_dict['discrete_flow_matching'],\n",
    "                          use_ema=args_dict['use_ema'])\n",
    "checkpoint = torch.load(checkpoint_path, map_location=\"cpu\", weights_only=False)\n",
    "model.load_state_dict(checkpoint[\"model\"])\n",
    "model.train(False)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(\"Number of GPUs being used:\", torch.cuda.device_count())\n",
    "model.to(device=device)\n",
    "\n",
    "cfg_weighted_model = CFGScaledModel(model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0df1a07f42e074e",
   "metadata": {},
   "source": [
    "Creating our ODE Solver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56988bdca71d162b",
   "metadata": {},
   "outputs": [],
   "source": [
    "solver = ODESolver(velocity_model=cfg_weighted_model)\n",
    "ode_opts = args_dict['ode_options']\n",
    "ode_opts[\"method\"] = args_dict['ode_method']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7121220e0fa76df",
   "metadata": {},
   "source": [
    "Setting the Resolution of the Noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b0dc1a6bc5ecf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the sampling resolution corresponding to the model\n",
    "if 'train_blurred_64' in args_dict['data_path'] and args_dict['dataset'] == 'imagenet':\n",
    "    sample_resolution = 64\n",
    "elif 'train_blurred_32' in args_dict['data_path'] or args_dict['dataset'] == 'cifar10':\n",
    "    sample_resolution = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f31a82f01b335a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_mutual_info_covariance(batch_size=2, dim=3 * 32 * 32, cov_images=None, cov_theta=None, theta_var=1.0,\n",
    "                                   epsilon=1e-5):\n",
    "    \"\"\"\n",
    "    This function creates a covariance matrix for a multivariate Gaussian distribution with specific properties.\n",
    "    :param batch_size: Number of images to generate.\n",
    "    :param dim: Dimension of each image (e.g., 3 * 32 * 32 for CIFAR-10).\n",
    "    :param cov_images: 2D array-like of shape (batch_size, batch_size) specifying the covariance between each pair of\n",
    "                       images. Diagonal elements are the variances for each image and off-diagonals are the covariances.\n",
    "    :param cov_theta: List or scalar specifying the cross-covariance between each image and theta.\n",
    "                      If scalar, all images will have the same covariance with theta.\n",
    "    :param theta_var: Scalar, variance for theta.\n",
    "    :param epsilon: Small value added to the diagonal for numerical stability.\n",
    "    :return: mu (mean vector) and Sigma (covariance matrix).\n",
    "    \"\"\"\n",
    "    import numpy as np\n",
    "\n",
    "    # Process cov_theta.\n",
    "    if cov_theta is None:\n",
    "        cov_theta = [0.5] * batch_size\n",
    "    elif np.isscalar(cov_theta):\n",
    "        cov_theta = [cov_theta] * batch_size\n",
    "    elif len(cov_theta) != batch_size:\n",
    "        raise ValueError(\"Length of cov_theta must match batch_size.\")\n",
    "\n",
    "    # Process cov_images.\n",
    "    if cov_images is None:\n",
    "        # Default: images are independent with unit variance.\n",
    "        cov_images = np.eye(batch_size)\n",
    "    else:\n",
    "        cov_images = np.array(cov_images)\n",
    "        if cov_images.shape != (batch_size, batch_size):\n",
    "            raise ValueError(\"cov_images must have shape (batch_size, batch_size).\")\n",
    "\n",
    "    # Build the image covariance block.\n",
    "    image_blocks = []\n",
    "    for i in range(batch_size):\n",
    "        row_blocks = []\n",
    "        for j in range(batch_size):\n",
    "            # Each block is (cov_images[i,j]) * I_d.\n",
    "            row_blocks.append(cov_images[i, j] * torch.eye(dim, device=device))\n",
    "        row = torch.cat(row_blocks, dim=1)\n",
    "        image_blocks.append(row)\n",
    "    Sigma_images = torch.cat(image_blocks, dim=0)\n",
    "\n",
    "    # Theta's block: 1x1 matrix.\n",
    "    Theta_block = torch.tensor([[theta_var]], device=device)\n",
    "\n",
    "    # Form the full covariance matrix.\n",
    "    Sigma = torch.block_diag(Sigma_images, Theta_block)\n",
    "\n",
    "    # Fill in the cross-covariance blocks between each image and theta.\n",
    "    ones_image = torch.ones(dim, device=device)\n",
    "    for i in range(batch_size):\n",
    "        Sigma[i * dim:(i + 1) * dim, -1] = cov_theta[i] * ones_image\n",
    "        Sigma[-1, i * dim:(i + 1) * dim] = cov_theta[i] * ones_image\n",
    "\n",
    "    # Add a small epsilon to the diagonal for numerical stability.\n",
    "    Sigma += epsilon * torch.eye(Sigma.shape[0], device=device)\n",
    "\n",
    "    # Create mean vector.\n",
    "    mu = torch.zeros(Sigma.shape[0], device=device)\n",
    "    return mu, Sigma\n",
    "\n",
    "\n",
    "def gaussian_mutual_noise(batch_size=2, channels=3, cov_images=None, cov_theta=None, theta_var=1, epsilon=0, verbose=0, seed=None, filename=0):\n",
    "    \"\"\"\n",
    "    This function generates a batch of correlated images from a multivariate Gaussian distribution where the\n",
    "    off-diagonal elements of the covariance matrix are controlled by a coefficient, rho.\n",
    "    :param batch_size: Number of images to generate.\n",
    "    :param rho: Correlation coefficient between noise vectors (0 = independent, 0.999 = perfectly correlated)\n",
    "    :param verbose: If > 0, displays generate noise latents.\n",
    "    :param seed: Random seed for reproducibility.\n",
    "    :param filename: Suffix for the output filename.\n",
    "    :return: None. The function:\n",
    "             1. Creates correlated noise latents from a multivariate Gaussian distribution.\n",
    "             2. Generates images using flow matching from these noise latents.\n",
    "    \"\"\"\n",
    "\n",
    "    if seed is None:\n",
    "        seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "    dim = channels * sample_resolution * sample_resolution  # 3072D per image\n",
    "    # dim = 3\n",
    "\n",
    "    mu, Sigma = create_mutual_info_covariance(\n",
    "        batch_size=batch_size,\n",
    "        dim=dim,  # For illustration, each image is 3-dimensional.\n",
    "        cov_images=cov_images,\n",
    "        cov_theta=cov_theta,\n",
    "        theta_var=theta_var,\n",
    "        epsilon=epsilon\n",
    "    )\n",
    "\n",
    "    # print(Sigma)\n",
    "\n",
    "    # Sample from the multivariate normal distribution\n",
    "    mvn = torch.distributions.MultivariateNormal(mu, covariance_matrix=Sigma)\n",
    "    samples = mvn.sample()\n",
    "\n",
    "    # Extract and reshape multiple images\n",
    "    noise_vectors = []\n",
    "    for i in range(batch_size):\n",
    "        noise_flat = samples[i * dim:(i + 1) * dim]\n",
    "        noise = noise_flat.reshape(channels, sample_resolution, sample_resolution)\n",
    "        noise_vectors.append(noise)\n",
    "\n",
    "    # Extract Theta (last element of the sample)\n",
    "    theta_value = samples[-1].item()\n",
    "\n",
    "\n",
    "    # Stack all noise tensors into a batch\n",
    "    x_0 = torch.stack(noise_vectors, dim=0)\n",
    "\n",
    "\n",
    "    if verbose > 0:\n",
    "        # Visualize the latent noise\n",
    "        plt.figure(figsize=(14, 4))\n",
    "\n",
    "        for i in range(batch_size):\n",
    "            plt.subplot(1, batch_size, i + 1)\n",
    "            img = x_0[i].cpu().permute(1, 2, 0).numpy() # Convert tensor to numpy for visualization\n",
    "            img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0,1] for visualization\n",
    "            plt.imshow(img)\n",
    "\n",
    "            plt.title(f\"Noise {i + 1}\")\n",
    "            plt.axis('off')\n",
    "\n",
    "        plt.suptitle(\"Noise Latents (to be passed through Flow Matching)\", fontsize=24)\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.subplots_adjust(bottom=0.14)  # Make room for the legend and title\n",
    "        plt.show()\n",
    "\n",
    "    labels = torch.tensor(list(range(batch_size)), dtype=torch.int32,\n",
    "                          device=device)  # Required to run the model, but not considered.\n",
    "    time_grid = torch.linspace(0, 1, 10).to(device=device)\n",
    "    synthetic_samples = solver.sample(\n",
    "        time_grid=time_grid,\n",
    "        x_init=x_0,\n",
    "        method=args_dict['ode_method'],\n",
    "        atol=args_dict['ode_options']['atol'] if 'atol' in args_dict['ode_options'] else None,\n",
    "        rtol=args_dict['ode_options']['rtol'] if 'rtol' in args_dict['ode_options'] else None,\n",
    "        step_size=args_dict['ode_options']['step_size'] if 'step_size' in args_dict['ode_options'] else None,\n",
    "        cfg_scale=args_dict['cfg_scale'],\n",
    "        label=labels,\n",
    "        return_intermediates=False,\n",
    "    )\n",
    "    # Scaling to [0, 1] from [-1, 1]\n",
    "    synthetic_samples = torch.clamp(\n",
    "        synthetic_samples * 0.5 + 0.5, min=0.0, max=1.0\n",
    "    )\n",
    "    synthetic_samples = torch.floor(synthetic_samples * 255) / 255.0\n",
    "\n",
    "    output_path = os.path.join(\"output_dir\", \"generated_examples_CIFAR10\")\n",
    "    os.makedirs(output_path, exist_ok=True)\n",
    "    plt.figure(figsize=(15, 3))\n",
    "    for j in range(batch_size):\n",
    "        plt.subplot(1, batch_size, j + 1)  # 1 row, batch_size columns\n",
    "        image = synthetic_samples[j].cpu().permute(1, 2, 0).numpy()\n",
    "        plt.imshow(image)\n",
    "        plt.axis('off')\n",
    "\n",
    "    # time_value = time_grid[-1].item()  # Use the last time step (index 9)\n",
    "\n",
    "    plt.suptitle(\n",
    "        r'Images Generated (Cov(X,Y) = ' + f'{cov_images}, ' +\n",
    "        r'Cov($\\theta$) = ' + f'{cov_theta}' + ')',\n",
    "        fontsize=20\n",
    "    )\n",
    "    \n",
    "    plt.tight_layout()\n",
    "\n",
    "    # Save as high-quality PDF in the specified directory\n",
    "    filename = f'Gaussian_Mutual_Info_Horizontal_{filename}.pdf'\n",
    "    full_path = os.path.join(output_path, filename)\n",
    "    plt.savefig(full_path, format='pdf', bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    print (f\"Theta Value: {theta_value}\")\n",
    "    print(f\"Generated Using Seed: {seed}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b25b70db6d56e019",
   "metadata": {},
   "source": [
    "Let's generate examples for a variety of $\\rho$. We will show the noise latents once just for visualization purposes then hide them for subsequent examples. The larger the $\\rho$, the more correlated the noise will be."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c5ebe4a226b8669",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "# Suppress specific FutureWarning before running your code\n",
    "warnings.filterwarnings(\"ignore\",\n",
    "                       message=\".*torch.cuda.amp.autocast.*\",\n",
    "                       category=FutureWarning)\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0.9, 0.9], theta_var=6000, epsilon=0, verbose=1, filename=0,\n",
    "                      cov_images=[[1.0, 0.0],\n",
    "                                  [0.0, 1.0]]\n",
    "                      )\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0.9, 0.9], theta_var=6000, epsilon=0, verbose=1, filename=0,\n",
    "                      cov_images=[[1.0, 0.0],\n",
    "                                  [0.0, 1.0]]\n",
    "                      )\n",
    "\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0, 0], theta_var=6000, epsilon=0, verbose=1, filename=1,\n",
    "                      cov_images=[[1.0, 0.999],\n",
    "                                  [0.999, 1.0]]\n",
    "                      )\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0, 0], theta_var=6000, epsilon=0, verbose=1, filename=1,\n",
    "                      cov_images=[[1.0, 0.999],\n",
    "                                  [0.999, 1.0]]\n",
    "                      )\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0.5, 0.5], theta_var=6000, epsilon=0, verbose=1, filename=1,\n",
    "                      cov_images=[[1.0, 0.5],\n",
    "                                  [0.5, 1.0]]\n",
    "                      )\n",
    "\n",
    "gaussian_mutual_noise(batch_size=2, cov_theta=[0.5, 0.5], theta_var=6000, epsilon=0, verbose=1, filename=1,\n",
    "                      cov_images=[[1.0, 0.5],\n",
    "                                  [0.5, 1.0]]\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e291560c81be9bd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "    gaussian_mutual_noise(batch_size=2, cov_theta=[0, 0], theta_var=6000, epsilon=0, verbose=1, filename=1,\n",
    "                      cov_images=[[1.0, i*0.1],\n",
    "                                  [i*0.1, 1.0]]\n",
    "                      )\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3fbd08390e8a801",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f4e7121909363571",
   "metadata": {},
   "source": [
    "## Testing non-diagonal submatrices"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55543a11d2cc83af",
   "metadata": {},
   "source": [
    "Consider a covariance matrix for two images where the diagonal blocks are given by\n",
    "$$\n",
    "\\Sigma_1 = v_1 I_d \\quad \\text{and} \\quad \\Sigma_2 = v_2 I_d,\n",
    "$$\n",
    "\n",
    "and the off-diagonal block is replaced by a constant matrix:\n",
    "\n",
    "$$\n",
    "\\Sigma_{12} = \\alpha J_d,\n",
    "$$\n",
    "\n",
    "where $J_d$ is the $d\\times d$ matrix with every entry equal to 1 and $I_d$ is the $d\\times d$ identity matrix,\n",
    "\n",
    "Note that $J_d$ is a rank-one matrix with one nonzero eigenvalue equal to $d$ (and $d-1$ eigenvalues equal to 0). In particular, if we consider the subspace spanned by the all-ones vector\n",
    "\n",
    "$$\n",
    "u = \\begin{pmatrix} 1 \\\\ 1 \\\\ \\vdots \\\\ 1 \\end{pmatrix} \\quad \\text{with} \\quad u^T u = d,\n",
    "$$\n",
    "\n",
    "then $J_d$ acts on $u$ as\n",
    "\n",
    "$$\n",
    "J_d\\, u = (u u^T) u = u (u^T u) = d\\, u.\n",
    "$$\n",
    "\n",
    "Thus, in the one-dimensional subspace spanned by $u$, the $d\\times d$ block $\\alpha J_d$ is effectively equivalent to the scalar $\\alpha d$.\n",
    "\n",
    "Therefore, the effective $2 \\times 2$ image covariance matrix (in this reduced subspace) becomes\n",
    "\n",
    "$$\n",
    "\\tilde{A} = \\begin{pmatrix}\n",
    "v_1 & \\alpha d \\\\\n",
    "\\alpha d & v_2\n",
    "\\end{pmatrix}.\n",
    "$$\n",
    "\n",
    "A necessary and sufficient condition for $\\tilde{A}$ to be positive definite is that its determinant is positive:\n",
    "\n",
    "$$\n",
    "\\det(\\tilde{A}) = v_1v_2 - (\\alpha d)^2 > 0.\n",
    "$$\n",
    "\n",
    "This implies:\n",
    "\n",
    "$$\n",
    "v_1v_2 > \\alpha^2 d^2.\n",
    "$$\n",
    "\n",
    "This condition ensures that the full image covariance block remains positive definite after modifying the off-diagonal blocks to be constant.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac7a26c424d0161a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
