{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "import h5py\n",
    "import pickle\n",
    "import cv2\n",
    "from pathlib import Path\n",
    "\n",
    "from src.metrics import obtain_metrics_minimality, obtain_metrics_sufficiency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read Factors\n",
    "def read_factors(dataset):\n",
    "    if dataset=='dsprites':\n",
    "        fname = \"/mnt/cephfs/home/voz/shared/database/disentanglement/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz\"\n",
    "        data = np.load(fname)\n",
    "        factors = data['latents_values']\n",
    "        data.close()\n",
    "    elif dataset=='3dshapes':\n",
    "        fname = \"/mnt/cephfs/home/voz/shared/database/disentanglement/3dshapes.h5\"\n",
    "        factors = np.array(h5py.File(fname, 'r')['labels'])\n",
    "    elif dataset=='mpi3d':\n",
    "        fname = \"/mnt/cephfs/home/voz/shared/database/disentanglement/mpi3d/mpi3d_real.npz\"\n",
    "        data = np.load(fname)\n",
    "        factors = data['labels']\n",
    "        data.close()\n",
    "    else:\n",
    "        raise ValueError\n",
    "    return factors\n",
    "\n",
    "\n",
    "# Get Factors Correlated\n",
    "def sample_all_factors(\n",
    "        latent_factor_indices,\n",
    "        factor_sizes,\n",
    "        corr_indices, \n",
    "        line_width, \n",
    "        num, \n",
    "        random_seed=42\n",
    "    ):\n",
    "    \n",
    "    corr_factor_sizes = [\n",
    "        factor_sizes[corr_indices[0]],\n",
    "        factor_sizes[corr_indices[1]]\n",
    "    ]\n",
    "    \n",
    "    unnormalized_joint_prob = np.zeros(corr_factor_sizes, np.uint8)\n",
    "    \n",
    "    if line_width >= 1.0:  # choose uniform distribution if line_width >= 1\n",
    "        unnormalized_joint_prob = np.ones(corr_factor_sizes)\n",
    "    else:\n",
    "    \n",
    "        width = math.ceil(line_width * min(corr_factor_sizes))\n",
    "    \n",
    "        offset = 0\n",
    "        start = (0, offset)\n",
    "        end = (corr_factor_sizes[1], corr_factor_sizes[0])\n",
    "    \n",
    "        kernel_width = min(corr_factor_sizes) // 4\n",
    "    \n",
    "        if not kernel_width % 2:  # kernels widths must be odd\n",
    "            kernel_width += 1\n",
    "    \n",
    "        kernel_width_x = kernel_width\n",
    "        kernel_width_y = kernel_width\n",
    "    \n",
    "        cv2.line(unnormalized_joint_prob, start, end, 255, width)\n",
    "    \n",
    "        unnormalized_joint_prob = cv2.GaussianBlur(unnormalized_joint_prob,\n",
    "                                                   (kernel_width_x, kernel_width_y), 0)\n",
    "        unnormalized_joint_prob = unnormalized_joint_prob.astype(np.float_)\n",
    "    \n",
    "    # normalize\n",
    "    joint_prob = unnormalized_joint_prob / unnormalized_joint_prob.sum()\n",
    "    \n",
    "    random_state = np.random.RandomState(random_seed)\n",
    "\n",
    "    n_x, n_y = corr_factor_sizes\n",
    "    pairs = np.indices(dimensions=(n_x, n_y))\n",
    "    pairs = pairs.reshape(2, -1).T\n",
    "\n",
    "    inds = random_state.choice(np.arange(n_x * n_y),\n",
    "                               p=joint_prob.reshape(-1),\n",
    "                               size=num, replace=True)\n",
    "    \n",
    "    correlated_samples = pairs[inds]\n",
    "    \n",
    "    factors = np.zeros(\n",
    "        shape=(num, len(latent_factor_indices)), dtype=np.int64)\n",
    "\n",
    "    idx = np.argwhere(np.isin(latent_factor_indices, corr_indices))\n",
    "    idx = idx.flatten()\n",
    "    factors[:, idx] = correlated_samples\n",
    "\n",
    "    for pos, i in enumerate(latent_factor_indices):\n",
    "        if not i in corr_indices:\n",
    "            factors[:, pos] = random_state.randint(factor_sizes[i], size=num)\n",
    "\n",
    "    return np.float64(factors)\n",
    "\n",
    "\n",
    "def get_factors_cor(factors, corr_indices, num, sigma, factor_indices=None):\n",
    "    factor_indices = list(range(factors.shape[1])) if factor_indices is None else factor_indices\n",
    "    return sample_all_factors(\n",
    "        latent_factor_indices=factor_indices,\n",
    "        factor_sizes=[len(np.unique(factors[:, j])) for j in range(factors.shape[1])],\n",
    "        corr_indices=corr_indices, \n",
    "        line_width=sigma, \n",
    "        num=num, \n",
    "        random_seed=42\n",
    "    )\n",
    "\n",
    "\n",
    "# Get Rrepresentations\n",
    "def create_random_matrix(n, alpha):\n",
    "    matrix = np.zeros((n, n))\n",
    "    for j in range(n):\n",
    "        matrix[j, j] = 1 - alpha\n",
    "        off_diag_randoms = np.random.rand(n - 1)\n",
    "        off_diag_randoms *= alpha / off_diag_randoms.sum()\n",
    "        matrix[:j, j] = off_diag_randoms[:j]   # Fill upper part of the column\n",
    "        matrix[j+1:, j] = off_diag_randoms[j:] # Fill lower part of the column\n",
    "    return matrix\n",
    "\n",
    "\n",
    "def create_uniform_matrix(n, alpha):\n",
    "    matrix = np.zeros((n, n))\n",
    "    for j in range(n):\n",
    "        matrix[j, j] = 1 - alpha\n",
    "        if j == n - 1:\n",
    "            matrix[0, j] = alpha\n",
    "        for i in range(n):\n",
    "            if i != j:  # Exclude the diagonal element\n",
    "                matrix[i, j] = np.random.uniform(0, alpha)\n",
    "    return matrix\n",
    "\n",
    "\n",
    "def create_perm_matrix(n, alpha):\n",
    "    matrix = np.zeros((n, n))\n",
    "    for j in range(n):\n",
    "        matrix[j, j] = 1 - alpha\n",
    "        if j < n - 1:\n",
    "            matrix[j + 1, j] = alpha\n",
    "        else:\n",
    "            matrix[0, j] = alpha\n",
    "        remaining_indices = [i for i in range(n) if i != j and i != (j + 1) % n]  # Exclude diagonal and subdiagonal\n",
    "        if remaining_indices:\n",
    "            off_diag_randoms = np.random.rand(len(remaining_indices))\n",
    "            off_diag_randoms *= 0\n",
    "            for idx, i in enumerate(remaining_indices):\n",
    "                matrix[i, j] = off_diag_randoms[idx]\n",
    "    return matrix\n",
    "\n",
    "\n",
    "def get_representations(factors, alpha=0., matrix_type='uniform'):\n",
    "    matrix = create_uniform_matrix(factors.shape[1], alpha)\n",
    "    return factors @ matrix\n",
    "\n",
    "\n",
    "# Get Correlated Factos + Representations\n",
    "def get_factors_representations(dataset, corr_indices, num, sigma, alpha, factor_indices=None):\n",
    "    factors = read_factors(dataset)\n",
    "    factors_cor = get_factors_cor(factors, corr_indices, num, sigma, factor_indices)\n",
    "    representations = get_representations(factors_cor, alpha)\n",
    "    return factors_cor, representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "dataset = 'dsprites'\n",
    "corr_indices = [3, 4]\n",
    "factor_indices = [1,2,3,4,5]\n",
    "num = 5000\n",
    "alphas = np.arange(0, 1., 0.05)\n",
    "sigmas = np.arange(0.01, 1., 0.05)\n",
    "#sigmas = [0.2, 0.7]\n",
    "\n",
    "\n",
    "scores = {alpha: {sigma: obtain_metrics_minimality(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), 'dsprites_minimality.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)\n",
    "\n",
    "\"\"\"\n",
    "scores = {alpha: {sigma: obtain_metrics_sufficiency(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha, factor_indices)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), 'dsprites_sufficiency.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "dataset = '3dshapes'\n",
    "corr_indices = [3, 5]\n",
    "num = 5000\n",
    "alphas = np.arange(0, 1., 0.05)\n",
    "sigmas = np.arange(0.01, 1., 0.05)\n",
    "#sigmas = [0.2, 0.7]\n",
    "\n",
    "scores = {alpha: {sigma: obtain_metrics_minimality(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), '3dshapes_minimality.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)\n",
    "\n",
    "scores = {alpha: {sigma: obtain_metrics_sufficiency(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), '3dshapes_sufficiency.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "dataset = 'mpi3d'\n",
    "corr_indices = [5,6]\n",
    "num = 5000\n",
    "alphas = np.arange(0, 1., 0.05)\n",
    "sigmas = np.arange(0.01, 1., 0.05)\n",
    "#sigmas = [0.2, 0.7]\n",
    "\n",
    "scores = {alpha: {sigma: obtain_metrics_minimality(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), 'mpi3d_minimality.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)\n",
    "\n",
    "scores = {alpha: {sigma: obtain_metrics_sufficiency(\n",
    "    *get_factors_representations(dataset, corr_indices, num, sigma, alpha)\n",
    ") for sigma in sigmas} for alpha in alphas}\n",
    "fname = os.path.join(Path.cwd(), 'mpi3d_sufficiency.pkl')\n",
    "with open(fname, 'wb') as file:\n",
    "    pickle.dump(scores, file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
