{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f8e81b4e",
   "metadata": {},
   "source": [
    "# Publication-Quality Comparisons and Visualizations\n",
    "This notebook contains additional analyses and visualizations to strengthen the publication. Sections include statistical comparisons, correlation analysis, distribution and time series visualizations, category comparisons, advanced plots, and export of figures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb0806e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import Required Libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import plotly.express as px\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9a9d32",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Publication Data\n",
    "data_path = 'C:/works/ArtifactGen/data/processed/suggested_splits_subjectwise_multilabel_filtered.csv'\n",
    "df = pd.read_csv(data_path)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07f81088",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic Statistical Comparisons\n",
    "# Show subject count per split\n",
    "split_counts = df['split'].value_counts()\n",
    "print('Subject count per split:')\n",
    "print(split_counts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4cec0db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correlation Analysis\n",
    "# No metric columns available for correlation analysis.\n",
    "print('No metric columns available for correlation analysis.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7754bcc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Distribution Visualization\n",
    "# Visualize split distribution\n",
    "plt.figure(figsize=(8,5))\n",
    "split_counts.plot(kind='bar')\n",
    "plt.title('Subject Distribution Across Splits')\n",
    "plt.xlabel('Split')\n",
    "plt.ylabel('Count')\n",
    "plt.tight_layout()\n",
    "plt.savefig('C:/works/ArtifactGen/paper/figs/split_distribution.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2328aa3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Additional Visualizations: Artifact Class Frequency and Split Proportions\n",
    "import os\n",
    "class_map_path = os.path.join('C:/works/ArtifactGen/data/processed', 'class_map.csv')\n",
    "class_map_df = pd.read_csv(class_map_path)\n",
    "\n",
    "# Bar plot of artifact class frequencies\n",
    "plt.figure(figsize=(7,4))\n",
    "ax = sns.barplot(x='display', y=[1]*len(class_map_df), data=class_map_df, palette='viridis')\n",
    "ax.set_title('Artifact Class Frequency (Dummy Count)')\n",
    "ax.set_xlabel('Artifact Class')\n",
    "ax.set_ylabel('Count')\n",
    "plt.tight_layout()\n",
    "plt.savefig('C:/works/ArtifactGen/paper/figs/class_frequency_bar.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "# Pie chart of split proportions\n",
    "split_props = df['split'].value_counts()\n",
    "plt.figure(figsize=(6,6))\n",
    "plt.pie(split_props, labels=split_props.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette('pastel'))\n",
    "plt.title('Proportion of Subjects in Each Split')\n",
    "plt.savefig('C:/works/ArtifactGen/paper/figs/split_proportion_pie.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "814a1d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# List all available scalar tags in TensorBoard event files for both ddpm and wgan\n",
    "# Requires: pip install tensorboard\n",
    "from tensorboard.backend.event_processing.event_accumulator import EventAccumulator\n",
    "import os\n",
    "\n",
    "def list_scalars(tb_dir):\n",
    "    scalars = set()\n",
    "    for fname in os.listdir(tb_dir):\n",
    "        if fname.startswith('events.out.tfevents'):\n",
    "            ea = EventAccumulator(os.path.join(tb_dir, fname))\n",
    "            ea.Reload()\n",
    "            scalars.update(ea.Tags()['scalars'])\n",
    "    return sorted(list(scalars))\n",
    "\n",
    "ddpm_dir = r'C:/works/ArtifactGen/results/tensorboard/ddpm'\n",
    "wgan_dir = r'C:/works/ArtifactGen/results/tensorboard/wgan'\n",
    "\n",
    "print('DDPM Scalars:')\n",
    "for tag in list_scalars(ddpm_dir):\n",
    "    print(tag)\n",
    "\n",
    "print('\\nWGAN Scalars:')\n",
    "for tag in list_scalars(wgan_dir):\n",
    "    print(tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0bd9ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Helper to extract scalar data from all event files in a directory\n",
    "def extract_scalars(tb_dir, tags):\n",
    "    data = {tag: [] for tag in tags}\n",
    "    steps = {tag: [] for tag in tags}\n",
    "    for fname in os.listdir(tb_dir):\n",
    "        if fname.startswith('events.out.tfevents'):\n",
    "            ea = EventAccumulator(os.path.join(tb_dir, fname))\n",
    "            ea.Reload()\n",
    "            for tag in tags:\n",
    "                if tag in ea.Tags()['scalars']:\n",
    "                    events = ea.Scalars(tag)\n",
    "                    data[tag].extend([e.value for e in events])\n",
    "                    steps[tag].extend([e.step for e in events])\n",
    "    # Sort by step\n",
    "    for tag in tags:\n",
    "        zipped = sorted(zip(steps[tag], data[tag]))\n",
    "        steps[tag] = [z[0] for z in zipped]\n",
    "        data[tag] = [z[1] for z in zipped]\n",
    "    return steps, data\n",
    "\n",
    "# Directories and tags\n",
    "DDPM_DIR = r'C:/works/ArtifactGen/results/tensorboard/ddpm'\n",
    "WGAN_DIR = r'C:/works/ArtifactGen/results/tensorboard/wgan'\n",
    "DDPM_TAGS = ['Loss/MSE']\n",
    "WGAN_TAGS = ['Loss/D', 'Loss/G', 'Loss/Spectral']\n",
    "\n",
    "# Extract data\n",
    "steps_ddpm, data_ddpm = extract_scalars(DDPM_DIR, DDPM_TAGS)\n",
    "steps_wgan, data_wgan = extract_scalars(WGAN_DIR, WGAN_TAGS)\n",
    "\n",
    "# Plot DDPM Loss/MSE vs WGAN Losses\n",
    "plt.figure(figsize=(10,6))\n",
    "if data_ddpm['Loss/MSE']:\n",
    "    plt.plot(steps_ddpm['Loss/MSE'], data_ddpm['Loss/MSE'], label='DDPM Loss/MSE')\n",
    "for tag in WGAN_TAGS:\n",
    "    if data_wgan[tag]:\n",
    "        plt.plot(steps_wgan[tag], data_wgan[tag], label=f'WGAN {tag}')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Loss / MSE')\n",
    "plt.title('DDPM vs WGAN: Loss and MSE Comparison')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig('C:/works/ArtifactGen/paper/figs/ddpm_wgan_loss_comparison.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
