{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import random\n",
    "import h5py\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import Dataset, TensorDataset\n",
    "import torch.optim as optim\n",
    "\n",
    "from regression.CNOClassification import CNOClassificationModel_pl\n",
    "from utils.utils_data import get_loader, load_data, read_cli_inference, find_files_with_extension, save_errors\n",
    "from diffusion.variance_fn import marginal_prob_std_1, diffusion_coeff_1, marginal_prob_std_2, diffusion_coeff_2\n",
    "from GenCFD.model.lightning_wrap.pl_conditional_denoiser import PreconditionedDenoiser_pl\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "which = \"cifar_class\"\n",
    "which_type = \"x&y\"\n",
    "\n",
    "if \"mnist\" in which:\n",
    "    N_train = 2000\n",
    "else:\n",
    "    N_train = 40000\n",
    "\n",
    "\n",
    "ood_share = 0.1\n",
    "batch_size = 64\n",
    "\n",
    "train_loader = get_loader(which_data = which,\n",
    "                        which_type = \"train\",\n",
    "                        N_samples = N_train,\n",
    "                        ood_share= ood_share,\n",
    "                        batch_size =  batch_size)\n",
    "\n",
    "valid_loader = get_loader(which_data = which,\n",
    "                        which_type = \"val\",\n",
    "                        N_samples = 0,\n",
    "                        ood_share= ood_share,\n",
    "                        batch_size = batch_size)\n",
    "\n",
    "test_loader = get_loader(which_data = which,\n",
    "                        which_type = \"test\",\n",
    "                        N_samples = 0,\n",
    "                        ood_share= 1.0,\n",
    "                        batch_size = 16)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Show the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images,y = next(iter(valid_loader))\n",
    "images = torch.permute(images, (0, 2, 3, 1))\n",
    "images = (255*(images.detach().cpu().numpy()*0.5 + 0.5)).astype(\"uint8\")\n",
    "fig = plt.figure(figsize=(25, 4))\n",
    "for idx in np.arange(20):\n",
    "    ax = fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])\n",
    "    plt.imshow(images[idx])\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import functools\n",
    "import os\n",
    "\n",
    "\n",
    "if \"cifar\" in which:\n",
    "    diff_model = \"/path_to_diff_model/\"\n",
    "elif \"mnist\" in which:\n",
    "    diff_model = \"/path_to_diff_model/\"\n",
    "\n",
    "diffusion_model_path = str(find_files_with_extension(diff_model + \"/model\", \"ckpt\", [], is_pl = True)[0])\n",
    "diffusion_config_path = str(find_files_with_extension(diff_model, \"json\", [\"param\"])[0])\n",
    "config_diff = argparse.Namespace(**load_data(diffusion_config_path))\n",
    "config_diff_arch = load_data(config_diff.config_arch)\n",
    "\n",
    "sigma = config_diff.sigma\n",
    "#sigma = sigma*(5/6)\n",
    "marginal_prob_std_fn = functools.partial(marginal_prob_std_2, sigma_min = 0.001, sigma_max=sigma, device = \"cuda\")\n",
    "diffusion_coeff_fn = functools.partial(diffusion_coeff_2, sigma_min = 0.001, sigma_max=sigma, device = \"cuda\")\n",
    "\n",
    "if which_type == \"x&y\":\n",
    "    dim_cond = 0\n",
    "    dim = config_diff.in_dim + 1\n",
    "else:\n",
    "    dim_cond = 1\n",
    "    dim = config_diff.in_dim\n",
    "\n",
    "print(config_diff_arch)\n",
    "print(config_diff)\n",
    "diffusion_model = PreconditionedDenoiser_pl(dim = dim, \n",
    "                                            dim_cond = dim_cond,\n",
    "                                            loss_fn = None,\n",
    "                                            marginal_prob_std_fn = marginal_prob_std_fn,\n",
    "                                            diffusion_coeff_fn = diffusion_coeff_fn,\n",
    "                                            config_train = vars(config_diff),\n",
    "                                            config_arch = config_diff_arch,\n",
    "                                            is_inference = True\n",
    "                                            )\n",
    "device = \"cuda\"\n",
    "checkpoint = torch.load(diffusion_model_path, map_location = device)\n",
    "diffusion_model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "diffusion_model = diffusion_model.best_model_ema.to(device)\n",
    "\n",
    "\n",
    "'''\n",
    "    Classification:\n",
    "'''\n",
    "\n",
    "if \"cifar\" in which:\n",
    "    path = \"/path_to_class_model/\"\n",
    "    reg_path_errors = \"/path_to_err_folder/\"\n",
    "elif \"mnist\" in which: \n",
    "    path = \"/path_to_class_model/\"\n",
    "    reg_path_errors = \"/path_to_err_folder/\"\n",
    "\n",
    "class_path = str(find_files_with_extension(path, \"json\", [\"param\"])[0])\n",
    "config_class = load_data(class_path)\n",
    "config_class[\"workdir\"] = None\n",
    "config_arch = load_data(config_class[\"config_arch\"])\n",
    "model = CNOClassificationModel_pl(in_dim = config_class[\"in_dim\"], \n",
    "                                out_dim = config_class[\"out_dim\"],\n",
    "                                loss_fn = None,\n",
    "                                config_train = config_class,\n",
    "                                config_arch = config_arch)\n",
    "model.load_state_dict(torch.load(f\"{path}/model-cifar.pt\"))\n",
    "model = model.to(\"cuda\")\n",
    "\n",
    "\n",
    "if not os.path.exists(reg_path_errors):\n",
    "    os.makedirs(reg_path_errors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io import loadmat\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "\n",
    "def load_svhn(N, device = \"cuda\"):\n",
    "    file = \"/path_to_svhn/\"\n",
    "    data = loadmat(file)\n",
    "\n",
    "    x_data = data['X']\n",
    "    y_data = data['y']\n",
    "    print(x_data.shape, y_data.shape)\n",
    "    X = ((torch.tensor(np.array(x_data)[:,:,:,:N], device = device))/255-0.5)/0.5\n",
    "    X = X.type(torch.float32)\n",
    "\n",
    "    Y = torch.tensor(np.array(y_data)[:N], device = device) - 1\n",
    "    Y = Y.type(torch.int64)\n",
    "\n",
    "    # Assuming X and Y are already defined\n",
    "    dataset = TensorDataset(X.permute(3,2,0,1),Y[:,-1])\n",
    "\n",
    "    # Create DataLoader\n",
    "    batch_size = 16\n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "    Get p(y)\n",
    "'''\n",
    "\n",
    "def p(train_targets, y = 1):\n",
    "    print(len(train_targets[train_targets == y]), len(train_targets))\n",
    "    return len(train_targets[train_targets == y])/len(train_targets)\n",
    "\n",
    "train_targets = torch.zeros(0, device = 'cuda')\n",
    "for i,(data, target) in enumerate(train_loader):    \n",
    "    _, target = data.cuda(), target.cuda()\n",
    "    train_targets = torch.cat((train_targets, target), axis = 0)\n",
    "\n",
    "Y = torch.zeros(10)\n",
    "for y in range(10):\n",
    "    Y[y] = p(train_targets, y)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_image(pred, x, target, classes, is_svhn = False):\n",
    "    if x.shape[1]>1:\n",
    "        x = (255*(x.permute(0,2,3,1).detach().cpu().numpy()*0.5 + 0.5)).astype(\"uint8\")\n",
    "    for j in range(pred.shape[0]):\n",
    "            if j ==0:\n",
    "                fig = plt.figure(figsize=(25, 4))\n",
    "\n",
    "            ax = fig.add_subplot(2, pred.shape[0]//2, j+1, xticks=[], yticks=[])\n",
    "            \n",
    "            if x.shape[1] == 1:\n",
    "                cax = ax.imshow((x[j,0].detach().cpu().numpy()), cmap = \"seismic\", vmin =0, vmax =1)\n",
    "            else:\n",
    "                cax = ax.imshow(x[j], cmap = \"seismic\", vmin =0, vmax =1)\n",
    "            #ax.set_title(f\"{str(target[j].item())}{list(probs.detach().cpu().numpy()[j])}\")\n",
    "            if not is_svnh:\n",
    "                ax.set_title(f\"True label: {classes[target[j].item()]}\")\n",
    "\n",
    "                    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusion.likelihood import ode_likelihood\n",
    "import copy\n",
    "\n",
    "A = 0\n",
    "\n",
    "def evaluate(model, test_loader, stop_id = -1, which_type = \"yx\", T = 1, noisy_labels = True, is_svnh = False):\n",
    "\n",
    "    if \"cifar\" in which:\n",
    "        classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "    else:\n",
    "        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']\n",
    "\n",
    "    test_loss = 0.0\n",
    "    class_correct = list(0. for i in range(10))\n",
    "    class_total = list(0. for i in range(10))\n",
    "\n",
    "    model.eval()\n",
    "    bpds = np.zeros((0,))\n",
    "    predicted = np.zeros((0,))\n",
    "    targets = np.zeros((0,))\n",
    "\n",
    "\n",
    "    epsilon_size = 16\n",
    "    if \"mnist\" in which:\n",
    "        dim = 1\n",
    "        s = 28\n",
    "    else:\n",
    "        dim = 3\n",
    "        s = 32\n",
    "    \n",
    "    if which_type == \"x&y\":\n",
    "        dim+=1\n",
    "    epsilon = torch.randn((epsilon_size, dim, s, s), device = device).type(torch.float32)\n",
    "    epsilon = torch.sqrt(torch.prod(torch.tensor(dim, device=device))) * epsilon / torch.norm(epsilon, dim=1, keepdim=True)\n",
    "\n",
    "    for i,(data, target) in enumerate(test_loader):\n",
    "        \n",
    "        if stop_id!=-1 and i>=stop_id:\n",
    "            return bpds, targets, predicted, class_correct, class_total\n",
    "\n",
    "        if len(target.shape)>1:\n",
    "            target = target[:,0]\n",
    "        \n",
    "        data, target = data.cuda(), target.cuda()\n",
    "    \n",
    "        output = model(data)\n",
    "        _, pred = torch.max(output, 1)   \n",
    "        probs = torch.exp(output/T)/torch.sum(torch.exp(output/T), axis = 1)[:, None]\n",
    "\n",
    "        targets = np.concatenate((targets, target.detach().cpu().numpy()), axis=0)\n",
    "        predicted = np.concatenate((predicted, pred.detach().cpu().numpy()), axis=0)\n",
    "\n",
    "        correct_tensor = pred.eq(target.data.view_as(pred))\n",
    "        correct = np.squeeze(correct_tensor.cpu().numpy())\n",
    "        length = data.shape[0]\n",
    "        for j in range(length):\n",
    "            label = target.data[j]\n",
    "            class_correct[label] += correct[j].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "        shape = (pred.shape[0], 1) + data.shape[2:]\n",
    "\n",
    "        label_gen = (pred.view(pred.shape[0], 1, 1, 1)/9.0) * torch.ones(shape, device = data.device).type(torch.float32)\n",
    "        \n",
    "        if noisy_labels:\n",
    "            label_gen = torch.zeros(shape, device = data.device).type(torch.float32)\n",
    "            for j in range(pred.shape[0]):\n",
    "                ## Create a categorical distribution\n",
    "                dist = torch.distributions.Categorical(probs=probs[j])\n",
    "\n",
    "                ## Sample (128 x 128) integers from the distribution\n",
    "                label_gen[j,0] = dist.sample((data.shape[2], data.shape[3]))/9.0\n",
    "        \n",
    "        \n",
    "        if which_type == \"yx\":\n",
    "            variable = data\n",
    "            condition = label_gen\n",
    "        else:\n",
    "            variable = torch.cat((data, label_gen), axis = 1)\n",
    "            condition = None\n",
    "                \n",
    "        _, prior, delta = ode_likelihood(diffusion_model,\n",
    "                                        variable,\n",
    "                                        condition,\n",
    "                                        marginal_prob_std_fn,\n",
    "                                        diffusion_coeff_fn,\n",
    "                                        t_batch = None,\n",
    "                                        batch_size=epsilon_size,\n",
    "                                        device='cuda',\n",
    "                                        eps = 1e-10,\n",
    "                                        rtol = 1e-4,\n",
    "                                        atol = 1e-4,\n",
    "                                        epsilon = epsilon,\n",
    "                                        ode_method = \"rk38\",\n",
    "                                        reduce_prior = True)\n",
    "\n",
    "        bpd = prior + delta     \n",
    "        bpds = np.concatenate((bpds, bpd.detach().cpu().numpy()))\n",
    "        \n",
    "        print(i, len(test_loader), np.mean(bpds))\n",
    "        \n",
    "        for j in range(pred.shape[0]):\n",
    "            if i < 2 and j<16:\n",
    "\n",
    "                if i < 2 and j ==0:\n",
    "                    fig = plt.figure(figsize=(25, 4))\n",
    "                    cmap = plt.get_cmap('tab10', 10)\n",
    "\n",
    "                ax = fig.add_subplot(2, pred.shape[0]//2, j+1, xticks=[], yticks=[])\n",
    "                \n",
    "                cax = ax.imshow((9.0*label_gen[j,0].detach().cpu().numpy()).astype(\"uint8\"), cmap = cmap, vmin =-0.5, vmax = 9.5)\n",
    "                if not is_svnh:\n",
    "                    ax.set_title(f\"True label: {classes[target[j].item()]}, l ={round(bpd[j].item(), 1)}\")\n",
    "                cbar = fig.colorbar(cax,ticks=list(np.arange(0,10)))\n",
    "                cbar.ax.set_yticklabels(classes)  # horizontal colorbar\n",
    "                \n",
    "            \n",
    "            if i <2 and j == pred.shape[0]-1:\n",
    "                plt.show()\n",
    "        \n",
    "        if i <2:\n",
    "            plot_image(pred, data, target, classes, is_svnh)\n",
    "\n",
    "    # average test loss\n",
    "    test_loss = test_loss/len(test_loader.dataset)\n",
    "    print('Test Loss: {:.6f}\\n'.format(test_loss))\n",
    "\n",
    "    for i in range(10):\n",
    "        if class_total[i] > 0:\n",
    "            print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (\n",
    "                classes[i], 100 * class_correct[i] / class_total[i],\n",
    "                np.sum(class_correct[i]), np.sum(class_total[i])))\n",
    "        else:\n",
    "            print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n",
    "\n",
    "    print('\\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (\n",
    "        100. * np.sum(class_correct) / np.sum(class_total),\n",
    "        np.sum(class_correct), np.sum(class_total)))\n",
    "    \n",
    "    return bpds, targets, predicted, class_correct, class_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "\n",
    "def save_errors(folder, bpds, targets, predicted, class_correct, class_total, is_svhn=False, noisy_labels=True):\n",
    "    if not os.path.exists(folder):\n",
    "        os.makedirs(folder)\n",
    "    if is_svhn:\n",
    "        tag1 = \"svhn\"\n",
    "    else:\n",
    "        tag1 = \"cifar10\"\n",
    "    if noisy_labels:\n",
    "        tag2 = \"noisy_labels\"\n",
    "    else:\n",
    "        tag2 = \"NOT_noisy_labels\"\n",
    "    file = f\"{folder}/{tag1}_{tag2}_{len(bpds)}.txt\"\n",
    "    s = f\"{str(list(bpds))}\\n{str(list(targets))}\\n{str(list(predicted))}\\n{str(list(class_correct))}\\n{str(list(class_total))}\"\n",
    "\n",
    "    text_file = open(file, \"w\")\n",
    "    text_file.write(s)\n",
    "    text_file.close()\n",
    "\n",
    "    print(file, \"SAVED\")\n",
    "\n",
    "def load_errors(file):\n",
    "    with open(file, \"r\") as f:\n",
    "        lines = f.readlines()\n",
    "        if len(lines) != 5:\n",
    "            raise ValueError(\"Expected 5 lines in the file\")\n",
    "\n",
    "        bpds = ast.literal_eval(lines[0].strip())\n",
    "        targets = ast.literal_eval(lines[1].strip())\n",
    "        predicted = ast.literal_eval(lines[2].strip())\n",
    "        class_correct = ast.literal_eval(lines[3].strip())\n",
    "        class_total = ast.literal_eval(lines[4].strip())\n",
    "\n",
    "    return np.array(bpds), np.array(targets), np.array(predicted), np.array(class_correct), np.array(class_total)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
