{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Filtering images with the CSF\n",
    "\n",
    "To use the CSF as a filter properly, we need to take into consideration the actual size at which images were presented.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "from typing import Callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def human_csf(sf: np.ndarray, tf: float = 5.0, eps: float = 1e-7) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Calculate the human CSF for a given spatial frequency, at a given temporal frequency.\n",
    "\n",
    "    :param sf: spatial frequency in cycles per degree\n",
    "    :param tf: temporal frequency in Hz (default: 5 Hz = 200ms presentation time)\n",
    "    :param eps: epsilon for numerical stability\n",
    "    :return: normalized CSF value\n",
    "    \"\"\"\n",
    "\n",
    "    velocity = tf * (1.0 / (sf+eps))\n",
    "    alpha = 1 * (2 * np.pi * sf)\n",
    "    k = 6.1 + 7.3 * np.abs(np.log10(velocity / 3))**3\n",
    "    alphaMax = 45.9 / (velocity + 2)\n",
    "    csf = ((alpha**2) * (k * velocity)) * np.exp(-2 * alpha * (1.0 / alphaMax))\n",
    "\n",
    "    norm =  csf / (np.max(csf) + eps)\n",
    "\n",
    "    # manually set DC component to 1, because we don't care about luminance\n",
    "    if isinstance(csf, np.ndarray):\n",
    "        norm[0,0] = 1.0\n",
    "\n",
    "    return norm\n",
    "\n",
    "def human_mtf(f: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Calculate the MTF of a human eye.\n",
    "\n",
    "    :param f: the frequency in cycles / degree\n",
    "    :return: MTF value\n",
    "    \"\"\"\n",
    "    s_o = 82.7\n",
    "\n",
    "    f = np.minimum(np.maximum(f, -s_o), s_o)\n",
    "\n",
    "    # parameters based on interferometry, from Williams et al 1994\n",
    "    alpha = 0.1212\n",
    "    w_1 = 0.3481\n",
    "    w_2 = 0.6519\n",
    "\n",
    "    abs_f = np.abs(f) / s_o\n",
    "    D = (2/np.pi) * (np.arccos(abs_f) - abs_f * np.sqrt(1 - abs_f**2))\n",
    "\n",
    "    return D * (w_1 + w_2 * np.exp(-alpha * np.abs(f)))\n",
    "\n",
    "\n",
    "def gaussian_filter(f: float, sigma: float = 2.5) -> float:\n",
    "    \"\"\"\n",
    "    Calculate the MTF of a Gaussian low-pass filter.\n",
    "\n",
    "    :param f: frequency in cycles per degree\n",
    "    :param sigma: standard deviation of the Gaussian in pixels\n",
    "    :return: MTF value\n",
    "    \"\"\"\n",
    "    return np.exp(-2 * np.pi**2 * sigma**2 * (f / 85.16052813036795)**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_2d_filter(filter_func: Callable):\n",
    "    \"\"\"\n",
    "    Produces the 2D CSF specturm @200ms for an image presented at 3° of visual angle.\n",
    "\n",
    "    :param filter_func: the modulation to apply to each frequency\n",
    "    \"\"\"\n",
    "\n",
    "    deg_per_px = 3.0 / 256\n",
    "\n",
    "    # Frequency axes in cycles/degree directly\n",
    "    fx = np.fft.fftfreq(256, d=deg_per_px)\n",
    "    fy = np.fft.fftfreq(256, d=deg_per_px)\n",
    "    FY, FX = np.meshgrid(fy, fx, indexing='ij')\n",
    "    F = np.sqrt(FX**2 + FY**2)\n",
    "\n",
    "    # apply the filter\n",
    "    H = filter_func(F)\n",
    "\n",
    "    return H\n",
    "\n",
    "\n",
    "def apply_filter_primitive(img: Image.Image, filter_func: Callable[np.ndarray, np.ndarray]) -> Image:\n",
    "    \"\"\"\n",
    "    Apply a given filter to a given image.\n",
    "\n",
    "    :param img: the image to be filtered \n",
    "    :param filter_func: the filter to be applied\n",
    "\n",
    "    :return: the filtered image\n",
    "    \"\"\"\n",
    "    img = np.array(img)\n",
    "    img = img.astype(np.float64)\n",
    "    img = img / 255.0\n",
    "\n",
    "    img2 = np.zeros_like(img)\n",
    "\n",
    "    H = get_2d_filter(filter_func)\n",
    "\n",
    "\n",
    "    if len(img.shape) == 3:\n",
    "        for i in range(img.shape[2]):\n",
    "            l = img[:, :, i]\n",
    "            L = np.fft.fft2(l)\n",
    "            l2 = np.fft.ifft2(L * H).real\n",
    "            img2[:, :, i] = l2\n",
    "\n",
    "        img2 = np.clip(img2, 0, 1)\n",
    "    else:\n",
    "        L = np.fft.fft2(img)\n",
    "        img2 = np.fft.ifft2(L * H).real\n",
    "        l2 = np.clip(l2, 0, 1)\n",
    "\n",
    "    img2 = (img2 * 255).astype(np.uint8)\n",
    "\n",
    "    return img2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Demo of how to apply the filter\n",
    "\n",
    "img = Image.open('some_example.png')\n",
    "\n",
    "# apply the filters\n",
    "csf_result = apply_filter_primitive(img, human_csf)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zeus",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
