{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a459560d-b425-48d5-ba3d-515c0eabb6bd",
   "metadata": {},
   "source": [
    "# Regression Bias"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6deb8ce7-1924-4e5f-ac53-d59d1c961eda",
   "metadata": {},
   "source": [
    "In this notebook, we analyze the regression bias in designing TSFMs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c67c291-448f-437a-b758-3ee786902b9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_rows', None)\n",
    "import random\n",
    "import statistics\n",
    "import scipy\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    !pip install transformers weightwatcher\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSeq2SeqLM,\n",
    "    AutoModelForCausalLM,\n",
    "    AutoConfig,\n",
    "    T5Config,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "\n",
    "from chronos import ChronosPipeline, ChronosBoltPipeline\n",
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "from matplotlib.colors import PowerNorm\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "235a7731-28bf-48eb-847d-3900115f460e",
   "metadata": {},
   "outputs": [],
   "source": [
    "bolt_p1 = # Add you pretrained Chronos-Bolt with a patch size of 1\n",
    "bolt = ChronosBoltPipeline.from_pretrained(f\"amazon/chronos-bolt-small\")\n",
    "chronos = ChronosPipeline.from_pretrained(f\"amazon/chronos-t5-small\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee071534-3238-48c5-9af1-16caf14c2cf0",
   "metadata": {},
   "source": [
    "## Test on a uniform branch walk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94dd31e8-b3de-4830-b863-b0a2db980a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = []\n",
    "for i in range(60):\n",
    "    context += [i % 2] * 5 #random.randint(1,5)\n",
    "\n",
    "plt.plot(context)\n",
    "\n",
    "gt = context[-64:]\n",
    "context = context[:-64]\n",
    "context = torch.from_numpy(np.array(context)).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e926513f-a021-4b2a-93d2-1bcb3f7b37d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_collection_e = []\n",
    "\n",
    "def lm_head_hook(module, input, output):\n",
    "    logits_collection_e.append(output.detach().cpu())\n",
    "\n",
    "handle = chronos.model.model.lm_head.register_forward_hook(lm_head_hook)\n",
    "\n",
    "output_chronos = chronos.predict(\n",
    "                context,\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()\n",
    "\n",
    "logits = torch.cat(logits_collection_e, dim=1)  # [batch, 1000, vocab_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c901723-b466-4b01-ae8b-b8c784f5a958",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Chronos Output\n",
    "'''\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[1]), context[0,:], label=\"context\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_chronos.shape[2]),output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chrono's Forecast\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_chronos.shape[2]), gt, color='gray', alpha=0.5, label=\"ground truth\")\n",
    "plt.legend(loc=\"lower left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5973c90-052a-4559-b5ef-717baa4316ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Chronos-Bolt Output\n",
    "'''\n",
    "output_bolt = bolt_p1.predict(\n",
    "                context,\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[1]), context[0,:], label=\"context\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]),output_bolt[0,4,:], label=\"forecast\")\n",
    "plt.fill_between(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]), output_bolt[0,2,:], output_bolt[0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]), output_bolt[0,0,:], output_bolt[0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_chronos.shape[2]), gt, color='gray', alpha=0.5, label=\"ground truth\")\n",
    "\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos-Bolt's Forecast\")\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.ylim(-0.1,1.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49c2b5bd-0043-4da6-8b8c-415bcfa81582",
   "metadata": {},
   "source": [
    "## Test on a random walk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f89cdd7-abb3-402c-b832-a6fac9251243",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = []\n",
    "for i in range(80):\n",
    "    context += [i % 2] * random.randint(1,5)\n",
    "\n",
    "plt.plot(context)\n",
    "context = torch.from_numpy(np.array(context)).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e9ee127-3da8-456e-bf16-b1c38a529e3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_collection_d = []\n",
    "\n",
    "def lm_head_hook(module, input, output):\n",
    "    logits_collection_d.append(output.detach().cpu())\n",
    "\n",
    "handle = chronos.model.model.lm_head.register_forward_hook(lm_head_hook)\n",
    "\n",
    "output_chronos = chronos.predict(\n",
    "                context,\n",
    "                prediction_length=64,\n",
    "                num_samples=1,\n",
    "            ).numpy()\n",
    "\n",
    "logits = torch.cat(logits_collection_d, dim=1)  # [batch, 1000, vocab_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da49c165-2e54-4f95-8355-e7d3c8644290",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Chronos Output\n",
    "'''\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[1]), context[0,:], label=\"context\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_chronos.shape[2]),output_chronos[0,0,:], label=\"forecast\")\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chrono's Forecast\")\n",
    "plt.legend(loc=\"lower left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61bacef2-0adf-4882-8fb1-3a99487dfcd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Chronos-Bolt Output\n",
    "'''\n",
    "output_bolt = bolt_p1.predict(\n",
    "                context,\n",
    "                prediction_length=64,\n",
    "            ).numpy()\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(0,context.shape[1]), context[0,:], label=\"context\")\n",
    "plt.plot(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]),output_bolt[0,4,:], label=\"forecast\")\n",
    "plt.fill_between(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]), output_bolt[0,2,:], output_bolt[0,6,:], color='orange', alpha=0.4, label=\"30-70%\")\n",
    "plt.fill_between(np.arange(context.shape[1],context.shape[1]+output_bolt.shape[2]), output_bolt[0,0,:], output_bolt[0,8,:], color='orange', alpha=0.2, label=\"10-90%\")\n",
    "\n",
    "plt.xlabel(\"time step\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.title(\"Chronos-Bolt's Forecast\")\n",
    "plt.legend(loc=\"lower left\")\n",
    "\n",
    "plt.ylim(-0.05,1.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c09a0528-be00-43da-8f27-672688f72643",
   "metadata": {},
   "source": [
    "## Understand the Logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff36f2f7-6d40-48d4-9680-9411ac7d19d5",
   "metadata": {},
   "source": [
    "### Uniform walk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88b13c8e-cd93-41ec-863f-f14bb196fee7",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = 0\n",
    "plt.axvline(x=2049, color=\"blue\", linestyle=\":\")\n",
    "plt.axvline(x=2327, color=\"green\", linestyle=\":\")\n",
    "prob = F.softmax(logits_wave[0,counter,:])\n",
    "plt.plot(np.arange(0,2049), prob[0:2049],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot(np.arange(2050,2327), prob[2050:2327],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot(np.arange(2328,4096), prob[2328:4096],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot([2049], prob[[2049]],'x',color='blue',alpha=0.8,markersize=8, label=\"bin of 0\")\n",
    "plt.plot([2327], prob[[2327]],'x',color='green',alpha=0.8,markersize=8, label=\"bin of 1\")\n",
    "#plt.yscale(\"log\")\n",
    "plt.ylim(-0.05,1.05)\n",
    "plt.xlim(2000,2500)\n",
    "plt.xlabel(\"bin index\")\n",
    "plt.ylabel(\"probability\")\n",
    "np.argmax(logits_wave[0,counter,:])\n",
    "plt.legend()\n",
    "plt.title(\"Forecasting the First Token by Chronos (Uniform Context)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7887c388-5ee7-4d79-b9c0-b55c2b9dc9cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = F.softmax(logits_wave[0,...])\n",
    "plt.plot(prob[:,2049], '--', color=\"blue\", label=\"bin of 0\")\n",
    "plt.plot(prob[:,2327], '--', color=\"green\", label=\"bin of 1\")\n",
    "plt.plot(prob[:,2052], '--', color=\"red\", label=\"bin of 0.01\")\n",
    "plt.title(\"Forecasting the Every Token by Chronos (Uniform Context)\")\n",
    "plt.xlabel(\"forecast step\")\n",
    "plt.ylabel(\"probability\")\n",
    "plt.legend(loc = \"lower left\")\n",
    "plt.ylim(-0.05,1.05)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aca682d-4bb5-46b5-b864-869bf51e2ab2",
   "metadata": {},
   "source": [
    "### Random walk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb23b3fd-631e-4709-b4d1-1b7d974c2f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter = 0\n",
    "plt.axvline(x=2049, color=\"blue\", linestyle=\":\")\n",
    "plt.axvline(x=2320, color=\"green\", linestyle=\":\")\n",
    "prob = F.softmax(logits_bif[0,counter,:])\n",
    "plt.plot(np.arange(0,2049), prob[0:2049],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot(np.arange(2050,2320), prob[2050:2320],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot(np.arange(2321,4096), prob[2321:4096],'.',color='r',alpha=0.8,markersize=2)\n",
    "plt.plot([2049], prob[[2049]],'x',color='blue',alpha=0.8,markersize=8, label=\"bin of 0\")\n",
    "plt.plot([2320], prob[[2320]],'x',color='green',alpha=0.8,markersize=8, label=\"bin of 1\")\n",
    "#plt.yscale(\"log\")\n",
    "plt.ylim(-0.05,1.05)\n",
    "plt.xlim(2000,2500)\n",
    "plt.xlabel(\"bin index\")\n",
    "plt.ylabel(\"probability\")\n",
    "plt.legend()\n",
    "plt.title(\"Forecasting the First Token by Chronos (Non-uniform Context)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a07ba20e-4279-4f92-b02b-e949378bfd0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prob = F.softmax(logits_bif[0,...])\n",
    "plt.plot(prob[:,2049], '--', color=\"blue\", label=\"bin of 0\")\n",
    "plt.plot(prob[:,2320], '--', color=\"green\", label=\"bin of 1\")\n",
    "plt.plot(prob[:,2052], '--', color=\"red\", label=\"bin of 0.01\")\n",
    "plt.ylim(-0.05,1.05)\n",
    "plt.title(\"Forecasting the Every Token by Chronos (Non-uniform Context)\")\n",
    "plt.xlabel(\"forecast step\")\n",
    "plt.ylabel(\"probability\")\n",
    "plt.legend(loc = \"lower left\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddc976e9-ddd8-42de-87ed-8b4c86511a3f",
   "metadata": {},
   "source": [
    "## Visualize the Loss Landscape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274df5d5-1039-4bce-8af0-418368fab40f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "def quantile_loss(y_true, y_pred, quantile):\n",
    "    return np.maximum(quantile * (y_true - y_pred), (quantile - 1) * (y_true - y_pred))\n",
    "\n",
    "q_25, q_50, q_75 = np.meshgrid(np.linspace(0, 1, 21),\n",
    "                               np.linspace(0, 1, 21),\n",
    "                               np.linspace(0, 1, 21))\n",
    "\n",
    "loss1 = quantile_loss(0, q_25, 0.25) + quantile_loss(0, q_50, 0.5) + quantile_loss(0, q_75, 0.75)\n",
    "loss2 = quantile_loss(1, q_25, 0.25) + quantile_loss(1, q_50, 0.5) + quantile_loss(1, q_75, 0.75)\n",
    "total_loss = loss1 + loss2\n",
    "\n",
    "fig = plt.figure(figsize=(24, 24))\n",
    "\n",
    "ax1 = fig.add_subplot(331, projection='3d')\n",
    "sc1 = ax1.scatter(q_25, q_50, q_75, c=loss1.ravel(), cmap='viridis', alpha=1)\n",
    "ax1.set_title('Loss Landscape (Ground Truth = 0)')\n",
    "ax1.set_xlabel('0.25 Quantile')\n",
    "ax1.set_ylabel('0.50 Quantile')\n",
    "ax1.set_zlabel('0.75 Quantile')\n",
    "fig.colorbar(sc1, ax=ax1, shrink=0.5, aspect=10, label='Loss')\n",
    "\n",
    "ax2 = fig.add_subplot(332, projection='3d')\n",
    "sc2 = ax2.scatter(q_25, q_50, q_75, c=loss2.ravel(), cmap='viridis', alpha=1)\n",
    "ax2.set_title('Loss Landscape (Ground Truth = 1)')\n",
    "ax2.set_xlabel('0.25 Quantile')\n",
    "ax2.set_ylabel('0.50 Quantile')\n",
    "ax2.set_zlabel('0.75 Quantile')\n",
    "fig.colorbar(sc2, ax=ax2, shrink=0.5, aspect=10, label='Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(333, projection='3d')\n",
    "sc3 = ax3.scatter(q_25, q_50, q_75, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.25 Quantile')\n",
    "ax3.set_ylabel('0.50 Quantile')\n",
    "ax3.set_zlabel('0.75 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(334, projection='3d')\n",
    "sc3 = ax3.scatter(q_25, q_50, q_75, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.25 Quantile')\n",
    "ax3.set_ylabel('0.50 Quantile')\n",
    "ax3.set_zlabel('0.75 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(335, projection='3d')\n",
    "sc3 = ax3.scatter(q_25, q_75, q_50, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.25 Quantile')\n",
    "ax3.set_ylabel('0.75 Quantile')\n",
    "ax3.set_zlabel('0.50 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(336, projection='3d')\n",
    "sc3 = ax3.scatter(q_50, q_25, q_75, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.50 Quantile')\n",
    "ax3.set_ylabel('0.25 Quantile')\n",
    "ax3.set_zlabel('0.75 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(337, projection='3d')\n",
    "sc3 = ax3.scatter(q_50, q_75, q_25, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.50 Quantile')\n",
    "ax3.set_ylabel('0.75 Quantile')\n",
    "ax3.set_zlabel('0.25 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(338, projection='3d')\n",
    "sc3 = ax3.scatter(q_75, q_25, q_50, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.75 Quantile')\n",
    "ax3.set_ylabel('0.25 Quantile')\n",
    "ax3.set_zlabel('0.50 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(339, projection='3d')\n",
    "sc3 = ax3.scatter(q_75, q_50, q_25, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses')\n",
    "ax3.set_xlabel('0.75 Quantile')\n",
    "ax3.set_ylabel('0.50 Quantile')\n",
    "ax3.set_zlabel('0.25 Quantile')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40d2481c-65e1-4aa4-9b46-495cf8ccd280",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(z):\n",
    "    z_shifted = z - np.max(z, axis=-1, keepdims=True)\n",
    "    exp_z = np.exp(z_shifted)\n",
    "    return exp_z / np.sum(exp_z, axis=-1, keepdims=True)\n",
    "\n",
    "def cross_entropy_loss(logits, true_class_index):\n",
    "    probs = softmax(logits)\n",
    "    true_class_prob = np.take_along_axis(probs, np.full(probs.shape[:-1] + (1,), true_class_index), axis=-1).squeeze(-1)\n",
    "    # Clip probabilities to avoid log(0)\n",
    "    true_class_prob = np.clip(true_class_prob, 1e-9, 1.0)\n",
    "    return -np.log(true_class_prob)\n",
    "\n",
    "# Create a grid of logits for three classes\n",
    "logit0, logit1, logit2 = np.meshgrid(np.linspace(-1, 1, 20),\n",
    "                                     np.linspace(-1, 1, 20),\n",
    "                                     np.linspace(-1, 1, 20))\n",
    "\n",
    "# Stack the logits into a single array for easier computation\n",
    "logits = np.stack([logit0, logit1, logit2], axis=-1)\n",
    "\n",
    "# --- Calculate Losses ---\n",
    "# Panel 1: Loss when the ground truth is Class 0\n",
    "loss1 = cross_entropy_loss(logits, 0)\n",
    "# Panel 2: Loss when the ground truth is Class 1\n",
    "loss2 = cross_entropy_loss(logits, 2)\n",
    "# Panel 3: Sum of the losses from the first two panels\n",
    "total_loss = loss1 + loss2\n",
    "\n",
    "# --- Plotting ---\n",
    "fig = plt.figure(figsize=(24, 24))\n",
    "\n",
    "# Plot for Ground Truth = Class 0\n",
    "ax1 = fig.add_subplot(331, projection='3d')\n",
    "sc1 = ax1.scatter(logit0, logit1, logit2, c=loss1.ravel(), cmap='viridis', alpha=1)\n",
    "ax1.set_title('Loss Landscape (Ground Truth = Class 0)')\n",
    "ax1.set_xlabel('Logit Class 0')\n",
    "ax1.set_ylabel('Logit Class 1')\n",
    "ax1.set_zlabel('Logit Class 2')\n",
    "fig.colorbar(sc1, ax=ax1, shrink=0.5, aspect=10, label='Cross-Entropy Loss')\n",
    "\n",
    "# Plot for Ground Truth = Class 1\n",
    "ax2 = fig.add_subplot(332, projection='3d')\n",
    "sc2 = ax2.scatter(logit0, logit1, logit2, c=loss2.ravel(), cmap='viridis', alpha=1)\n",
    "ax2.set_title('Loss Landscape (Ground Truth = Class 2)')\n",
    "ax2.set_xlabel('Logit Class 0')\n",
    "ax2.set_ylabel('Logit Class 1')\n",
    "ax2.set_zlabel('Logit Class 2')\n",
    "fig.colorbar(sc2, ax=ax2, shrink=0.5, aspect=10, label='Cross-Entropy Loss')\n",
    "\n",
    "# Plot for the sum of losses\n",
    "ax3 = fig.add_subplot(333, projection='3d')\n",
    "sc3 = ax3.scatter(logit0, logit1, logit2, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 0')\n",
    "ax3.set_ylabel('Logit Class 1')\n",
    "ax3.set_zlabel('Logit Class 2')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(334, projection='3d')\n",
    "sc3 = ax3.scatter(logit0, logit1, logit2, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 0')\n",
    "ax3.set_ylabel('Logit Class 1')\n",
    "ax3.set_zlabel('Logit Class 2')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(335, projection='3d')\n",
    "sc3 = ax3.scatter(logit0, logit2, logit1, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 0')\n",
    "ax3.set_ylabel('Logit Class 2')\n",
    "ax3.set_zlabel('Logit Class 1')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(336, projection='3d')\n",
    "sc3 = ax3.scatter(logit1, logit0, logit2, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 1')\n",
    "ax3.set_ylabel('Logit Class 0')\n",
    "ax3.set_zlabel('Logit Class 2')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(337, projection='3d')\n",
    "sc3 = ax3.scatter(logit1, logit2, logit0, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 1')\n",
    "ax3.set_ylabel('Logit Class 2')\n",
    "ax3.set_zlabel('Logit Class 0')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(338, projection='3d')\n",
    "sc3 = ax3.scatter(logit2, logit0, logit1, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 2')\n",
    "ax3.set_ylabel('Logit Class 0')\n",
    "ax3.set_zlabel('Logit Class 1')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "ax3 = fig.add_subplot(339, projection='3d')\n",
    "sc3 = ax3.scatter(logit2, logit1, logit0, c=total_loss.ravel(), cmap='viridis', alpha=1)\n",
    "ax3.set_title('Sum of Losses (GT Class 0 + GT Class 2)')\n",
    "ax3.set_xlabel('Logit Class 2')\n",
    "ax3.set_ylabel('Logit Class 1')\n",
    "ax3.set_zlabel('Logit Class 0')\n",
    "fig.colorbar(sc3, ax=ax3, shrink=0.5, aspect=10, label='Total Loss')\n",
    "\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
