{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d600f08e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from discriminative_metrics import discriminative_score_metrics\n",
    "from predictive_metrics import predictive_score_metrics\n",
    "from context_fid import Context_FID\n",
    "from cross_correlation import CrossCorrelLoss\n",
    "from metric_utils import display_scores\n",
    "from dtw import dtw_js_divergence_distance"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5f4a164",
   "metadata": {},
   "source": [
    "### Load Real and Generated Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96588f6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"default_experiment\"\n",
    "real_data_path = f\"../../outputs/{experiment_name}/real_samples.npy\"\n",
    "gen_data_path = f\"../../outputs/{experiment_name}/ddpm_samples.npy\"\n",
    "real_data = np.load(real_data_path)\n",
    "generated_data = np.load(gen_data_path)\n",
    "real_data.shape, generated_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6fb459a",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_samples = min(real_data.shape[0], generated_data.shape[0])\n",
    "if real_data.shape[0] > num_samples:\n",
    "    print(f\"WARNING: Generated data only has {generated_data.shape[0]} samples, less than real data's {real_data.shape[0]} samples. Using all {num_samples} generated samples for evaluation.\")\n",
    "else:\n",
    "    print(f\"number of samples: {num_samples}\")\n",
    "\n",
    "random_indices = np.random.choice(len(real_data), num_samples, replace=False)\n",
    "real_data = real_data[random_indices]\n",
    "random_indices = np.random.choice(len(generated_data), num_samples, replace=False)\n",
    "generated_data = generated_data[random_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71296ef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# minmax scale the inputs for fair comparison\n",
    "data_min = np.min(real_data, axis=(0,1), keepdims=True)\n",
    "data_max = np.max(real_data, axis=(0,1), keepdims=True)\n",
    "\n",
    "real_data = (real_data - data_min) / (data_max - data_min)\n",
    "generated_data = (generated_data - data_min) / (data_max - data_min)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da09d10a",
   "metadata": {},
   "source": [
    "### Discriminative Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b47b22",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 5\n",
    "discriminative_score = []\n",
    "\n",
    "for i in range(iterations):\n",
    "    temp_disc, fake_acc, real_acc = discriminative_score_metrics(real_data, generated_data)\n",
    "    discriminative_score.append(temp_disc)\n",
    "    print(f'Iter {i}: ', temp_disc, '\\n')\n",
    "      \n",
    "display_scores(discriminative_score)\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32ec7287",
   "metadata": {},
   "source": [
    "### Predictive Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d766f30c",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 5\n",
    "predictive_score = []\n",
    "for i in range(iterations):\n",
    "    temp_pred = predictive_score_metrics(real_data, generated_data)\n",
    "    predictive_score.append(temp_pred)\n",
    "    print(i, ' epoch: ', temp_pred, '\\n')\n",
    "      \n",
    "display_scores(predictive_score)\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18a2b417",
   "metadata": {},
   "source": [
    "### Context-FID Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bdc9d5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_fid_score = []\n",
    "\n",
    "for i in range(iterations):\n",
    "    context_fid = Context_FID(real_data, generated_data)\n",
    "    context_fid_score.append(context_fid)\n",
    "    print(f'Iter {i}: ', 'context-fid =', context_fid, '\\n')\n",
    "      \n",
    "display_scores(context_fid_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ea6fe97",
   "metadata": {},
   "source": [
    "### Correlational Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64dba75d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_choice(size, num_select=100):\n",
    "    select_idx = np.random.randint(low=0, high=size, size=(num_select,))\n",
    "    return select_idx\n",
    "\n",
    "x_real = torch.from_numpy(real_data)\n",
    "x_fake = torch.from_numpy(generated_data)\n",
    "\n",
    "correlational_score = []\n",
    "# size = int(x_real.shape[0] / iterations)\n",
    "size = 1000\n",
    "\n",
    "for i in range(iterations):\n",
    "    real_idx = random_choice(x_real.shape[0], size)\n",
    "    fake_idx = random_choice(x_fake.shape[0], size)\n",
    "    corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')\n",
    "    loss = corr.compute(x_fake[fake_idx, :, :])\n",
    "    correlational_score.append(loss.item())\n",
    "    print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\\n')\n",
    "\n",
    "display_scores(correlational_score)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4237eebe",
   "metadata": {},
   "source": [
    "### DTW distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc548e60",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 5\n",
    "js_results = []\n",
    "for i in range(iterations):\n",
    "    js_dist = dtw_js_divergence_distance(real_data, generated_data, n_samples=100)['js_divergence']\n",
    "    print(\"js_dist: \", round(js_dist, 4))\n",
    "    js_results.append(js_dist)\n",
    "display_scores(js_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "wavediff",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
