{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from pytorch_fid.fid_score import calculate_frechet_distance\n",
    "import torch\n",
    "from utils_dqa import compute_kernel\n",
    "_SCALE=1000\n",
    "# Set random seed for reproducibility\n",
    "np.random.seed(42)\n",
    "N_train = 1000\n",
    "def compute_mmd(fx, fy, sigma=10):\n",
    "\n",
    "    fx = torch.tensor(fx)\n",
    "    fy = torch.tensor(fy)\n",
    "    # Reshape features if necessary\n",
    "    fx = fx.view(fx.size(0), -1)\n",
    "    fy = fy.view(fy.size(0), -1)\n",
    "\n",
    "    # Define kernel parameter gamma\n",
    "    gamma = 1 / (2 * sigma ** 2)\n",
    "\n",
    "    # Compute kernel matrices\n",
    "    K_xx = compute_kernel(fx, fx, gamma)\n",
    "    K_yy = compute_kernel(fy, fy, gamma) \n",
    "    K_xy = compute_kernel(fx, fy, gamma)\n",
    "\n",
    "    # Calculate MMD\n",
    "    mmd = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()\n",
    "    return mmd *_SCALE\n",
    "def auto_fid(A,B):\n",
    "    mu_A = np.mean(A, axis=0)\n",
    "    sigma_A = np.cov(A, rowvar=False)\n",
    "    mu_B = np.mean(B, axis=0)\n",
    "    sigma_B = np.cov(B, rowvar=False)\n",
    "    fid = calculate_frechet_distance(mu_A, sigma_A, mu_B, sigma_B)\n",
    "    return fid\n",
    "\n",
    "\n",
    "def generate_non_gaussian_data(N_train, N_test, translation_A=0, translation_B=0):\n",
    "    \"\"\"\n",
    "    Generate non-Gaussian training and test data with translation applied to each group.\n",
    "    \n",
    "    Args:\n",
    "    - N_train: Number of training samples\n",
    "    - N_test: Number of test samples\n",
    "    - translation_A: Translation applied to group A's data\n",
    "    - translation_B: Translation applied to group B's data\n",
    "    \n",
    "    Returns:\n",
    "    - train_A, train_B, test_A, test_B, train_data\n",
    "    \"\"\"\n",
    "    np.random.seed(0)\n",
    "    mean_A = [0, 0]\n",
    "    cov_A = [[1, 0], [0, 1]]  # Small variance\n",
    "    mean_B = [15, 15]\n",
    "    cov_B = [[8, 0], [0, 8]]\n",
    "    train_A = (\n",
    "        np.random.multivariate_normal(mean_A, cov_A, N_train) +\n",
    "        # np.random.uniform(low=-6, high=6, size=(N_train, 2)) + translation_A\n",
    "        + np.random.exponential(scale=1.0, size=(N_train, 2)) \n",
    "    )\n",
    "\n",
    "    # Generate train_B by summing components and applying translation\n",
    "    train_B = (\n",
    "        np.random.multivariate_normal(mean_B, cov_B, N_train) +\n",
    "        # np.random.uniform(low=5, high=12, size=(N_train, 2)) + translation_B\n",
    "        + np.random.exponential(scale=2.0, size=(N_train, 2))\n",
    "    )\n",
    "\n",
    "    # Similarly, generate test_A\n",
    "    test_A = (\n",
    "        np.random.multivariate_normal(mean_A, cov_A, N_test) +\n",
    "        # np.random.uniform(low=-6, high=6, size=(N_test, 2)) + translation_A\n",
    "        + np.random.exponential(scale=1.0, size=(N_test, 2)) \n",
    "    )\n",
    "\n",
    "    # Generate test_B\n",
    "    test_B = (\n",
    "        np.random.multivariate_normal(mean_B, cov_B, N_test) +\n",
    "        # np.random.uniform(low=5, high=12, size=(N_test, 2)) + translation_B\n",
    "        + np.random.exponential(scale=2.0, size=(N_test, 2)) \n",
    "    )\n",
    "\n",
    "    train_data = np.vstack((train_A, train_B))\n",
    "\n",
    "    return train_A,train_B, test_A, test_B, train_data\n",
    "def generate_non_gaussian_test_data(N_test, translation_A1=0,translation_A2=0 ,translation_B1=0,translation_B2=0):\n",
    "    \"\"\"\n",
    "    Generate non-Gaussian test data with translation applied to each group.\n",
    "    \n",
    "    Args:\n",
    "    - N_test: Number of test samples\n",
    "    - translation_A: Translation applied to group A's data\n",
    "    - translation_B: Translation applied to group B's data\n",
    "    \n",
    "    Returns:\n",
    "    - test_A, test_B\n",
    "    \"\"\"\n",
    "    np.random.seed(0)\n",
    "    mean_A = [0, 0]\n",
    "    cov_A = [[3, 0], [0, 3]]  # Small variance\n",
    "    mean_B = [15, 15]\n",
    "    cov_B = [[12, 0], [0, 12]]\n",
    "    test_A = (\n",
    "        np.random.multivariate_normal(mean_A, cov_A, N_test) + translation_A1 +\n",
    "        # np.random.uniform(low=-6, high=6, size=(N_test, 2))+ translation_A1 \n",
    "        + np.random.exponential(scale=1.0+translation_A2, size=(N_test, 2))\n",
    "    )\n",
    "\n",
    "    # Generate test_B with summation for 2D and apply different translations\n",
    "    test_B = (\n",
    "        np.random.multivariate_normal(mean_B, cov_B, N_test) + translation_B1 +\n",
    "        # np.random.uniform(low=5, high=12, size=(N_test, 2)) + translation_B1 \n",
    "        +np.random.exponential(scale=2.0+translation_B2, size=(N_test, 2))\n",
    "    )\n",
    "\n",
    "    return test_A, test_B\n",
    "def plot_data(ax, train_A, train_B, test_A, test_B,title,results):\n",
    "    (mmd_all_A, mmd_all_B, mmd_fid_A, mmd_fid_B, mmd_A, mmd_B,fid_A, fid_B, DQA_mmd, DQA_fid) = results\n",
    "    # Plot the data points for reference and generated data\n",
    "    ax.scatter(train_A[:, 0], train_A[:, 1], alpha=0.7, label='Reference Group A',color='C0')\n",
    "    ax.scatter(train_B[:, 0], train_B[:, 1], alpha=0.7, label='Reference Group B',color='C1')\n",
    "    ax.scatter(test_A[:, 0], test_A[:, 1], alpha=0.7, label='Generated A', marker='x',color='C2')\n",
    "    ax.scatter(test_B[:, 0], test_B[:, 1], alpha=0.7, label='Generated B', marker='x',color='C3')\n",
    "    \n",
    "\n",
    "    # Set plot limits and labels\n",
    "    ax.legend(loc='lower right',fontsize=16)\n",
    "    # ax.set_title(f'DQA-MMD: {DQA_mmd:.2f}, DQA-FD: {DQA_fid:.2f}',fontsize=20)\n",
    "    ax.set_xlim((-4, 4.2))\n",
    "    ax.set_ylim((-4, 4.2))\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "    # Add p-values as text in the figure\n",
    "    textstr = f\"\\u2194 All Ref. \\n MMD A: {mmd_all_A:.2f}\\n MMD B: {mmd_all_B:.2f}\\n FD A: {mmd_fid_A:.2f}\\n FD B: {mmd_fid_B:.2f}\"\n",
    "    textstr2 = f\"\\u2194Group-Specific Ref. \\n MMD A: {mmd_A:.2f}\\n MMD B: {mmd_B:.2f}\\n FD A: {fid_A:.2f}\\n FD B: {fid_B:.2f}\"\n",
    "\n",
    "    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=16,\n",
    "            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.3))\n",
    "    ax.text(0.05, 0.23, textstr2, transform=ax.transAxes, fontsize=16,\n",
    "            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.3))\n",
    "\n",
    "\n",
    "def compute_DQA(scaler, ax, test_A, test_B, train_A, train_B,train_data,title):\n",
    "    epsilon = 1e-3\n",
    "\n",
    "    all_ref_data = scaler.transform(train_data)\n",
    "    ref_A = scaler.transform(train_A)\n",
    "    ref_B = scaler.transform(train_B)\n",
    "    gen_A = scaler.transform(test_A)\n",
    "    gen_B = scaler.transform(test_B)\n",
    "\n",
    "    \n",
    "    gen_data = np.vstack((gen_A, gen_B))\n",
    "    \n",
    "\n",
    "    \n",
    "    std_ref = np.std(all_ref_data)\n",
    "    std_gen = np.std(gen_data)\n",
    "    std = np.sqrt((std_ref**2 +std_gen**2)/2 )\n",
    "    dim =gen_data.shape[1]\n",
    "    ref_data = np.vstack((ref_A, ref_B))\n",
    "\n",
    "\n",
    "    \n",
    "    mmd_all_A = compute_mmd(gen_A,ref_data)\n",
    "    mmd_all_B = compute_mmd(gen_B,ref_data)\n",
    "    mmd_fid_A = auto_fid(gen_A,ref_data)\n",
    "    mmd_fid_B = auto_fid(gen_B,ref_data)\n",
    "\n",
    "    mmd_A = compute_mmd(gen_A, ref_A)\n",
    "    mmd_B = compute_mmd(gen_B, ref_B)\n",
    "    obs_diff_mmd = abs(mmd_A - mmd_B)\n",
    "    fid_A = auto_fid(gen_A, ref_A)\n",
    "    fid_B = auto_fid(gen_B, ref_B)\n",
    "    obs_diff_fid = abs(fid_A - fid_B)\n",
    "    \n",
    "    \n",
    "    D_mmd = compute_mmd(ref_data, gen_data)\n",
    "    D_fid = auto_fid(ref_data, gen_data)\n",
    "    DQA_mmd = obs_diff_mmd / (D_mmd+0.1*std)\n",
    "    DQA_fid = obs_diff_fid / (D_fid+0.1*std)\n",
    "        \n",
    "    plot_data(ax, ref_A, ref_B, gen_A, gen_B,title, (mmd_all_A, mmd_all_B, mmd_fid_A, mmd_fid_B, mmd_A, mmd_B,fid_A, fid_B, DQA_mmd, DQA_fid))\n",
    "\n",
    "    return obs_diff_fid, obs_diff_mmd\n",
    "# Create subplots\n",
    "N = 1\n",
    "fig, axs = plt.subplots(N, 2, figsize=(12, 6*N))\n",
    "\n",
    "right_lim=15\n",
    "x = 2\n",
    "N_train = 500\n",
    "N_test = 500\n",
    "train_A, train_B, test_A, test_B, train_data = generate_non_gaussian_data(N_train, N_test, translation_A=0, translation_B=0)\n",
    "scaler = StandardScaler()\n",
    "scaler.fit(train_data)\n",
    "test_A, test_B = generate_non_gaussian_test_data(N_test, translation_A1=3,translation_A2=0.2,translation_B1=-3,translation_B2=0.2)\n",
    "compute_DQA(scaler, axs[0], test_A, test_B, train_A, train_B, train_data, \"DQA with out-of-distribution (mean-shift case 1)\")\n",
    "\n",
    "test_A, test_B = generate_non_gaussian_test_data(N_test, translation_A1=1,translation_A2=0.2,translation_B1=-11,translation_B2=0.2)\n",
    "compute_DQA(scaler, axs[1], test_A, test_B, train_A, train_B, train_data, \"DQA with out-of-distribution (mean-shift case 2)\")\n",
    "\n",
    "plt.tight_layout()\n",
    "# plt.savefig('DQA_fair_non_gaussian.png')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
