{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "use sdp attention as default\n",
      "keep default attention mode\n",
      "use sdp attention as default\n",
      "keep default attention mode\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import os\n",
    "from argparse import ArgumentParser, Namespace\n",
    "from typing import Optional, Tuple, Set, List, Dict\n",
    "import sys\n",
    "\n",
    "\n",
    "sys.path.append(\"./MST/simulation/train_code\")\n",
    "cuda_device = \"cuda\"\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "\n",
    "sys.path.append('.')\n",
    "sys.path.append('./DiffBIR')\n",
    "\n",
    "from accelerate.utils import set_seed\n",
    "from DiffBIR.utils.inference import InferenceLoop\n",
    "from DiffBIR.utils.common import instantiate_from_config, load_file_from_url, count_vram_usage\n",
    "from DiffBIR.utils.helpers import MSI_Pipeline\n",
    "from DiffBIR.model.gaussian_diffusion import Diffusion\n",
    "from DiffBIR.model.cldm import ControlLDM\n",
    "from DiffBIR.utils.cond_fn import Guidance\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "from argparse import ArgumentParser, Namespace\n",
    "from MST.simulation.train_code.utils import *\n",
    "from MST.simulation.train_code.architecture import *\n",
    "import scipy.io as scio\n",
    "from DiffBIR.utils.common import count_vram_usage, wavelet_decomposition_msi\n",
    "\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2836668/3081980065.py:127: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  URSe_model = torch.load(\"weights/model_URSe_hf3_endecoder_c21_bu2_c9_DConvWoBN_resca_silu_2024-09-05_psnr49.5199.pt\", map_location=\"cpu\").to(cuda_device)\n"
     ]
    }
   ],
   "source": [
    "class ChannelAttention(nn.Module):\n",
    "    def __init__(self, in_channels, reduction_ratio=2):\n",
    "        super(ChannelAttention, self).__init__()\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n",
    "        self.max_pool = nn.AdaptiveMaxPool2d(1)\n",
    "        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, x):\n",
    "        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))\n",
    "        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))\n",
    "        out = avg_out + max_out\n",
    "        return self.sigmoid(out)\n",
    "\n",
    "\n",
    "class URSEncoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(URSEncoder, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=28, out_channels=14, kernel_size=1)\n",
    "        self.conv2 = nn.ConvTranspose2d(in_channels=14, out_channels=7, kernel_size=2, stride=2)\n",
    "        self.conv3 = nn.Conv2d(in_channels=7, out_channels=3, kernel_size=1)\n",
    "        self.ca = ChannelAttention(14, 2)\n",
    "        self.ca_out = ChannelAttention(14, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        ret = self.ca_out(x)\n",
    "        x = self.ca(x) * x\n",
    "        x = self.conv2(x)\n",
    "        return self.conv3(x), ret\n",
    "\n",
    "\n",
    "class URSDecoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(URSDecoder, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=3, out_channels=7, kernel_size=1)\n",
    "        self.conv2 = nn.Conv2d(in_channels=7, out_channels=14, kernel_size=2, stride=2)\n",
    "        self.conv3 = nn.Conv2d(in_channels=14, out_channels=28, kernel_size=1)\n",
    "\n",
    "    def forward(self, x, ca):\n",
    "        x = self.conv1(x)\n",
    "        x = self.conv2(x)\n",
    "        x = ca * x \n",
    "        return self.conv3(x)\n",
    "\n",
    "\n",
    "class URSe(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(URSe, self).__init__()\n",
    "        self.encoder = URSEncoder()\n",
    "        self.decoder = URSDecoder()\n",
    "\n",
    "    def forward(self, x):\n",
    "        en, ca = self.encoder(x)\n",
    "        return self.decoder(en, ca)\n",
    "\n",
    "\n",
    "URSe_model = URSe()\n",
    "URSe_model = torch.load(\"weights/model_URSe_c21_bu2_c9_allca_shift4_conv_block_silu_wobn_2024-07-24_loss0.00001008.pt\", map_location=\"cpu\")\n",
    "\n",
    "URSe_model = URSe_model.eval().to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "from typing import Tuple\n",
    "\n",
    "\n",
    "class MeasMSEGuidance(Guidance):\n",
    "    def load_guidance(self, target: torch.Tensor, masks: torch.Tensor, MSI_E_vectors:torch.Tensor, max_val_channel, min_val_channel, inputs_msi_lf, decoder: torch.nn.Module) -> None:\n",
    "        self.target = target\n",
    "        self.mask3d_batch = masks\n",
    "        self.max_val_channel, self.min_val_channel, self.inputs_msi_lf = max_val_channel, min_val_channel, inputs_msi_lf\n",
    "        self.decoder = decoder\n",
    "        self.rgb_target = None\n",
    "        self.MSI_E_vectors = MSI_E_vectors\n",
    "        self.bias = 0\n",
    "        self.lambda_reg = 0.005\n",
    "        self.msi_scale = None\n",
    "\n",
    "    def load_bias(self, bias: torch.Tensor):\n",
    "        self.bias = bias\n",
    "\n",
    "    def load_rgb_target(self, rgb_target: torch.Tensor):\n",
    "        self.rgb_target = rgb_target\n",
    "\n",
    "    def _forward(self, target: torch.Tensor, pred_x0: torch.Tensor, t: int, visual: bool = False) -> Tuple[torch.Tensor, float]:\n",
    "        # Ensure the directory exists\n",
    "        if visual:\n",
    "            visual_dir = \"visual/\"\n",
    "            os.makedirs(visual_dir, exist_ok=True)\n",
    "\n",
    "        with torch.enable_grad():\n",
    "            pred_x0.requires_grad_(True)\n",
    "            pred_x0 = (pred_x0) / 2 + self.bias\n",
    "            # Clamp pred_x0 to [0, 1] range\n",
    "            pred_x0_clamped = torch.clamp(pred_x0, 0, 1)\n",
    "\n",
    "            def shift(inputs, step=2):\n",
    "                [bs, nC, row, col] = inputs.shape\n",
    "                output = torch.zeros(bs, nC, row, col + (nC - 1) * step).cuda().float()\n",
    "                for i in range(nC):\n",
    "                    output[:, i, :, step * i : step * i + col] = inputs[:, i, :, :]\n",
    "                return output\n",
    "\n",
    "            def gen_meas_torch(data_batch, mask3d_batch):\n",
    "                temp = shift(mask3d_batch * data_batch, 2)\n",
    "                meas = torch.sum(temp, 1)\n",
    "                return meas\n",
    "\n",
    "            # Calculate meas and loss\n",
    "            loss = 0\n",
    "            msi = self.decoder((pred_x0_clamped) * (self.max_val_channel - self.min_val_channel) + self.min_val_channel, self.MSI_E_vectors) + self.inputs_msi_lf\n",
    "\n",
    "            penalty_msi_low = torch.relu(-msi)  # penalize values below 0\n",
    "            msi_clamped = torch.clamp(msi, 0, 10)\n",
    "\n",
    "            meas = gen_meas_torch(msi_clamped, self.mask3d_batch)\n",
    "\n",
    "            # Calculate the regularization term for out-of-bound values\n",
    "            penalty_low = torch.relu(-pred_x0)  # penalize values below 0\n",
    "            penalty_high = torch.relu(pred_x0 - 1)  # penalize values above 1\n",
    "            regularization = penalty_low.mean((1, 2, 3)).sum() + penalty_high.mean((1, 2, 3)).sum() + penalty_msi_low.mean((1, 2)).sum()\n",
    "\n",
    "            # Add regularization to the loss\n",
    "            loss += (meas - target).pow(2).mean((1, 2)).sum() * 0.5\n",
    "            loss += (meas - target).abs().mean((1, 2)).sum() * 0.5\n",
    "            loss += self.lambda_reg * regularization  # lambda_reg is a weighting factor for the regularization term\n",
    "\n",
    "            if self.rgb_target is not None and self.rgb_subscale > 0:\n",
    "                loss += (pred_x0_clamped - self.rgb_target).pow(2).mean((1, 2, 3)).sum() * self.rgb_subscale\n",
    "\n",
    "        scale = self.scale\n",
    "        g = -torch.autograd.grad(loss, pred_x0)[0] * scale\n",
    "\n",
    "        if visual and t % 4 == 1:\n",
    "            with torch.no_grad():\n",
    "                # Prepare numpy arrays for pred_x0, meas, target, and difference (meas - target)\n",
    "                pred_x0_np = np.transpose(pred_x0_clamped.detach().cpu().numpy()[0], (1, 2, 0))  # HWC format\n",
    "                meas_np = meas.detach().cpu().numpy()[0]\n",
    "                target_np = target.detach().cpu().numpy()[0]\n",
    "                diff_np = (meas - target).detach().cpu().numpy()[0]\n",
    "\n",
    "                # Set up subplots\n",
    "                fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n",
    "                images = [(pred_x0_np, \"pred_x0\", None), (meas_np, \"meas\", \"gray\"), (target_np, \"target\", \"gray\"), (diff_np, \"meas - target\", \"coolwarm\")]  # pred_x0 image (no colorbar)  # meas image (grayscale)  # target image (grayscale)  # Difference image with colormap\n",
    "\n",
    "                # Display images\n",
    "                for ax, (img, title, cmap) in zip(axs, images):\n",
    "                    im = ax.imshow(img, cmap=cmap, vmin=0 if title in [\"meas\", \"target\"] else -0.4, vmax=10 if title in [\"meas\", \"target\"] else 0.4)\n",
    "                    ax.set_title(title)\n",
    "                    ax.axis(\"off\")\n",
    "\n",
    "                # Share a single colorbar between meas and target\n",
    "                cax = fig.add_axes([0.35, 0.1, 0.3, 0.03])  # Position for shared colorbar\n",
    "                fig.colorbar(axs[1].get_images()[0], cax=cax, orientation=\"horizontal\", label=\"Intensity (0-10)\")\n",
    "\n",
    "                # Save the combined figure\n",
    "                plt.tight_layout()\n",
    "                plt.savefig(os.path.join(visual_dir, f\"{t}_combined.png\"), bbox_inches=\"tight\", pad_inches=0)\n",
    "                plt.close()\n",
    "\n",
    "        return g, loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    # stage_1 model weights\n",
    "    \"bsrnet\": \"https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth\",\n",
    "    # the following checkpoint is up-to-date, but we use the old version in our paper\n",
    "    # \"swinir_face\": \"https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth\",\n",
    "    \"swinir_face\": \"https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt\",\n",
    "    \"scunet_psnr\": \"https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth\",\n",
    "    \"swinir_general\": \"https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt\",\n",
    "    # stage_2 model weights\n",
    "    \"sd_v21\": \"https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt\",\n",
    "    \"v1_face\": \"https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth\",\n",
    "    \"v1_general\": \"https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth\",\n",
    "    \"v2\": \"https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth\"\n",
    "}\n",
    "\n",
    "\n",
    "def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:\n",
    "    sd_path = load_file_from_url(url, model_dir=\"./DiffBIR/weights\")\n",
    "    sd = torch.load(sd_path, map_location=\"cpu\")\n",
    "    if \"state_dict\" in sd:\n",
    "        sd = sd[\"state_dict\"]\n",
    "    if list(sd.keys())[0].startswith(\"module\"):\n",
    "        sd = {k[len(\"module.\"):]: v for k, v in sd.items()}\n",
    "    return sd\n",
    "\n",
    "\n",
    "class InferenceLoop_NoPre:\n",
    "\n",
    "    def __init__(self, args: Namespace) -> \"InferenceLoop\":\n",
    "        self.args = args\n",
    "        self.loop_ctx = {}\n",
    "        self.pipeline: MSI_Pipeline = None\n",
    "        self.init_stage2_model()\n",
    "\n",
    "    @count_vram_usage\n",
    "    def init_stage2_model(self) -> None:\n",
    "        # load uent, vae, clip\n",
    "        self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load(\"Predict-and-Subspace-Refine/DiffBIR/configs/inference/cldm.yaml\"))\n",
    "        sd = load_model_from_url(MODELS[\"sd_v21\"])\n",
    "        unused = self.cldm.load_pretrained_sd(sd)\n",
    "        print(f\"strictly load pretrained sd_v2.1, unused weights: {unused}\")\n",
    "        # load controlnet\n",
    "        self.cldm.load_controlnet_from_ckpt(torch.load(self.args.ckpt, map_location=\"cpu\"))\n",
    "        print(f\"strictly load controlnet weight {self.args.ckpt}\")\n",
    "        if self.args.vae != None:\n",
    "            self.cldm.load_vae_from_ckpt(torch.load(self.args.vae, map_location=\"cpu\"))\n",
    "            print(f\"strictly load vae weight {self.args.vae}\")\n",
    "        self.cldm.eval().cuda()\n",
    "        # load diffusion\n",
    "        self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load(\"Predict-and-Subspace-Refine/DiffBIR/configs/inference/diffusion.yaml\"))\n",
    "        self.diffusion.cuda()\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def run(self, images: torch.tensor) -> torch.tensor:\n",
    "        # We don't support batch processing since input images may have different size\n",
    "\n",
    "        return self.pipeline.run_stage2(\n",
    "            images, self.args.steps, 1.0, self.args.tiled,\n",
    "            self.args.tile_size, self.args.tile_stride,\n",
    "            self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,\n",
    "            self.args.better_start\n",
    "        )\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def preprocess_data(input_meas: torch.tensor, input_mask: torch.tensor, PSRSCI_Pipeline: InferenceLoop_NoPre) -> Tuple[torch.tensor, torch.tensor]:\n",
    "    \"\"\"\n",
    "    Preprocess input data by applying necessary transformations and normalization.\n",
    "\n",
    "    Args:\n",
    "        input_meas (torch.tensor): Input measurement data.\n",
    "        input_mask (torch.tensor): Input mask data.\n",
    "        model (ControlLDM): Model instance.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[torch.tensor, torch.tensor]: Tuple containing preprocessed RGB images and normalization coefficients.\n",
    "    \"\"\"\n",
    "\n",
    "    n_samples = input_meas.shape[0]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        MSI_IMAGE = PSRSCI_Pipeline.MSI_model(input_meas, input_mask)\n",
    "        inputs_msi_hf, inputs_msi_lf = wavelet_decomposition_msi(MSI_IMAGE, 3)\n",
    "\n",
    "        MSI_images_encoded, MSI_E_vectors = PSRSCI_Pipeline.encoder(inputs_msi_hf)\n",
    "\n",
    "        RANGE_MAX = 0.85\n",
    "        RANGE_MIN = 0.15\n",
    "\n",
    "        range_channel = torch.tensor([MSI_images_encoded[i].max() - MSI_images_encoded[i].min() for i in range(n_samples)]).cuda()\n",
    "        max_val_channel = torch.tensor([MSI_images_encoded[i].max() + range_channel[i] / (RANGE_MAX - RANGE_MIN)*(1-RANGE_MAX) for i in range(n_samples)]).cuda().view(n_samples, 1, 1, 1)\n",
    "        min_val_channel = torch.tensor([MSI_images_encoded[i].min() - range_channel[i] / (RANGE_MAX - RANGE_MIN)*(RANGE_MIN) for i in range(n_samples)]).cuda().view(n_samples, 1, 1, 1)\n",
    "\n",
    "        normalized_images = (MSI_images_encoded - min_val_channel) / (max_val_channel - min_val_channel)\n",
    "\n",
    "        return normalized_images, MSI_E_vectors, max_val_channel, min_val_channel, inputs_msi_lf\n",
    "\n",
    "@torch.no_grad()\n",
    "def process_diffusion(\n",
    "    PSRSCI_Pipeline: InferenceLoop_NoPre,\n",
    "    normalized_images: torch.tensor,\n",
    "    MSI_E_vectors:torch.Tensor,\n",
    "    max_val_channel: torch.tensor,\n",
    "    min_val_channel: torch.tensor,\n",
    "    inputs_msi_lf: torch.tensor,\n",
    "    steps: int,\n",
    "    upscale: int,\n",
    "    cfg_scale: float,\n",
    "    cond_fn: Optional[MeasMSEGuidance],\n",
    "    tiled: bool,\n",
    "    tile_size: int,\n",
    "    tile_stride: int,\n",
    "    better_start: bool = False,\n",
    "    pos_prompt: str = \"\",\n",
    "    neg_prompt: str = \"low quality, blurry, low-resolution, noisy, unsharp, weird textures\",\n",
    ") -> Tuple[torch.tensor, torch.tensor]:\n",
    "    \"\"\"\n",
    "    Apply Diffusion model on preprocessed data to generate restoration results.\n",
    "\n",
    "    Args:\n",
    "        model (ControlLDM): Model.\n",
    "        normalized_images (torch.tensor): Preprocessed normalized images.\n",
    "        max_val_channel (torch.tensor): Maximum values for each channel.\n",
    "        min_val_channel (torch.tensor): Minimum values for each channel.\n",
    "        steps (int): Sampling steps.\n",
    "        strength (float): Control strength.\n",
    "        color_fix_type (str): Type of color correction for samples.\n",
    "        cond_fn (Guidance | None): Guidance function.\n",
    "        tiled (bool): If True, a patch-based sampling strategy will be used.\n",
    "        tile_size (int): Size of patch.\n",
    "        tile_stride (int): Stride of sliding patch.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[torch.tensor, torch.tensor]: Tuple containing restored images and diffusion outputs.\n",
    "    \"\"\"\n",
    "\n",
    "    if upscale > 1.0:\n",
    "        normalized_images_up = F.interpolate(normalized_images, size=(normalized_images.shape[-2] * upscale, normalized_images.shape[-1] * upscale), mode=\"bicubic\", antialias=True)\n",
    "    else:\n",
    "        normalized_images_up = normalized_images\n",
    "\n",
    "    PSRSCI_Pipeline.pipeline = MSI_Pipeline(PSRSCI_Pipeline.cldm, PSRSCI_Pipeline.diffusion, cond_fn, PSRSCI_Pipeline.args.device)\n",
    "\n",
    "    N = 2  \n",
    "\n",
    "    #  diffusion output list\n",
    "    diffusion_outputs_list = []\n",
    "    diffusion_outputs_list.append(normalized_images_up)\n",
    "\n",
    "    # calc median\n",
    "    for _ in range(N):\n",
    "        diffusion_output = PSRSCI_Pipeline.pipeline.run_stage2(\n",
    "            normalized_images_up, steps, 1.0, tiled, tile_size, tile_stride,\n",
    "            pos_prompt, neg_prompt, cfg_scale, better_start\n",
    "        )\n",
    "\n",
    "        if upscale > 1.0:\n",
    "            diffusion_output = F.interpolate(\n",
    "                diffusion_output,\n",
    "                size=(normalized_images.shape[-2], normalized_images.shape[-1]),\n",
    "                mode=\"bicubic\", antialias=True\n",
    "            )\n",
    "\n",
    "        diffusion_outputs_list.append(diffusion_output.contiguous())\n",
    "\n",
    "    diffusion_outputs_median = torch.stack(diffusion_outputs_list, dim=0).median(dim=0)[0]  # [0]是提取中位数的值\n",
    "\n",
    "    diffusion_outputs = diffusion_outputs_median.clamp(0, 1)\n",
    "\n",
    "    restored_images = (PSRSCI_Pipeline.decoder(diffusion_outputs * (max_val_channel - min_val_channel) + min_val_channel, MSI_E_vectors) + inputs_msi_lf).clamp(0, 10)\n",
    "\n",
    "    return restored_images, diffusion_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(ckpt='Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_cldmhf07f03/checkpoints/0040000.pt', vae='Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_vaehf04f03/checkpoints/0010000_vae.pt', channel_vae='Predict-and-Subspace-Refine/weights/model_URSe_hf3_endecoder_c21_bu2_c9_DConvWoBN_resca_silu_2024-09-05_psnr49.5199.pt', steps=200, better_start=True, upscale=1.0, tiled=False, tile_size=512, tile_stride=256, pos_prompt='', neg_prompt='', cfg_scale=1.0, n_samples=1, guidance=True, g_scale=1, g_t_start=400, g_t_stop=-1, g_space='rgb', g_repeat=1, seed=231, output='./results/', num_evals=300, device='cuda')\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "building SDPAttnBlock (sdp) with 512 in_channels\n",
      "building SDPAttnBlock (sdp) with 512 in_channels\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is None and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 320, context_dim is 1024 and using 5 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is None and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 640, context_dim is 1024 and using 10 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is None and using 20 heads.\n",
      "Setting up SDPCrossAttention (sdp). Query dim is 1280, context_dim is 1024 and using 20 heads.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2218378/2750478355.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  sd = torch.load(sd_path, map_location=\"cpu\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "strictly load pretrained sd_v2.1, unused weights: {'posterior_mean_coef2', 'model_ema.num_updates', 'model_ema.decay', 'posterior_mean_coef1', 'alphas_cumprod', 'posterior_log_variance_clipped', 'sqrt_recip_alphas_cumprod', 'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_alphas_cumprod', 'betas', 'posterior_variance', 'alphas_cumprod_prev', 'sqrt_recipm1_alphas_cumprod'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2218378/2750478355.py:43: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  self.cldm.load_controlnet_from_ckpt(torch.load(self.args.ckpt, map_location=\"cpu\"))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "strictly load controlnet weight Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_cldmhf07f03/checkpoints/0040000.pt\n",
      "strictly load vae weight Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_vaehf04f03/checkpoints/0010000_vae.pt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2218378/2750478355.py:46: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  self.cldm.load_vae_from_ckpt(torch.load(self.args.vae, map_location=\"cpu\"))\n",
      "/tmp/ipykernel_2218378/4155903439.py:44: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  PSRSCI_model = torch.load(args.channel_vae, map_location='cpu')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load model from ./MST/simulation/test_code/model_zoo/dauhst_9stg/dauhst_9stg.pth\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Predict-and-Subspace-Refine/MST/simulation/train_code/architecture/__init__.py:62: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  checkpoint = torch.load(pretrained_model_path)\n"
     ]
    }
   ],
   "source": [
    "def parse_args() -> Namespace:\n",
    "    parser = ArgumentParser()\n",
    "    # model parameters\n",
    "    parser.add_argument(\"--ckpt\", type=str, default=\"Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_cldmhf07f03/checkpoints/0040000.pt\")\n",
    "    parser.add_argument(\"--vae\", type=str, default=\"Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_vaehf04f03/checkpoints/0010000_vae.pt\")\n",
    "    parser.add_argument(\"--URSe\", type=str, default=\"Predict-and-Subspace-Refine/weights/model_URSe_hf3_endecoder_c21_bu2_c9_DConvWoBN_resca_silu_2024-09-05_psnr49.5199.pt\")\n",
    "    # sampling parameters\n",
    "    parser.add_argument(\"--steps\", type=int, default=200)\n",
    "    parser.add_argument(\"--better_start\", type=bool, default=True)\n",
    "    parser.add_argument(\"--upscale\", type=int, default=1.0)\n",
    "    parser.add_argument(\"--tiled\", type=bool, default=False)\n",
    "    parser.add_argument(\"--tile_size\", type=int, default=512)\n",
    "    parser.add_argument(\"--tile_stride\", type=int, default=256)\n",
    "    parser.add_argument(\"--pos_prompt\", type=str, default=\"\")\n",
    "    parser.add_argument(\"--neg_prompt\", type=str, default=\"\")\n",
    "    parser.add_argument(\"--cfg_scale\", type=float, default=1.0)\n",
    "    # input parameters\n",
    "    parser.add_argument(\"--n_samples\", type=int, default=1)\n",
    "    # guidance parameters\n",
    "    parser.add_argument(\"--guidance\", type=bool, default=True)\n",
    "    parser.add_argument(\"--g_scale\", type=float, default=1)\n",
    "    parser.add_argument(\"--g_t_start\", type=int, default=400)\n",
    "    parser.add_argument(\"--g_t_stop\", type=int, default=-1)\n",
    "    parser.add_argument(\"--g_space\", type=str, default=\"rgb\")\n",
    "    parser.add_argument(\"--g_repeat\", type=int, default=1)\n",
    "    # output parameters\n",
    "    # common parameters\n",
    "    parser.add_argument(\"--seed\", type=int, default=231)\n",
    "    parser.add_argument(\"--output\", type=str, default=\"./results/\")\n",
    "    parser.add_argument(\"--num_evals\", type=int, default=300)\n",
    "    parser.add_argument(\"--device\", type=str, default=cuda_device)\n",
    "\n",
    "    return parser.parse_known_args()[0]\n",
    "\n",
    "args = parse_args()\n",
    "print(args)\n",
    "args.device = torch.device(args.device)\n",
    "set_seed(args.seed)\n",
    "\n",
    "PSRSCI_Pipeline = InferenceLoop_NoPre(args=args)\n",
    "\n",
    "URSe_model = URSe()\n",
    "URSe_model = torch.load(args.URSe, map_location='cpu')\n",
    "URSe_model.eval().cuda()\n",
    "PSRSCI_Pipeline.encoder = URSe_model.encoder\n",
    "PSRSCI_Pipeline.decoder = URSe_model.decoder\n",
    "\n",
    "PSRSCI_Pipeline.MSI_model = model_generator(\"dauhst_3stg\", \"./MST/simulation/test_code/model_zoo/dauhst_3stg/dauhst_3stg.pth\")\n",
    "# PSRSCI_Pipeline.MSI_model = model_generator(\"mst_l\", \"./MST/simulation/test_code/model_zoo/mst/mst_l.pth\").cuda().eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask3d_batch, input_mask = init_mask(\"./MST/datasets/TSA_simu_data/\", \"Phi_PhiPhiT\", 10)\n",
    "# mask3d_batch, input_mask = init_mask(\"./MST/datasets/TSA_simu_data/\", \"Phi\", 10)\n",
    "\n",
    "test_data = LoadTest(\"./MST/datasets/TSA_simu_data/Truth/\")\n",
    "test_data = test_data.cuda().float()\n",
    "input_meas = init_meas(test_data, mask3d_batch, \"Y\").cuda()\n",
    "# input_meas = init_meas(test_data, mask3d_batch, \"H\").cuda()\n",
    "\n",
    "mask3d_batch, input_mask = init_mask(\"./MST/datasets/TSA_simu_data/\", \"Phi_PhiPhiT\", 1)\n",
    "# mask3d_batch, input_mask = init_mask(\"./MST/datasets/TSA_simu_data/\", \"Phi\", 1)\n",
    "\n",
    "ts = [0]\n",
    "\n",
    "input_meas = input_meas[ts]\n",
    "# input_mask = input_mask[t]\n",
    "# mask3d_batch = mask3d_batch[t]\n",
    "test_data = test_data[ts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28, 256, 256])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask3d_batch[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "from MST.simulation.train_code.architecture import *\n",
    "from MST.simulation.train_code.utils import *\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from MST.simulation.train_code.utils import torch_psnr, torch_ssim\n",
    "import torch\n",
    "\n",
    "\n",
    "def parse_args() -> Namespace:\n",
    "    parser = ArgumentParser()\n",
    "    # model parameters\n",
    "    parser.add_argument(\"--ckpt\", type=str, default=\"Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_cldmhf07f03/checkpoints/0040000.pt\")\n",
    "    parser.add_argument(\"--vae\", type=str, default=\"Predict-and-Subspace-Refine/DiffBIR/exp_dir/exp_vaehf04f03/checkpoints/0010000_vae.pt\")\n",
    "    parser.add_argument(\"--URSe\", type=str, default=\"Predict-and-Subspace-Refine/weights/model_URSe_hf3_endecoder_c21_bu2_c9_DConvWoBN_resca_silu_2024-09-05_psnr49.5199.pt\")\n",
    "    # sampling parameters\n",
    "    parser.add_argument(\"--steps\", type=int, default=200)\n",
    "    parser.add_argument(\"--better_start\", type=bool, default=True)\n",
    "    parser.add_argument(\"--upscale\", type=int, default=1.0)\n",
    "    parser.add_argument(\"--tiled\", type=bool, default=False)\n",
    "    parser.add_argument(\"--tile_size\", type=int, default=512)\n",
    "    parser.add_argument(\"--tile_stride\", type=int, default=256)\n",
    "    parser.add_argument(\"--pos_prompt\", type=str, default=\"\")\n",
    "    parser.add_argument(\"--neg_prompt\", type=str, default=\"\")\n",
    "    parser.add_argument(\"--cfg_scale\", type=float, default=1.0)\n",
    "    # input parameters\n",
    "    parser.add_argument(\"--n_samples\", type=int, default=1)\n",
    "    # guidance parameters\n",
    "    parser.add_argument(\"--guidance\", type=bool, default=True)\n",
    "    parser.add_argument(\"--g_scale\", type=float, default=0)\n",
    "    parser.add_argument(\"--g_t_start\", type=int, default=600)\n",
    "    parser.add_argument(\"--g_t_stop\", type=int, default=-1)\n",
    "    parser.add_argument(\"--g_space\", type=str, default=\"rgb\")\n",
    "    parser.add_argument(\"--g_repeat\", type=int, default=1)\n",
    "    # output parameters\n",
    "    # common parameters\n",
    "    parser.add_argument(\"--seed\", type=int, default=366)\n",
    "    parser.add_argument(\"--output\", type=str, default=\"./results/\")\n",
    "    parser.add_argument(\"--num_evals\", type=int, default=300)\n",
    "    parser.add_argument(\"--device\", type=str, default=cuda_device)\n",
    "\n",
    "    return parser.parse_known_args()[0]\n",
    "\n",
    "\n",
    "args = parse_args()\n",
    "\n",
    "@torch.no_grad()\n",
    "def main_func():\n",
    "    global test_data\n",
    "    normalized_images, MSI_E_vectors, max_val_channel, min_val_channel, inputs_msi_lf = preprocess_data(input_meas, input_mask, PSRSCI_Pipeline)\n",
    "\n",
    "    os.makedirs(args.output, exist_ok=True)\n",
    "\n",
    "    if args.guidance:\n",
    "        cond_fn = MeasMSEGuidance(\n",
    "            scale=150,rgb_subscale=0.05, t_start=int(args.g_t_start), t_stop=args.g_t_stop,\n",
    "            space=args.g_space, repeat=int(args.g_repeat)\n",
    "        )\n",
    "        cond_fn.load_guidance(init_meas(test_data, mask3d_batch, \"Y\").cuda(), mask3d_batch, MSI_E_vectors, max_val_channel, min_val_channel,inputs_msi_lf, PSRSCI_Pipeline.decoder)\n",
    "    else:\n",
    "        cond_fn = None\n",
    "\n",
    "    restored_images, diffusion_outputs = process_diffusion(\n",
    "        PSRSCI_Pipeline,\n",
    "        normalized_images,\n",
    "        MSI_E_vectors,\n",
    "        max_val_channel,\n",
    "        min_val_channel,\n",
    "        inputs_msi_lf,\n",
    "        steps=int(args.steps),\n",
    "        upscale=args.upscale,\n",
    "        cond_fn=cond_fn,\n",
    "        cfg_scale=args.cfg_scale,\n",
    "        tiled=args.tiled,\n",
    "        tile_size=args.tile_size,\n",
    "        tile_stride=args.tile_stride,\n",
    "        better_start=args.better_start,\n",
    "        pos_prompt=args.pos_prompt,\n",
    "        neg_prompt=args.neg_prompt,\n",
    "    )\n",
    "\n",
    "    for i in range(restored_images.shape[0]):\n",
    "        save_path = os.path.join(args.output, f\"{str(ts[i])}.mat\")\n",
    "\n",
    "        print(f'Save reconstructed HSIs as {save_path}.')\n",
    "        scio.savemat(save_path, {\"truth\": test_data[i].detach().cpu().float().numpy(), \"normalized_images\": normalized_images[i].detach().cpu().numpy(), \"diffusion_outputs\": diffusion_outputs[i].detach().cpu().numpy(), \"pred\": restored_images[i].detach().cpu().numpy(),\"init\":(PSRSCI_Pipeline.decoder(normalized_images * (max_val_channel - min_val_channel) + min_val_channel) + inputs_msi_lf)[i].detach().cpu().numpy()})\n",
    "\n",
    "        print(f\"save to {save_path}\")\n",
    "\n",
    "\n",
    "main_func()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.colors as mcolors\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def test():\n",
    "    t_new = range(2)\n",
    "\n",
    "    scenes = [sio.loadmat(f\"./results/{str(i)}.mat\") for i in t_new]  # /bak/v mstl\n",
    "\n",
    "    pred = torch.tensor([scenes[i][\"pred\"] for i in t_new])\n",
    "    init = torch.tensor([scenes[i][\"init\"] for i in t_new])\n",
    "    test_data = torch.tensor([scenes[i][\"truth\"] for i in t_new])\n",
    "    diffusion_outputs = torch.tensor([scenes[i][\"diffusion_outputs\"] for i in t_new])\n",
    "    normalized_images = torch.tensor([scenes[i][\"normalized_images\"] for i in t_new])\n",
    "\n",
    "    total_psnr = 0.0\n",
    "\n",
    "    total_ssim = 0.0\n",
    "\n",
    "    for i in range(test_data.shape[0]):\n",
    "        current_psnr = torch_psnr(pred[i], test_data[i])\n",
    "        current_ssim = torch_ssim(pred[i], test_data[i])\n",
    "\n",
    "        total_psnr += current_psnr\n",
    "        total_ssim += current_ssim\n",
    "\n",
    "        print(f\"For image {i+1}  PSNR: {current_psnr:.2f}   SSIM: {current_ssim:.4f}\")\n",
    "\n",
    "    average_psnr = total_psnr / test_data.shape[0]\n",
    "    average_ssim = total_ssim / test_data.shape[0]\n",
    "\n",
    "    print(f\"\\nAverage PSNR for all images: {average_psnr:.2f}\")\n",
    "    print(f\"Average SSIM for all images: {average_ssim:.4f}\")\n",
    "\n",
    "    plt.imshow(normalized_images[0].permute(1, 2, 0).detach().cpu().numpy())\n",
    "    plt.axis(\"off\")  \n",
    "    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove white borders\n",
    "    plt.show()\n",
    "\n",
    "    plt.imshow(diffusion_outputs[0].permute(1, 2, 0).detach().cpu().numpy())\n",
    "    plt.axis(\"off\") \n",
    "    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove white borders\n",
    "    plt.show()\n",
    "\n",
    "    plt.imshow(URSe_model.encoder(init.cuda())[0].permute(1, 2, 0).detach().cpu().numpy()+0.1)\n",
    "    plt.axis(\"off\") \n",
    "    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove white borders\n",
    "    plt.show()\n",
    "\n",
    "    plt.imshow(URSe_model.encoder(pred.cuda())[0].permute(1, 2, 0).detach().cpu().numpy()+0.1)\n",
    "    plt.axis(\"off\") \n",
    "    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove white borders\n",
    "    plt.show()\n",
    "\n",
    "    plt.imshow(URSe_model.encoder(test_data.cuda())[0].permute(1, 2, 0).detach().cpu().numpy()+0.1)\n",
    "    plt.axis(\"off\") \n",
    "    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove white borders\n",
    "    plt.show()\n",
    "\n",
    "    norm = mcolors.TwoSlopeNorm(vcenter=0, vmin=-0.4, vmax=0.4)\n",
    "\n",
    "    plt.figure()\n",
    "    img2 = plt.imshow(init_meas(PSRSCI_Pipeline.MSI_model(input_meas.cuda(), input_mask), mask3d_batch, \"Y\").detach().cpu().numpy()[0] - init_meas(test_data.cuda(), mask3d_batch, \"Y\").detach().cpu().numpy()[0], cmap=\"coolwarm\", norm=norm)\n",
    "    plt.colorbar(img2)\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    img1 = plt.imshow(init_meas(pred.cuda(), mask3d_batch, \"Y\").detach().cpu().numpy()[0] - init_meas(test_data.cuda(), mask3d_batch, \"Y\").detach().cpu().numpy()[0], cmap=\"coolwarm\", norm=norm)\n",
    "    plt.colorbar(img1)\n",
    "    plt.show()\n",
    "\n",
    "test()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffbir",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
