{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MCal Calibration Demo\n",
    "\n",
    "This notebook demonstrates the usage of the MCal calibration framework."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.path.append('../src')\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from calibrators import MCal, PlattCalibrator, TemperatureScaling\n",
    "from utils.optimization import get_expectation\n",
    "from utils.visualization import plot_training_curves"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Synthetic Data\n",
    "\n",
    "Create synthetic probability distributions for testing calibration methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(1234)\n",
    "\n",
    "d = 4  # Number of classes\n",
    "N = 1000  # Number of samples\n",
    "\n",
    "# Create clean (target) probabilities - uniform distribution\n",
    "clean_probs = torch.ones(N, d)\n",
    "clean_probs /= clean_probs.sum(dim=1, keepdim=True)\n",
    "\n",
    "# Create ablated (biased) probabilities - biased towards first class\n",
    "ablated_probs = clean_probs + torch.rand_like(clean_probs) + torch.eye(d)[0]\n",
    "ablated_probs /= ablated_probs.sum(dim=1, keepdim=True)\n",
    "\n",
    "print(f\"Clean probs shape: {clean_probs.shape}\")\n",
    "print(f\"Ablated probs shape: {ablated_probs.shape}\")\n",
    "print(f\"Raw accuracy: {(clean_probs.argmax(dim=-1) == ablated_probs.argmax(dim=-1)).float().mean():.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test MCal Calibrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and fit MCal\n",
    "mcal = MCal(d)\n",
    "mcal_stats = mcal.fit(\n",
    "    ablated_probs, \n",
    "    clean_probs, \n",
    "    verbose=True, \n",
    "    kappa=1.0,\n",
    "    lr=1e-3,\n",
    "    early_stopping=False, \n",
    "    max_steps=1000\n",
    ")\n",
    "\n",
    "# Get calibrated outputs\n",
    "mcal_output = mcal(ablated_probs)\n",
    "print(f\"MCal output shape: {mcal_output.shape}\")\n",
    "\n",
    "# Calculate expectations\n",
    "one_hot_exp, prob_exp = get_expectation(mcal_output)\n",
    "print(f\"MCal one-hot expectation: {one_hot_exp}\")\n",
    "print(f\"MCal prob expectation: {prob_exp}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Platt Calibrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and fit Platt calibrator\n",
    "platt = PlattCalibrator(d)\n",
    "platt_stats = platt.fit(\n",
    "    ablated_probs, \n",
    "    clean_probs, \n",
    "    lr=1e-3, \n",
    "    verbose=True, \n",
    "    max_steps=1000\n",
    ")\n",
    "\n",
    "# Get calibrated outputs\n",
    "platt_output = platt(ablated_probs)\n",
    "print(f\"Platt output shape: {platt_output.shape}\")\n",
    "\n",
    "# Calculate expectations\n",
    "one_hot_exp, prob_exp = get_expectation(platt_output)\n",
    "print(f\"Platt one-hot expectation: {one_hot_exp}\")\n",
    "print(f\"Platt prob expectation: {prob_exp}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Temperature Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and fit temperature scaling\n",
    "temp_scaler = TemperatureScaling(d)\n",
    "temp_stats = temp_scaler.fit(\n",
    "    ablated_probs, \n",
    "    clean_probs, \n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Get calibrated outputs\n",
    "temp_output = temp_scaler(ablated_probs)\n",
    "print(f\"Temperature scaling output shape: {temp_output.shape}\")\n",
    "print(f\"Learned temperature: {temp_scaler.temperature.item():.4f}\")\n",
    "\n",
    "# Calculate expectations\n",
    "one_hot_exp, prob_exp = get_expectation(temp_output)\n",
    "print(f\"Temperature one-hot expectation: {one_hot_exp}\")\n",
    "print(f\"Temperature prob expectation: {prob_exp}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize Training Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot MCal training curves\n",
    "fig, ax = plot_training_curves(mcal_stats, title='MCal Training Curves')\n",
    "plt.show()\n",
    "\n",
    "# Plot Platt training curves  \n",
    "fig, ax = plot_training_curves(platt_stats, title='Platt Training Curves')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Results\n",
    "\n",
    "Compare the calibration performance of different methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate original expectations\n",
    "orig_one_hot, orig_prob = get_expectation(ablated_probs)\n",
    "\n",
    "print(\"Comparison of Expectations:\")\n",
    "print(f\"Original prob expectation: {orig_prob}\")\n",
    "print(f\"MCal prob expectation: {get_expectation(mcal_output)[1]}\")\n",
    "print(f\"Platt prob expectation: {get_expectation(platt_output)[1]}\")\n",
    "print(f\"Temperature prob expectation: {get_expectation(temp_output)[1]}\")\n",
    "print(f\"Target (uniform): {torch.ones(d) / d}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "This demo shows how to use the MCal framework to calibrate probability distributions using different methods:\n",
    "\n",
    "1. **MCal**: Vector scaling with learnable class-specific parameters\n",
    "2. **Platt Calibrator**: Platt scaling with logit transformation\n",
    "3. **Temperature Scaling**: Simple temperature parameter scaling\n",
    "\n",
    "Each method has different strengths and is suitable for different scenarios."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
