{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Moment propagation and first covering attempt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook describes our attempt at using the moment propagation approach assuming that at each layer of the network, all propagated variables are Gaussians. Moreover, we make first experiments with propagating a covering. \n",
    "\n",
    "This notebook is also not relevant to our final result, except that it shows how the Gaussianity assumption is unrealistic, however it is included to document our work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include(\"../src/DSI.jl\")\n",
    "include(\"../src/Zono_utils.jl\")\n",
    "include(\"../src/PZono.jl\")\n",
    "include(\"../src/DSZ.jl\")\n",
    "include(\"../src/propagation.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using SpecialFunctions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# P-box setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "using PyPlot\n",
    "using Distributions\n",
    "\n",
    "c = normal(interval(0,1),1)\n",
    "ProbabilityBoundsAnalysis.plot(c)\n",
    "\n",
    "arr = []\n",
    "for mean in 0:0.01:1\n",
    "    normal_dist = Normal(mean, 1)\n",
    "    push!(arr, normal_dist)\n",
    "    x = range(mean - 4, mean + 4, length=1000)\n",
    "    cdf_values = cdf.(normal_dist, x)\n",
    "    plot(x, cdf_values)\n",
    "end\n",
    "\n",
    "PyPlot.savefig(\"parametric.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W1 = [1 -1.0; 1.0 1.]\n",
    "b = [0.0; 0.0]\n",
    "W2 = [1 -1.0; 1.0 1.]\n",
    "L1 = Layer(W1, b, ReLU())\n",
    "L2 = Layer(W2, b, Id())\n",
    "full_net = Network([L1; L2])\n",
    "# input range\n",
    "x = [normal(interval(0,1),1), normal(interval(0,1),1)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Moment-propagation approach"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implementation of the moment propagation approach introduced in https://arxiv.org/abs/2403.16163. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function propagate_network(mean_x, cov_x, network::Network)\n",
    "    for layer in network.layers\n",
    "        mean_x, cov_x = propagate_layer(mean_x, cov_x, layer)\n",
    "    end\n",
    "    return mean_x, cov_x\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function propagate_layer(mean_x, cov_x, layer::Layer)\n",
    "    mean_y = layer.weights * mean_x .+ layer.bias\n",
    "    cov_y = layer.weights * cov_x * transpose(layer.weights)\n",
    "    if layer.activation === ReLU()\n",
    "        mean_z, cov_z = propagate_relu(mean_y, cov_y)\n",
    "    elseif layer.activation === Id()\n",
    "        mean_z, cov_z = mean_y, cov_y\n",
    "    end\n",
    "    \n",
    "    return mean_z, cov_z\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "function propagate_relu(mean_y, cov_y)\n",
    "    var_y = diag(cov_y)\n",
    "    std_y = sqrt.(var_y)\n",
    "\n",
    "    std_y = std_y .+ 1e-8\n",
    "\n",
    "    alpha = mean_y ./ std_y\n",
    "    CDF_alpha = 0.5 * (1 .+ erf.(alpha ./ sqrt(2))) \n",
    "    PDF_alpha = exp.(-0.5 .* alpha.^2) ./ sqrt(2π) \n",
    "\n",
    "    mean_z = mean_y .* CDF_alpha .+ std_y .* PDF_alpha\n",
    "\n",
    "    derelu = []\n",
    "    push!(derelu,CDF_alpha)\n",
    "    push!(derelu, PDF_alpha ./ std_y)\n",
    "    push!(derelu, -mean_y ./ (std_y.^3) .* PDF_alpha)\n",
    "    push!(derelu, PDF_alpha ./ (std_y.^3) .* ((mean_y.^2 ./ (std_y.^2)) .- 1))\n",
    "    push!(derelu, - mean_y./ (std_y.^5) .* ((mean_y./std_y).^2 .- 3) .* PDF_alpha)\n",
    "\n",
    "    n = length(mean_y)\n",
    "    cov_z = zeros(n, n)\n",
    "    for i in 1:n\n",
    "        for j in 1:n\n",
    "            rho_ij = cov_y[i, j] / (std_y[i] * std_y[j])\n",
    "            for k in 1:5\n",
    "                cov_z[i, j] += ((rho_ij^k)/factorial(k)) * (std_y[i]^k * derelu[k][i]) * (std_y[j]^k * derelu[k][j])\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "\n",
    "    return mean_z, cov_z\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing this on the toy example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_discretization_steps = 100\n",
    "ProbabilityBoundsAnalysis.setSteps(nb_discretization_steps)\n",
    "pz = pbox_approximate_nnet(full_net, x, true)\n",
    "ProbabilityBoundsAnalysis.plot(pz[1])\n",
    "\n",
    "mean_arr = [0:0.01:1;]\n",
    "\n",
    "for i in 1:length(mean_arr)\n",
    "    for j in i:length(mean_arr)\n",
    "        mean_input = [mean_arr[i], mean_arr[j]]\n",
    "        cov_input = [1.0 0; 0 1.0]\n",
    "        mean_output, cov_output = propagate_network(mean_input, cov_input, full_net)\n",
    "        normal_dist = Normal(mean_output[1], cov_output[1,1])\n",
    "        supp = range(mean_output[1] - 10, mean_output[1] + 10, length=1000)\n",
    "        cdf_values = cdf.(normal_dist, supp)\n",
    "        plot(supp, cdf_values)\n",
    "    end\n",
    "end\n",
    "PyPlot.savefig(\"MomentPropagationfst.png\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_discretization_steps = 100\n",
    "ProbabilityBoundsAnalysis.setSteps(nb_discretization_steps)\n",
    "pz = pbox_approximate_nnet(full_net, x, true)\n",
    "ProbabilityBoundsAnalysis.plot(pz[2])\n",
    "\n",
    "mean_arr = [0:0.01:1;]\n",
    "\n",
    "for i in 1:length(mean_arr)\n",
    "    for j in i:length(mean_arr)\n",
    "        mean_input = [mean_arr[i], mean_arr[j]]\n",
    "        cov_input = [1.0 0; 0 1.0]\n",
    "        mean_output, cov_output = propagate_network(mean_input, cov_input, full_net)\n",
    "        normal_dist = Normal(mean_output[2], cov_output[2,2])\n",
    "        supp = range(mean_output[2] - 20, mean_output[2] + 20, length=1000)\n",
    "        cdf_values = cdf.(normal_dist, supp)\n",
    "        plot(supp, cdf_values)\n",
    "    end\n",
    "end\n",
    "PyPlot.savefig(\"MomentPropagationsnd.png\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.11.2",
   "language": "julia",
   "name": "julia-1.11"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
