{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a5ecf87-86e5-4abb-8592-04636ec42caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from unet import *\n",
    "from torch.utils.data import DataLoader\n",
    "import dino.vision_transformer as vits\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30e50f0-7e3b-4e6e-8984-d1ef3e93b346",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = \"cuda\"\n",
    "timesteps = 1000#1000\n",
    "scale = 1000 / timesteps\n",
    "# beta_start = scale * 0.0001\n",
    "# beta_end = scale * 0.02\n",
    "beta_start = scale * 0.1\n",
    "beta_end = scale * 20.\n",
    "betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)\n",
    "\n",
    "            \n",
    "alphas = 1. - betas/1000\n",
    "alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
    "alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n",
    "        \n",
    "sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
    "sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)\n",
    "log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod)\n",
    "\n",
    "sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)\n",
    "sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)\n",
    "sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)\n",
    "sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1309cab1-8bd3-484b-b43c-344ad7cf508b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda\"\n",
    "model_mu = UNetModel(256, 128, dropout=0, n_heads=4 ,in_channels=3)\n",
    "model_logvar = UNetModel(256, 128, dropout=0, n_heads=4 ,in_channels=3)\n",
    "model_mu =nn.DataParallel(model_mu,device_ids=[0, 1, 2, 3])\n",
    "model_mu.to(device)\n",
    "model_logvar = nn.DataParallel(model_logvar,device_ids=[0, 1, 2, 3])\n",
    "model_logvar.to(device)\n",
    "\n",
    "\n",
    "# model_mu.load_state_dict(torch.load('./ck/basechanel64/capsule/duad_mu_epoch_capsule_2500.pth'))\n",
    "# model_logvar.load_state_dict(torch.load('./ck/basechanel64/capsule/duad_logvar_epoch_capsule_2500.pth'))\n",
    "# model_mu.load_state_dict(torch.load('./ck/MVTec_noro/capsule/duad_mu_epoch_copy_2500.pth'))\n",
    "# model_logvar.load_state_dict(torch.load('./ck/MVTec_noro/capsule/duad_logvar_epoch_copy_2500.pth'))\n",
    "\n",
    "category = \"connector\" #['bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw', \n",
    "                    #'toothbrush', 'transistor', 'zipper','carpet','grid', 'leather', 'tile', 'wood']    \n",
    "\n",
    "\n",
    "#['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']    \n",
    " #['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']    \n",
    "         # ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2',\n",
    "#                  'pcb1', 'pcb2' ,'pcb3', 'pcb4', 'pipe_fryum']\n",
    "        #['bracket_black','bracket_brown','bracket_white',\n",
    "#     'connector','metal_plate','tubes']\n",
    "# model_mu.load_state_dict(torch.load(f'./ck_xishu/MVTec/{category}/duad_mu_epoch_{category}_1000.pth'))\n",
    "# model_logvar.load_state_dict(torch.load(f'./ck_xishu/MVTec/{category}/duad_logvar_epoch_{category}_1000.pth'))\n",
    "\n",
    "# model_mu.load_state_dict(torch.load(f'./ck_new/MVTec/{category}/duad_mu_epoch_{category}_1000.pth',map_location='cpu'))\n",
    "# model_logvar.load_state_dict(torch.load(f'./ck_new/MVTec/{category}/duad_logvar_epoch_{category}_1000.pth',map_location='cpu'))\n",
    "\n",
    "\n",
    "\n",
    "# model_mu.load_state_dict(torch.load(f'./ck_new/visa/{category}/duad_mu_epoch_{category}_1000.pth',map_location='cpu'))\n",
    "# model_logvar.load_state_dict(torch.load(f'./ck_new/visa/{category}/duad_logvar_epoch_{category}_1000.pth',map_location='cpu'))\n",
    "\n",
    "model_mu.load_state_dict(torch.load(f'./ck_new/MPDD/{category}/duad_mu_epoch_{category}_3000.pth',map_location='cpu'))\n",
    "model_logvar.load_state_dict(torch.load(f'./ck_new/MPDD/{category}/duad_logvar_epoch_{category}_3000.pth',map_location='cpu'))\n",
    "\n",
    "\n",
    "\n",
    "model_mu.eval()\n",
    "model_logvar.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a83e1d8-36a4-44f1-a45d-f656d7dddc2c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7b3c05d5-7984-4ca7-9baf-a6e72c31d359",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import json\n",
    "import cv2\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import transforms\n",
    "import torch\n",
    "import torchvision\n",
    "# mean_train = [0.485, 0.456, 0.406]\n",
    "# std_train = [0.229, 0.224, 0.225]\n",
    "# transforms.RandomRotation(30,interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "\n",
    "# transforms.Normalize(mean=mean_train, std=std_train)\n",
    "def data_transforms(size):\n",
    "    datatrans =  transforms.Compose([\n",
    "    transforms.Resize((size, size)),\n",
    "    transforms.ToTensor(),\n",
    "    \n",
    "    transforms.CenterCrop(size),\n",
    "    #transforms.CenterCrop(args.input_size),\n",
    "    transforms.Lambda(lambda t: (t * 2) - 1)])\n",
    "    return datatrans\n",
    "def gt_transforms(size):\n",
    "    gttrans =  transforms.Compose([\n",
    "    transforms.Resize((size, size)),\n",
    "    \n",
    "    transforms.CenterCrop(size),\n",
    "    transforms.ToTensor()])\n",
    "    return gttrans\n",
    "\n",
    "\n",
    "class MVTecDataset(Dataset):\n",
    "    def __init__(self,type, root):\n",
    "        self.data = []\n",
    "\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "        \n",
    "        with open(f'/home/jovyan/diad_new/DiAD-main/training/MVTec-AD/{category}.json', 'rt') as f:\n",
    "\n",
    "             for line in f:\n",
    "                self.data.append(json.loads(line))\n",
    "    \n",
    "#         self.data.append(data_path)\n",
    "        self.label_to_idx = {'bottle': '0', 'cable': '1', 'capsule': '2', 'carpet': '3', 'grid': '4', 'hazelnut': '5',\n",
    "                             'leather': '6', 'metal_nut': '7', 'pill': '8', 'screw': '9', 'tile': '10',\n",
    "                             'toothbrush': '11', 'transistor': '12', 'wood': '13', 'zipper': '14'}\n",
    "        self.image_size = (256, 256)\n",
    "        self.root = '/home/jovyan/diad_new/DiAD-main/training/MVTec-AD/mvtecad/'\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        source_filename = item['filename']\n",
    "        target_filename = item['filename']\n",
    "        label = item[\"label\"]\n",
    "        if item.get(\"maskname\", None):\n",
    "            mask = cv2.imread( self.root + item['maskname'], cv2.IMREAD_GRAYSCALE)\n",
    "        else:\n",
    "            if label == 0:  # good\n",
    "                mask = np.zeros(self.image_size).astype(np.uint8)\n",
    "            elif label == 1:  # defective\n",
    "                mask = (np.ones(self.image_size)).astype(np.uint8)\n",
    "            else:\n",
    "                raise ValueError(\"Labels must be [None, 0, 1]!\")\n",
    "\n",
    "        prompt = \"\"\n",
    "        source = cv2.imread(self.root + source_filename)\n",
    "        target = cv2.imread(self.root + target_filename)\n",
    "        source = cv2.cvtColor(source, 4)\n",
    "        target = cv2.cvtColor(target, 4)\n",
    "        source = Image.fromarray(source, \"RGB\")\n",
    "        target = Image.fromarray(target, \"RGB\")\n",
    "        mask = Image.fromarray(mask, \"L\")\n",
    "        # transform_fn = transforms.Resize(256, Image.BILINEAR)\n",
    "        transform_fn = transforms.Resize(self.image_size)\n",
    "#         transform_fn = transforms.Compose([\n",
    "#             transforms.Resize(self.image_size),\n",
    "#             transforms.RandomRotation(30,interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "#             ])\n",
    "        source = transform_fn(source)\n",
    "#         target = source\n",
    "        target = transform_fn(target)\n",
    "        mask = transform_fn(mask)\n",
    "        source = transforms.ToTensor()(source)\n",
    "        target = transforms.ToTensor()(target)\n",
    "        mask = transforms.ToTensor()(mask)\n",
    "        #normalize_fn = transforms.Normalize(mean=mean_train, std=std_train)\n",
    "        normalize_fn =transforms.Lambda(lambda t: (t * 2) - 1)\n",
    "#         old_1=source     \n",
    "        \n",
    "        source = normalize_fn(source)\n",
    "#         print('source',torch.max(source),torch.min(source))\n",
    "        target = normalize_fn(target)\n",
    "        \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        clsname = item[\"clsname\"]\n",
    "        image_idx = self.label_to_idx[clsname]\n",
    "        \n",
    "        b_size = source.size()\n",
    "        t_kl = torch.zeros(b_size)\n",
    "\n",
    "        return dict(jpg=target, mask=mask, filename=source_filename, clsname=clsname, label=int(image_idx),label_01 = label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "3227b018-3271-48fc-87b4-4e1960dd4c6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import json\n",
    "import cv2\n",
    "import numpy as np\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n",
    "\n",
    "mean_train = [0.485, 0.456, 0.406]\n",
    "std_train = [0.229, 0.224, 0.225]\n",
    "class VisaDataset(Dataset):\n",
    "    def __init__(self,type, root):\n",
    "        self.data = []\n",
    "        with open('/home/jovyan/diad_new/DiAD-main/training/VisA/visa.csv', 'rt') as f:\n",
    "            render = csv.reader(f, delimiter=',')\n",
    "            header = next(render)\n",
    "            for row in render:       \n",
    "                if row[1] == type and row[0] == f'{category}':\n",
    "#                 if row[1] == type and row[0] == 'candle':                 \n",
    "#                 if row[1] == type and row[0] == 'capsules' :                 \n",
    "#                 if row[1] == type and row[0] == 'cashew' :                 \n",
    "#                 if row[1] == type and row[0] == 'chewinggum' :               \n",
    "#                 if row[1] == type and row[0] == 'fryum' :                \n",
    "#                 if row[1] == type and row[0] == 'macaroni1' :               \n",
    "#                 if row[1] == type and row[0] == 'macaroni2' :                 \n",
    "#                 if row[1] == type and row[0] == 'pcb1':                 \n",
    "#                 if row[1] == type and row[0] == 'pcb2' :                \n",
    "#                 if row[1] == type and row[0] == 'pcb3':                \n",
    "#                 if row[1] == type and row[0] == 'pcb4':               \n",
    "#                 if row[1] == type and row[0] == 'pipe_fryum':\n",
    "\n",
    "                 \n",
    "                    data_dict = {'object':row[0],'split':row[1],'label':row[2],'image':row[3],'mask':row[4]}\n",
    "                    self.data.append(data_dict)\n",
    "        self.label_to_idx = {'candle': '0', 'capsules': '1', 'cashew': '2', 'chewinggum': '3', 'fryum': '4', 'macaroni1': '5',\n",
    "                             'macaroni2': '6', 'pcb1': '7', 'pcb2': '8', 'pcb3': '9', 'pcb4': '10',\n",
    "                             'pipe_fryum': '11',}\n",
    "        self.image_size = (256,256)\n",
    "        self.root = root\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "\n",
    "        source_filename = item['image']\n",
    "        target_filename = item['image']\n",
    "        prompt = \"\"\n",
    "        if item.get(\"mask\", None):\n",
    "#             mask = cv2.imread( self.root + item['mask'], cv2.IMREAD_GRAYSCALE)\n",
    "            mask = Image.open( self.root + item['mask'])\n",
    "            mask_array = np.array(mask)\n",
    "            mask_array[mask_array != 0] = 255\n",
    "            mask = mask_array\n",
    "        else:\n",
    "            if item['label'] == 'normal':  # good\n",
    "                mask = np.zeros(self.image_size).astype(np.uint8)\n",
    "            elif item['label'] == 'anomaly':  # defective\n",
    "                mask = (np.ones(self.image_size)).astype(np.uint8)\n",
    "            else:\n",
    "                raise ValueError(\"Labels must be [None, 0, 1]!\")\n",
    "        \n",
    "        label = 0 if item['label'] == 'normal' else 1\n",
    "        \n",
    "        source = cv2.imread(self.root + source_filename)\n",
    "        target = cv2.imread(self.root + target_filename)\n",
    "        source = cv2.cvtColor(source, 4)\n",
    "        target = cv2.cvtColor(target, 4)\n",
    "        source = Image.fromarray(source, \"RGB\")\n",
    "        target = Image.fromarray(target, \"RGB\")\n",
    "        mask = Image.fromarray(mask, \"L\")\n",
    "        transform_fn = transforms.Resize(self.image_size)\n",
    "        source = transform_fn(source)\n",
    "        target = transform_fn(target)\n",
    "        mask = transform_fn(mask)\n",
    "        source = transforms.ToTensor()(source)\n",
    "        target = transforms.ToTensor()(target)\n",
    "        mask = transforms.ToTensor()(mask)\n",
    "#         normalize_fn = transforms.Normalize(mean=mean_train, std=std_train)\n",
    "#         source = normalize_fn(source)\n",
    "#         target = normalize_fn(target)\n",
    "        normalize_fn =transforms.Lambda(lambda t: (t * 2) - 1)    \n",
    "        source = normalize_fn(source)\n",
    "        target = normalize_fn(target)\n",
    "        clsname = item[\"object\"]\n",
    "        image_idx = self.label_to_idx[clsname]\n",
    "\n",
    "        return dict(jpg=target, txt=prompt, hint=source, mask=mask, filename=source_filename, clsname=clsname, label=int(image_idx),label_01 = label)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51e8d65f-45d2-4a6f-a12a-adcb3f8ea36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import json\n",
    "import cv2\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "from torchvision import transforms\n",
    "import torch\n",
    "import torchvision\n",
    "# mean_train = [0.485, 0.456, 0.406]\n",
    "# std_train = [0.229, 0.224, 0.225]\n",
    "# transforms.RandomRotation(30,interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "\n",
    "# transforms.Normalize(mean=mean_train, std=std_train)\n",
    "def data_transforms(size):\n",
    "    datatrans =  transforms.Compose([\n",
    "    transforms.Resize((size, size)),\n",
    "    transforms.ToTensor(),\n",
    "    \n",
    "    transforms.CenterCrop(size),\n",
    "    #transforms.CenterCrop(args.input_size),\n",
    "    transforms.Lambda(lambda t: (t * 2) - 1)])\n",
    "    return datatrans\n",
    "def gt_transforms(size):\n",
    "    gttrans =  transforms.Compose([\n",
    "    transforms.Resize((size, size)),\n",
    "    \n",
    "    transforms.CenterCrop(size),\n",
    "    transforms.ToTensor()])\n",
    "    return gttrans\n",
    "\n",
    "\n",
    "class MPDDDataset(Dataset):\n",
    "    def __init__(self,type, root):\n",
    "        self.data = []\n",
    "\n",
    "    \n",
    "        \n",
    "        with open(f'/home/jovyan/dataset/MPDD/{category}.json', 'rt') as f:\n",
    "\n",
    "             for line in f:\n",
    "                self.data.append(json.loads(line))\n",
    "    \n",
    "#         self.data.append(data_path)\n",
    "#         self.label_to_idx = {'bottle': '0', 'cable': '1', 'capsule': '2', 'carpet': '3', 'grid': '4', 'hazelnut': '5',\n",
    "#                              'leather': '6', 'metal_nut': '7', 'pill': '8', 'screw': '9', 'tile': '10',\n",
    "#                              'toothbrush': '11', 'transistor': '12', 'wood': '13', 'zipper': '14'}\n",
    "        self.image_size = (256, 256)\n",
    "        self.root = root\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        source_filename = item['filename']\n",
    "        target_filename = item['filename']\n",
    "        label = item[\"label\"]\n",
    "        if item.get(\"maskname\", None):\n",
    "            mask = cv2.imread( self.root + item['maskname'], cv2.IMREAD_GRAYSCALE)\n",
    "        else:\n",
    "            if label == 0:  # good\n",
    "                mask = np.zeros(self.image_size).astype(np.uint8)\n",
    "            elif label == 1:  # defective\n",
    "                mask = (np.ones(self.image_size)).astype(np.uint8)\n",
    "            else:\n",
    "                raise ValueError(\"Labels must be [None, 0, 1]!\")\n",
    "\n",
    "        prompt = \"\"\n",
    "        source = cv2.imread(self.root + source_filename)\n",
    "        target = cv2.imread(self.root + target_filename)\n",
    "        source = cv2.cvtColor(source, 4)\n",
    "        target = cv2.cvtColor(target, 4)\n",
    "        source = Image.fromarray(source, \"RGB\")\n",
    "        target = Image.fromarray(target, \"RGB\")\n",
    "        mask = Image.fromarray(mask, \"L\")\n",
    "        # transform_fn = transforms.Resize(256, Image.BILINEAR)\n",
    "        transform_fn = transforms.Resize(self.image_size,Image.ANTIALIAS)\n",
    "#         transform_fn = transforms.Compose([\n",
    "#             transforms.Resize(self.image_size),\n",
    "#             transforms.RandomRotation(30,interpolation=transforms.InterpolationMode.BILINEAR),\n",
    "#             ])\n",
    "        source = transform_fn(source)\n",
    "#         target = source\n",
    "        target = transform_fn(target)\n",
    "        mask = transform_fn(mask)\n",
    "        source = transforms.ToTensor()(source)\n",
    "        target = transforms.ToTensor()(target)\n",
    "        mask = transforms.ToTensor()(mask)\n",
    "        #normalize_fn = transforms.Normalize(mean=mean_train, std=std_train)\n",
    "        normalize_fn =transforms.Lambda(lambda t: (t * 2) - 1)\n",
    "#         old_1=source     \n",
    "        \n",
    "        source = normalize_fn(source)\n",
    "#         print('source',torch.max(source),torch.min(source))\n",
    "        target = normalize_fn(target)\n",
    "        \n",
    "\n",
    "        \n",
    "        \n",
    "        \n",
    "        clsname = item[\"clsname\"]\n",
    "\n",
    "#         image_idx = self.label_to_idx[clsname]\n",
    "        image_idx = ''\n",
    "        b_size = source.size()\n",
    "        t_kl = torch.zeros(b_size)\n",
    "\n",
    "        return dict(jpg=target, mask=mask, filename=source_filename, clsname=clsname, label=image_idx,label_01 = label)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a51d88a8-1588-487d-b256-c9173e716878",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Misc\n",
    "# CUDA_VISIBLE_DEVICES=1\n",
    "# data_path = '/home/jovyan/diad_new/DiAD-main/training/MVTec-AD/mvtecad/'\n",
    "# data_path = '/home/jovyan/visa/'\n",
    "data_path = '/home/jovyan/dataset/MPDD/'\n",
    "# dataset = MVTecDataset('test',data_path)\n",
    "dataset = MPDDDataset('test',data_path)\n",
    "\n",
    "# dataset = MVTecDataset('test',data_path)\n",
    "# dataset = VisaDataset('test',data_path)\n",
    "dataloader = DataLoader(dataset, num_workers=8, batch_size=256, shuffle=False)\n",
    "# pretrained_model = timm.create_model(\"resnet50\", pretrained=True, features_only=True)\n",
    "# pretrained_model = pretrained_model.cuda()\n",
    "# pretrained_model = pretrained_model.cpu()\n",
    "# pretrained_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a99515-90a8-4631-86e2-377fbf6c2bd8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7305a421-3f88-4f22-a5f2-d2494da97e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats=None):\n",
    "    if vit_model == \"dino\":\n",
    "        if vit_arch == \"vit_small\" and vit_patch_size == 16:\n",
    "            url = \"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth\"\n",
    "            initial_dim = 384\n",
    "        elif vit_arch == \"vit_small\" and vit_patch_size == 8:\n",
    "            url = \"dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth\"\n",
    "            initial_dim = 384\n",
    "        elif vit_arch == \"vit_base\" and vit_patch_size == 16:\n",
    "            if vit_model == \"clip\":\n",
    "                url = \"5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\"\n",
    "            elif vit_model == \"dino\":\n",
    "                url = \"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth\"\n",
    "            initial_dim = 768\n",
    "        elif vit_arch == \"vit_base\" and vit_patch_size == 8:\n",
    "            url = \"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth\"\n",
    "            initial_dim = 768\n",
    "\n",
    "        state_dict = torch.hub.load_state_dict_from_url(\n",
    "            url=\"https://dl.fbaipublicfiles.com/dino/\" + url\n",
    "        )\n",
    "#         state_dict = torch.load('/home/jovyan/cvpr/dino_pill_14.pth')\n",
    "        vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0, extraction=[3, 6, 9, 12])\n",
    "        vit_encoder.load_state_dict(state_dict, strict=True)\n",
    "\n",
    "    elif vit_model == \"dino-v2\":\n",
    "        if vit_model == \"dino-v2\" and vit_arch == \"vit_base\" and vit_patch_size == 14:\n",
    "            # url = \"dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth\"\n",
    "            initial_dim = 768\n",
    "            vit_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc')\n",
    "#             vit_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb8_reg_lc')\n",
    "        elif vit_model == \"dino-v2\" and vit_arch == \"vit_large\" and vit_patch_size == 14:\n",
    "            # url = \"dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth\"\n",
    "            # initial_dim = 768\n",
    "            vit_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc')\n",
    "\n",
    "        # state_dict = torch.hub.load_state_dict_from_url(\n",
    "        #     url=\"https://dl.fbaipublicfiles.com/dinov2/\" + url\n",
    "        # )\n",
    "        # vit_encoder = vits2.__dict__[vit_arch](patch_size=vit_patch_size)\n",
    "\n",
    "    for p in vit_encoder.parameters():\n",
    "        p.requires_grad = False\n",
    "\n",
    "    return vit_encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a943240-c13e-4d97-88fa-3fa450937b31",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05da0aea-27e8-4d66-afed-fa56c68608b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: http_proxy=http://10.60.6.186:7890\n",
      "env: https_proxy=http://10.60.6.186:7890\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/torchvision/transforms/transforms.py:287: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var tensor(19.9964, device='cuda:0') tensor(7.3338e-05, device='cuda:0')\n",
      "var tensor(17.6953, device='cuda:0') tensor(8.0455e-05, device='cuda:0')\n",
      "var tensor(17.2486, device='cuda:0') tensor(8.3690e-05, device='cuda:0')\n",
      "var tensor(19.6851, device='cuda:0') tensor(8.7090e-05, device='cuda:0')\n",
      "var tensor(28.4821, device='cuda:0') tensor(8.7212e-05, device='cuda:0')\n",
      "var tensor(32.9607, device='cuda:0') tensor(8.8289e-05, device='cuda:0')\n",
      "var tensor(40.9350, device='cuda:0') tensor(9.0459e-05, device='cuda:0')\n",
      "var tensor(74.4288, device='cuda:0') tensor(9.2537e-05, device='cuda:0')\n",
      "var tensor(197.9968, device='cuda:0') tensor(9.4368e-05, device='cuda:0')\n",
      "var tensor(585.6688, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(921.8633, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(694.1952, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(522.3657, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(472.4205, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(446.4328, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(421.4065, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(397.2377, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "var tensor(304.5539, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "var tensor(182.8138, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "var tensor(70.1956, device='cuda:0') tensor(0.0001, device='cuda:0')\n",
      "var tensor(28.3704, device='cuda:0') tensor(9.8663e-05, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  warnings.warn(\n",
      "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3679: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test-------- I-AUROC/I-AP/I-F1-max/P-AUROC/P-AP/P-F1-max/PRO:98.3/96.4/93.3/94.6/4.4/9.2/73.3-----\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# x_list= []\n",
    "\n",
    "s_list = []\n",
    "\n",
    "def rec(x_in,z):\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(z):\n",
    "            t = torch.full((x_in.shape[0],), i, device=device)\n",
    "            noise = torch.randn_like(x_in).float() \n",
    "            a1 = 10*torch.max(torch.exp(model_logvar(x_in,t)))*0.001\n",
    "#             print('1')\n",
    "#         print(torch.mean(a))10*10*\n",
    "            x_t = torch.sqrt(1-a1)*x_in+torch.sqrt(a1)*noise\n",
    "            x_in = x_t\n",
    "            s_list.append(a1)\n",
    "            \n",
    "            \n",
    "    with torch.no_grad():\n",
    "        t = torch.full((x_in.shape[0],), z+1, device=device)\n",
    "        noise = torch.randn_like(image_input).float() \n",
    "        a1 = 10*torch.max(torch.exp(model_logvar(x_in,t)))*0.001\n",
    "        x_t1 = torch.sqrt(1-a1)*x_in+torch.sqrt(a1)*noise\n",
    "        s_list.append(a1)\n",
    "        \n",
    "        \n",
    "    img_test =x_t1\n",
    "#     img_test = x\n",
    "#     x1=x\n",
    "    x = x_t\n",
    "#     betas_list = []\n",
    "    with torch.no_grad():\n",
    "        for i in reversed(range(z)):\n",
    "            t = torch.full((x.shape[0],), i, device=device)\n",
    " \n",
    "#             noise = torch.randn_like(image_input).float()  \n",
    "#             perturbed_data = sqrt_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float() * image_input + sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float() * noise\n",
    "            \n",
    "            x = x.float()\n",
    "            pred_noise = model_mu(x,t)\n",
    "            logvar = model_logvar(x,t)\n",
    "           \n",
    "#             logvar = torch.clamp(logvar,min=-4.)\n",
    "#             print('logvar',torch.max(logvar),torch.min(logvar))\n",
    "            var =10*torch.exp(logvar)\n",
    "            print('var',torch.max(var),torch.min(var))\n",
    "\n",
    "            \n",
    "    \n",
    "    \n",
    "            betas = torch.tensor(s_list)\n",
    "#             betas =s_list\n",
    "#             betas = var\n",
    "#             betas_list.append(betas)\n",
    "#             alphas = torch.tensor(1.) - betas\n",
    "            alphas = 1. - betas\n",
    "#             alphas = [1 -  number for number in betas]\n",
    "#             print(alphas.type)\n",
    "#             alphas_tensor = torch.stack(alphas, dim=0)\n",
    "#             print(alphas_tensor)\n",
    "        \n",
    "            alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
    "#             alphas_cumprod = torch.cumprod(alphas_tensor, dim=0)\n",
    "            alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n",
    "\n",
    "            sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
    "            sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)\n",
    "            log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod)\n",
    "            #         alpha连乘的倒数\n",
    "            sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)\n",
    "            sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)\n",
    "            sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)\n",
    "            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)\n",
    "\n",
    "    #         noise = torch.randn_like(images).float()        \n",
    "#             sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()\n",
    "#             sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()\n",
    "    \n",
    "    \n",
    "        \n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "            \n",
    "#             x_recon = 1/torch.sqrt(alphas[i]).to(device)*img_test-betas[i].to(device)/(sqrt_one_minus_alphas_cumprod[i]).float()*torch.sqrt(alphas[i]).to(device)*pred_noise\n",
    "            \n",
    "            x_recon = 1/torch.sqrt(alphas[t]).to(device).view(-1, 1, 1, 1)*img_test-betas[t].to(device).view(-1, 1, 1, 1)/(sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()*torch.sqrt(alphas[t]).to(device).view(-1, 1, 1, 1))*pred_noise\n",
    "#             x_recon = torch.clamp(x_recon, min=-1., max=1.)\n",
    "            x_recon = torch.clamp(x_recon, min=-1., max=1.)\n",
    "            img_test = x_recon.float()\n",
    "        #\n",
    "            \n",
    "            \n",
    "\n",
    "#             std = (sqrt_one_minus_alphas_cumprod[t]**2).reshape(t.shape[0],1,1,1).float().to(device)\n",
    "\n",
    "            duad_fn = -(x-x_recon)\n",
    "\n",
    "        \n",
    "            x = x*(1+var/2)+duad_fn\n",
    "    \n",
    "            x = torch.clamp(x, min=-1., max=1.)\n",
    "#             img_test = x.float()\n",
    "#             x_list.append(x)\n",
    "#             print(x.shape)\n",
    "#             x_recon=1\n",
    "    return x,x_recon,var/10,x_t\n",
    "\n",
    "from scipy.ndimage import label\n",
    "%env http_proxy=http://10.60.6.186:7890\n",
    "%env https_proxy=http://10.60.6.186:7890\n",
    "import copy\n",
    "from kornia.filters import gaussian_blur2d\n",
    "from utilize.utilize import normalize, fix_seeds, compute_pro\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score\n",
    "# import timm\n",
    "weight_dtype = torch.float32\n",
    "device = torch.device(\"cuda\")\n",
    "dino_model = get_vit_encoder(vit_arch=\"vit_base\", vit_model=\"dino\", vit_patch_size=8, enc_type_feats=None).to(device, dtype=weight_dtype)\n",
    "# dino_model = nn.DataParallel(dino_model)\n",
    "# state_dict_dino = torch.load('/home/jovyan/cvpr/dino_pill_14.pth')\n",
    "# dino_model.load_state_dict(state_dict_dino)\n",
    "dino_model.eval()\n",
    "dino_frozen = copy.deepcopy(dino_model)\n",
    "\n",
    "\n",
    "# pretrained_model = timm.create_model(\"resnet50\", pretrained=True, features_only=True)\n",
    "# pretrained_model = pretrained_model.cuda()\n",
    "# pretrained_model.eval()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# file_path = f\"{category}.txt\"\n",
    "preds = []\n",
    "masks = []\n",
    "scores = []\n",
    "labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for input in dataloader:\n",
    "        image_input = input['jpg'].to(device)\n",
    "        anomaly_mask = input['mask'].to(device)\n",
    "        z  =20\n",
    "        \n",
    "        u=981.049084243972\n",
    "        v = 0\n",
    "        w=817.573268078374\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        t = torch.full((image_input.shape[0],), z, device=device)    \n",
    "        noise = torch.randn_like(image_input).float()  \n",
    "\n",
    "        x,x_recon,var,x_t = rec(image_input,z+1)\n",
    "        \n",
    "        \n",
    "\n",
    "        reconstruct_images =x     \n",
    "#         image_input_dino = image_input[i].unsqueeze(dim=0)\n",
    "\n",
    "#         transform1 = transforms.Compose([\n",
    "#             transforms.Lambda(lambda t: (t + 1) / (2)),\n",
    "#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "#         ])\n",
    "#         image_input2 = transform1(image_input)\n",
    "#         reconstruct_images2 = transform1(reconstruct_images)\n",
    "        \n",
    "        _, patch_tokens_i = dino_model(image_input.to(dtype=weight_dtype))\n",
    "        _, patch_tokens_r = dino_model(reconstruct_images.to(dtype=weight_dtype))\n",
    "        sigma = 6#6\n",
    "        kernel_size = 2 * int(4 * sigma + 0.5) + 1\n",
    "        b, n, c = patch_tokens_i[0][:, 1:, :].shape\n",
    "        h = int(n ** 0.5)\n",
    "        anomaly_maps_cos = torch.zeros((b, 1, 256, 256)).to(device)\n",
    "        anomaly_maps2 = torch.zeros((b, 1, 256, 256)).to(device)\n",
    "        for idx in range(len(patch_tokens_i)):\n",
    "            pi = patch_tokens_i[idx][:, 1:, :]\n",
    "            pr = patch_tokens_r[idx][:, 1:, :]\n",
    "\n",
    "            pi = pi / torch.norm(pi, p=2, dim=-1, keepdim=True)\n",
    "            pr = pr / torch.norm(pr, p=2, dim=-1, keepdim=True)\n",
    "\n",
    "            cos0 = torch.bmm(pi, pr.permute(0, 2, 1))\n",
    "\n",
    "            anomaly_map_cos, _ = torch.min(1 - cos0, dim=-1)\n",
    "            anomaly_map_cos = F.interpolate(anomaly_map_cos.reshape(-1, 1, h, h), size=256, mode='bilinear', align_corners=True)\n",
    "            anomaly_maps_cos += anomaly_map_cos      #(128,1,256,256)\n",
    "\n",
    "\n",
    "        \n",
    "        \n",
    "       \n",
    "        \n",
    "\n",
    "        mean_q = x \n",
    "        std_q = torch.sqrt(var)\n",
    "        mean_p = image_input\n",
    "        \n",
    "        t_kl1 = torch.full((image_input.shape[0],), 0, device=device)\n",
    "        std_p = torch.sqrt(torch.exp(model_logvar(image_input,t_kl1)))\n",
    "    #     print(std_p.shape)\n",
    "        kl_div_c = torch.log(std_q / std_p) + (std_p**2 + (mean_p - mean_q)**2) / (2 * std_q**2) - 0.5\n",
    "\n",
    "        kl_map = 0\n",
    "        kl_map += kl_div_c\n",
    "        anomaly_kl = gaussian_blur2d(\n",
    "        kl_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)\n",
    "        )\n",
    "        anomaly_kl = torch.sum(anomaly_kl, dim=1)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        anomaly_maps_uncer = (x-image_input)**2/(var) #([1, 3, 256, 256])\n",
    "#         anomaly_maps_uncer = (x0-images0)**2/(var) #([1, 3, 256, 256])        \n",
    "            \n",
    "        distance_map = torch.mean(((x-image_input)**2/(var)), dim=1).unsqueeze(1)\n",
    "\n",
    "        anomaly_uncer = 0\n",
    "        anomaly_uncer += distance_map\n",
    "        \n",
    "#         anomaly_uncer = v*anomaly_uncer + u*anomaly_maps_cos\n",
    "#         print('anomaly_uncer',anomaly_uncer.shape)\n",
    "        anomaly_uncer = gaussian_blur2d(\n",
    "        anomaly_uncer , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)\n",
    "        )\n",
    "        anomaly_uncer = torch.sum(anomaly_uncer, dim=1)\n",
    "      \n",
    "    \n",
    "\n",
    "        \n",
    "        \n",
    "        anomaly_maps1_cos = gaussian_blur2d(anomaly_maps_cos, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "\n",
    "        anomaly_maps1 = u*(torch.max(anomaly_uncer)/ torch.max(anomaly_maps1_cos))*anomaly_maps1_cos+v*anomaly_uncer+w*anomaly_kl\n",
    "#         anomaly_maps1 = v*anomaly_uncer        \n",
    "        \n",
    "#         anomaly_maps1 = gaussian_blur2d(anomaly_maps1, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "        anomaly_maps = anomaly_maps1 \n",
    "#         \n",
    "\n",
    "\n",
    "        \n",
    "#         print(anomaly_maps.shape)\n",
    "    \n",
    "\n",
    "        \n",
    "        \n",
    "        batch_size, _,height, width = anomaly_maps_cos.shape\n",
    "        for i in range(batch_size):         \n",
    "            root = os.path.join('log_image_var0124/')\n",
    "            name = input[\"filename\"][i][-7:-4] #005\n",
    "            heatmap_name = \"{}-heatmapcos.png\".format(name)\n",
    "            os.makedirs(os.path.join(root, input[\"filename\"][i][:-7]), exist_ok=True)\n",
    "            \n",
    "            \n",
    "            pixel_mean = [0.485, 0.456, 0.406]\n",
    "            pixel_std = [0.229, 0.224, 0.225]\n",
    "            pixel_mean = torch.tensor(pixel_mean).unsqueeze(1).unsqueeze(1).to(device)  # 3 x 1 x 1\n",
    "            pixel_std = torch.tensor(pixel_std).unsqueeze(1).unsqueeze(1).to(device)\n",
    "\n",
    "            anomaly_maps_cos1 = anomaly_maps_cos[i].cpu()\n",
    "#             anomaly_maps_cos1 =anomaly_map[i].cpu()\n",
    "            anomaly_maps_cos1 =anomaly_maps_cos1.unsqueeze(dim=0)\n",
    "            \n",
    "            anomaly_maps_cos1 = gaussian_blur2d(anomaly_maps_cos1, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "            anomaly_maps_cos1 = np.round(255 * (anomaly_maps_cos1 - anomaly_maps_cos1.min()) / (anomaly_maps_cos1.max() - anomaly_maps_cos1.min()))\n",
    "            anomaly_maps_cos1 = anomaly_maps_cos1.to('cpu').squeeze().numpy().astype(np.uint8)\n",
    "            heatmap_cos = cv2.applyColorMap(anomaly_maps_cos1, colormap=cv2.COLORMAP_JET)\n",
    "            image1 = ((image_input[i]+1)/2) * 255\n",
    "#             image1 = (image_input[i]* pixel_std + pixel_mean) * 255\n",
    "            image1 = image1.permute(1, 2, 0).to('cpu').numpy().astype('uint8')\n",
    "            image_copy = image1.copy()\n",
    "            out_heat_map = cv2.addWeighted(heatmap_cos, 0.5, image_copy, 0.5, 0, image_copy)            \n",
    "            name = input[\"filename\"][i][-7:-4] #005\n",
    "            name0=\"{}-in.png\".format(name)\n",
    "            name1=\"{}-rec.png\".format(name)\n",
    "            heatmap_name = \"{}-heatmapcos.png\".format(name)\n",
    "#             print(root + input[\"filename\"][i][:-7] + heatmap_name)\n",
    "            cv2.imwrite(root + input[\"filename\"][i][:-7] + heatmap_name,out_heat_map)\n",
    "            \n",
    "        \n",
    "        \n",
    "\n",
    "            cv_uncer = anomaly_maps_uncer[i].cpu()\n",
    "            \n",
    "            cv_uncer =cv_uncer.unsqueeze(dim=0)\n",
    "            cv_uncer = gaussian_blur2d(cv_uncer, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "            cv_uncer = np.round(255 * (cv_uncer - cv_uncer.min()) / (cv_uncer.max() - cv_uncer.min()))\n",
    "            cv_uncer = cv_uncer.to('cpu').squeeze().numpy().astype(np.uint8)\n",
    "            heatmap_uncer = cv2.applyColorMap(cv_uncer, colormap=cv2.COLORMAP_JET)\n",
    "            image1 = ((image_input[i]+1)/2) * 255\n",
    "            image1 = image1.permute(1, 2, 0).to('cpu').numpy().astype('uint8')\n",
    "            image_copy = image1.copy()\n",
    "            uncer_heat_map = cv2.addWeighted(heatmap_uncer, 0.5, image_copy, 0.5, 0, image_copy)\n",
    "            \n",
    "            \n",
    "            name = input[\"filename\"][i][-7:-4] #005\n",
    "            name0=\"{}-in.png\".format(name)\n",
    "            name1=\"{}-rec.png\".format(name)\n",
    "            heatmap_name = \"{}-heatmap_uae.png\".format(name)\n",
    "            cv2.imwrite(root + input[\"filename\"][i][:-7] + heatmap_name,uncer_heat_map)\n",
    "            \n",
    "\n",
    "            cv_heatmap = anomaly_maps[i].cpu()\n",
    "            \n",
    "            \n",
    "            cv_heatmap = np.round(255 * (cv_heatmap - cv_heatmap.min()) / (cv_heatmap.max() - cv_heatmap.min()))\n",
    "            cv_heatmap = cv_heatmap.to('cpu').squeeze().numpy().astype(np.uint8)\n",
    "            heatmap_cv_heatmap = cv2.applyColorMap(cv_heatmap, colormap=cv2.COLORMAP_JET)\n",
    "            image1 = ((image_input[i]+1)/2) * 255\n",
    "            image1 = image1.permute(1, 2, 0).to('cpu').numpy().astype('uint8')\n",
    "            image_copy = image1.copy()\n",
    "            cv_heatmap_map = cv2.addWeighted(heatmap_cv_heatmap, 0.5, image_copy, 0.5, 0, image_copy)\n",
    "            \n",
    "            name = input[\"filename\"][i][-7:-4] #005\n",
    "            name0=\"{}-in.png\".format(name)\n",
    "            name1=\"{}-rec.png\".format(name)\n",
    "            heatmap_name = \"{}-heatmap.png\".format(name)\n",
    "            cv2.imwrite(root + input[\"filename\"][i][:-7] + heatmap_name,cv_heatmap_map)\n",
    "        \n",
    "        \n",
    "        \n",
    "\n",
    "            #.astype(np.uint8)*255\n",
    "            image_input1 = ((image_input[i]+1)/2).permute(1, 2, 0).to('cpu').numpy()*255\n",
    "#             image_input1 = ((x_t[i]+1)/2).permute(1, 2, 0).to('cpu').numpy()*255\n",
    "            \n",
    "            image_input1 = cv2.cvtColor(image_input1, cv2.COLOR_BGR2RGB)\n",
    "            x_rec = ((x[i]+1)/2).permute(1, 2, 0).to('cpu').numpy()*255\n",
    "            x_rec = cv2.cvtColor(x_rec, cv2.COLOR_BGR2RGB)\n",
    "            cv2.imwrite(root + input[\"filename\"][i][:-7] + name0,image_input1)\n",
    "            cv2.imwrite(root + input[\"filename\"][i][:-7] + name1,x_rec)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        score = torch.topk(torch.flatten(anomaly_maps, start_dim=1), 250)[0].mean(dim=1)\n",
    "#         score = torch.flatten(anomaly_maps, start_dim=1).topk(dim=-1, k=64)[0].mean(-1)\n",
    "        \n",
    "        masks.extend([m for m in anomaly_mask[:, 0, :, :].cpu().numpy()])\n",
    "        preds.extend([a for a in anomaly_maps.cpu().numpy()])\n",
    "        scores.extend([s for s in score.cpu().numpy()])\n",
    "        labels.extend([l for l in input[\"label_01\"].cpu().numpy()])\n",
    "#         print('scores',scores)\n",
    "        \n",
    "scores = normalize(np.array(scores))\n",
    "labels = np.array(labels)\n",
    "preds = np.array(preds)\n",
    "masks = np.array(masks, dtype=np.int_)\n",
    "\n",
    "\n",
    "precisions_image, recalls_image, _ = precision_recall_curve(labels, scores)\n",
    "f1_scores_image = (2 * precisions_image * recalls_image) / (precisions_image + recalls_image)\n",
    "best_f1_scores_image = np.max(f1_scores_image[np.isfinite(f1_scores_image)])\n",
    "auroc_image = roc_auc_score(labels, scores)\n",
    "AP_image = average_precision_score(labels, scores)\n",
    "\n",
    "\n",
    "precisions_pixel, recalls_pixel, _ = precision_recall_curve(masks.ravel(), preds.ravel())\n",
    "f1_scores_pixel = (2 * precisions_pixel * recalls_pixel) / (precisions_pixel + recalls_pixel)\n",
    "best_f1_scores_pixel = np.max(f1_scores_pixel[np.isfinite(f1_scores_pixel)])\n",
    "auroc_pixel = roc_auc_score(masks.ravel(), preds.ravel())\n",
    "AP_pixel = average_precision_score(masks.ravel(), preds.ravel())\n",
    "\n",
    "pro = compute_pro(masks, preds)\n",
    "\n",
    "print(f\"test-------- I-AUROC/I-AP/I-F1-max/P-AUROC/P-AP/P-F1-max/PRO:{round(auroc_image*100, 1)}/{round(AP_image*100, 1)}/{round(best_f1_scores_image*100, 1)}/\"\n",
    "              f\"{round(auroc_pixel*100, 1)}/{round(AP_pixel*100, 1)}/{round(best_f1_scores_pixel*100, 1)}/{round(pro*100, 1)}-----\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7082e137-4754-492e-abc1-9b67e850831e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: http_proxy=http://10.60.6.186:7890\n",
      "env: https_proxy=http://10.60.6.186:7890\n",
      "pred_noise tensor(11.6051, device='cuda:0')\n",
      "logvar tensor(-1.7964, device='cuda:0') tensor(-10.7477, device='cuda:0')\n",
      "var tensor(1.6590, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "pred_noise tensor(12.3891, device='cuda:0')\n",
      "logvar tensor(-1.8615, device='cuda:0') tensor(-10.6773, device='cuda:0')\n",
      "var tensor(1.5543, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "pred_noise tensor(16.9554, device='cuda:0')\n",
      "logvar tensor(-1.9254, device='cuda:0') tensor(-10.6280, device='cuda:0')\n",
      "var tensor(1.4581, device='cuda:0') tensor(0.0002, device='cuda:0')\n",
      "pred_noise tensor(14.8605, device='cuda:0')\n",
      "logvar tensor(-3.1050, device='cuda:0') tensor(-10.8704, device='cuda:0')\n",
      "var tensor(0.4483, device='cuda:0') tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  warnings.warn(\n",
      "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3679: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.2-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.3-----\n",
      "test-------- I-AUROC:98.1-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "test-------- I-AUROC:98.0-----\n",
      "Best AUROC: 0.9826546003016592, Best u: 83.50565511275575, Best v: 1939.2429487784095, Best w: 0.0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# x_list= []\n",
    "s_list = []\n",
    "\n",
    "def rec(x_in,z):\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i in range(z):\n",
    "            t = torch.full((x_in.shape[0],), i, device=device)\n",
    "            noise = torch.randn_like(x_in).float() \n",
    "            a1 = 10*torch.max(torch.exp(model_logvar(x_in,t)))*0.001\n",
    "#             print('1')\n",
    "#         print(torch.mean(a))10*10*\n",
    "            x_t = torch.sqrt(1-a1)*x_in+torch.sqrt(a1)*noise\n",
    "            x_in = x_t\n",
    "            s_list.append(a1)\n",
    "            \n",
    "            \n",
    "    with torch.no_grad():\n",
    "        t = torch.full((x_in.shape[0],), z+1, device=device)\n",
    "        noise = torch.randn_like(x_in).float() \n",
    "        a1 = 10*torch.max(torch.exp(model_logvar(x_in,t)))*0.001\n",
    "        x_t1 = torch.sqrt(1-a1)*x_in+torch.sqrt(a1)*noise\n",
    "        s_list.append(a1)\n",
    "        \n",
    "        \n",
    "    img_test =x_t1\n",
    "#     img_test = x\n",
    "#     x1=x\n",
    "    x = x_t\n",
    "#     betas_list = []\n",
    "    with torch.no_grad():\n",
    "        for i in reversed(range(z)):\n",
    "            t = torch.full((x.shape[0],), i, device=device)\n",
    " \n",
    "#             noise = torch.randn_like(image_input).float()  \n",
    "#             perturbed_data = sqrt_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float() * image_input + sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float() * noise\n",
    "            \n",
    "            x = x.float()\n",
    "            pred_noise = model_mu(x,t)\n",
    "            logvar = model_logvar(x,t)\n",
    "            print('pred_noise',torch.max(pred_noise))\n",
    "#             logvar = torch.clamp(logvar,min=-4.)\n",
    "            print('logvar',torch.max(logvar),torch.min(logvar))\n",
    "            var =10*torch.exp(logvar)\n",
    "            print('var',torch.max(var),torch.min(var))\n",
    "\n",
    "            \n",
    "    \n",
    "    \n",
    "            betas = torch.tensor(s_list)\n",
    "#             betas =s_list\n",
    "#             betas = var\n",
    "#             betas_list.append(betas)\n",
    "#             alphas = torch.tensor(1.) - betas\n",
    "            alphas = 1. - betas\n",
    "#             alphas = [1 -  number for number in betas]\n",
    "#             print(alphas.type)\n",
    "#             alphas_tensor = torch.stack(alphas, dim=0)\n",
    "#             print(alphas_tensor)\n",
    "        \n",
    "            alphas_cumprod = torch.cumprod(alphas, axis=0)\n",
    "#             alphas_cumprod = torch.cumprod(alphas_tensor, dim=0)\n",
    "            alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)\n",
    "\n",
    "            sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)\n",
    "            sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)\n",
    "            log_one_minus_alphas_cumprod = torch.log(1.0 - alphas_cumprod)\n",
    "\n",
    "            sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)\n",
    "            sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)\n",
    "            sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)\n",
    "            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)\n",
    "\n",
    "    #         noise = torch.randn_like(images).float()        \n",
    "#             sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()\n",
    "#             sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()\n",
    "    \n",
    "    \n",
    "        \n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "            \n",
    "#             x_recon = 1/torch.sqrt(alphas[i]).to(device)*img_test-betas[i].to(device)/(sqrt_one_minus_alphas_cumprod[i]).float()*torch.sqrt(alphas[i]).to(device)*pred_noise\n",
    "            \n",
    "            x_recon = 1/torch.sqrt(alphas[t]).to(device).view(-1, 1, 1, 1)*img_test-betas[t].to(device).view(-1, 1, 1, 1)/(sqrt_one_minus_alphas_cumprod[t].reshape(t.shape[0],1,1,1).float()*torch.sqrt(alphas[t]).to(device).view(-1, 1, 1, 1))*pred_noise\n",
    "#             x_recon = torch.clamp(x_recon, min=-1., max=1.)\n",
    "            x_recon = torch.clamp(x_recon, min=-1., max=1.)\n",
    "            img_test = x_recon.float()\n",
    "        #\n",
    "            \n",
    "            \n",
    "\n",
    "#             std = (sqrt_one_minus_alphas_cumprod[t]**2).reshape(t.shape[0],1,1,1).float().to(device)\n",
    "\n",
    "            duad_fn = -(x-x_recon)\n",
    "\n",
    "        \n",
    "            x = x*(1+var/2)+duad_fn\n",
    "    \n",
    "            x = torch.clamp(x, min=-1., max=1.)\n",
    "#             img_test = x.float()\n",
    "#             x_list.append(x)\n",
    "#             print(x.shape)\n",
    "#             x_recon=1\n",
    "\n",
    "from scipy.ndimage import label\n",
    "\n",
    "import copy\n",
    "from kornia.filters import gaussian_blur2d\n",
    "from utilize.utilize import normalize, fix_seeds, compute_pro\n",
    "from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score\n",
    "# import timm\n",
    "weight_dtype = torch.float32\n",
    "device = torch.device(\"cuda\")\n",
    "dino_model = get_vit_encoder(vit_arch=\"vit_base\", vit_model=\"dino\", vit_patch_size=8, enc_type_feats=None).to(device, dtype=weight_dtype)\n",
    "# dino_model = nn.DataParallel(dino_model)\n",
    "# state_dict_dino = torch.load('/home/jovyan/cvpr/dino_pill_14.pth')\n",
    "# dino_model.load_state_dict(state_dict_dino)\n",
    "dino_model.eval()\n",
    "dino_frozen = copy.deepcopy(dino_model)\n",
    "\n",
    "\n",
    "# pretrained_model = timm.create_model(\"resnet50\", pretrained=True, features_only=True)\n",
    "# pretrained_model = pretrained_model.cuda()\n",
    "# pretrained_model.eval()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# file_path = f\"{category}.txt\"\n",
    "\n",
    "\n",
    "\n",
    "from skopt import gp_minimize\n",
    "from skopt.space import Real\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "with torch.no_grad():\n",
    "    for input in dataloader:\n",
    "        image_input = input['jpg'].to(device)\n",
    "        anomaly_mask = input['mask'].to(device)\n",
    "        z =3    #99\n",
    "        t = torch.full((image_input.shape[0],), z, device=device)    \n",
    "        noise = torch.randn_like(image_input).float()  \n",
    "\n",
    "        x,x_recon,var,x_t = rec(image_input,z+1)\n",
    "\n",
    "        \n",
    "\n",
    "        reconstruct_images =x     \n",
    "#         image_input_dino = image_input[i].unsqueeze(dim=0)\n",
    "\n",
    "#         transform1 = transforms.Compose([\n",
    "#             transforms.Lambda(lambda t: (t + 1) / (2)),\n",
    "#             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "#         ])\n",
    "#         image_input2 = transform1(image_input)\n",
    "#         reconstruct_images2 = transform1(reconstruct_images)\n",
    "\n",
    "        _, patch_tokens_i = dino_model(image_input.to(dtype=weight_dtype))\n",
    "        _, patch_tokens_r = dino_model(reconstruct_images.to(dtype=weight_dtype))\n",
    "        sigma = 6#6\n",
    "        kernel_size = 2 * int(4 * sigma + 0.5) + 1\n",
    "        b, n, c = patch_tokens_i[0][:, 1:, :].shape\n",
    "        h = int(n ** 0.5)\n",
    "        anomaly_maps_cos = torch.zeros((b, 1, 256, 256)).to(device)\n",
    "        anomaly_maps2 = torch.zeros((b, 1, 256, 256)).to(device)\n",
    "        for idx in range(len(patch_tokens_i)):\n",
    "            pi = patch_tokens_i[idx][:, 1:, :]\n",
    "            pr = patch_tokens_r[idx][:, 1:, :]\n",
    "\n",
    "            pi = pi / torch.norm(pi, p=2, dim=-1, keepdim=True)\n",
    "            pr = pr / torch.norm(pr, p=2, dim=-1, keepdim=True)\n",
    "\n",
    "            cos0 = torch.bmm(pi, pr.permute(0, 2, 1))\n",
    "\n",
    "            anomaly_map_cos, _ = torch.min(1 - cos0, dim=-1)\n",
    "            anomaly_map_cos = F.interpolate(anomaly_map_cos.reshape(-1, 1, h, h), size=256, mode='bilinear', align_corners=True)\n",
    "            anomaly_maps_cos += anomaly_map_cos      #(128,1,256,256)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "\n",
    "        mean_q = x  \n",
    "        std_q = torch.sqrt(var)\n",
    "        mean_p = image_input\n",
    "\n",
    "        t_kl1 = torch.full((image_input.shape[0],), 0, device=device)\n",
    "        std_p = torch.sqrt(torch.exp(model_logvar(image_input,t_kl1)))\n",
    "    #     print(std_p.shape)\n",
    "        kl_div_c = torch.log(std_q / std_p) + (std_p**2 + (mean_p - mean_q)**2) / (2 * std_q**2) - 0.5\n",
    "#         kl_map=kl_div_c\n",
    "\n",
    "    #     print(kl_div_c.shape)\n",
    "\n",
    "#         kl_map = F.interpolate(kl_div_c.resize(1,1,32,32), size=256, mode='bilinear', align_corners=True)\n",
    "        kl_map = 0\n",
    "        kl_map += kl_div_c\n",
    "        anomaly_kl = gaussian_blur2d(\n",
    "        kl_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)\n",
    "        )\n",
    "        anomaly_kl = torch.sum(anomaly_kl, dim=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        anomaly_maps_uncer = (x-image_input)**2/(var) #([1, 3, 256, 256])\n",
    "#         anomaly_maps_uncer = (x0-images0)**2/(var) #([1, 3, 256, 256])        \n",
    "\n",
    "        distance_map = torch.mean(((x-image_input)**2/(var)), dim=1).unsqueeze(1)\n",
    "#         distance_map = torch.mean(((x0-images0)**2/var), dim=1).unsqueeze(1)\n",
    "#         print(torch.max(distance_map[0]),torch.min(distance_map[0]),torch.mean(distance_map[0]))\n",
    "#         print('anomaly_maps_uncer',anomaly_maps_uncer.shape)\n",
    "#         anomaly_maps_uncer1 =torch.mean(anomaly_maps_uncer,dim=1)\n",
    "        anomaly_uncer = 0\n",
    "        anomaly_uncer += distance_map\n",
    "\n",
    "#         anomaly_uncer = v*anomaly_uncer + u*anomaly_maps_cos\n",
    "#         print('anomaly_uncer',anomaly_uncer.shape)\n",
    "        anomaly_uncer = gaussian_blur2d(\n",
    "        anomaly_uncer , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)\n",
    "        )\n",
    "        anomaly_uncer = torch.sum(anomaly_uncer, dim=1)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        anomaly_maps1_cos = gaussian_blur2d(anomaly_maps_cos, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "#         anomaly_maps1_cos = gaussian_blur2d(anomaly_map, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "#         print(anomaly_maps1_cos.shape)\n",
    "\n",
    "def objective(params):\n",
    "    \n",
    "    \n",
    "    preds = []\n",
    "    masks = []\n",
    "    scores = []\n",
    "    labels = []\n",
    "    \n",
    "    u, v, w = params  \n",
    "    anomaly_maps1 = u*(torch.max(anomaly_uncer)/ torch.max(anomaly_maps1_cos))*anomaly_maps1_cos+v*anomaly_uncer+w*anomaly_kl\n",
    "#         anomaly_maps1 = v*anomaly_uncer        \n",
    "\n",
    "#         anomaly_maps1 = gaussian_blur2d(anomaly_maps1, kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma))[:, 0]\n",
    "    anomaly_maps = anomaly_maps1 #object_mask.to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    score = torch.topk(torch.flatten(anomaly_maps, start_dim=1), 250)[0].mean(dim=1)\n",
    "#         score = torch.flatten(anomaly_maps, start_dim=1).topk(dim=-1, k=64)[0].mean(-1)\n",
    "\n",
    "    masks.extend([m for m in anomaly_mask[:, 0, :, :].cpu().numpy()])\n",
    "    preds.extend([a for a in anomaly_maps.cpu().numpy()])\n",
    "    scores.extend([s for s in score.cpu().numpy()])\n",
    "    labels.extend([l for l in input[\"label_01\"].cpu().numpy()])\n",
    "#         print('scores',scores)\n",
    "\n",
    "    scores = normalize(np.array(scores))\n",
    "    labels = np.array(labels)\n",
    "    preds = np.array(preds)\n",
    "    masks = np.array(masks, dtype=np.int_)\n",
    "\n",
    "\n",
    "#     precisions_image, recalls_image, _ = precision_recall_curve(labels, scores)\n",
    "#     f1_scores_image = (2 * precisions_image * recalls_image) / (precisions_image + recalls_image)\n",
    "#     best_f1_scores_image = np.max(f1_scores_image[np.isfinite(f1_scores_image)])\n",
    "    auroc_image = roc_auc_score(labels, scores)\n",
    "#     AP_image = average_precision_score(labels, scores)\n",
    "\n",
    "\n",
    "#     precisions_pixel, recalls_pixel, _ = precision_recall_curve(masks.ravel(), preds.ravel())\n",
    "#     f1_scores_pixel = (2 * precisions_pixel * recalls_pixel) / (precisions_pixel + recalls_pixel)\n",
    "#     best_f1_scores_pixel = np.max(f1_scores_pixel[np.isfinite(f1_scores_pixel)])\n",
    "#     auroc_pixel = roc_auc_score(masks.ravel(), preds.ravel())\n",
    "#     AP_pixel = average_precision_score(masks.ravel(), preds.ravel())\n",
    "\n",
    "#     pro = compute_pro(masks, preds)\n",
    "\n",
    "    print(f\"test-------- I-AUROC:{round(auroc_image*100, 1)}-----\")\n",
    "\n",
    "\n",
    "    \n",
    "\n",
    "    \n",
    "    \n",
    "    return -auroc_image\n",
    "\n",
    "space = [\n",
    "    Real(0, 100, name='u'), \n",
    "    Real(0.0001, 10000, name='v'), \n",
    "    Real(0, 10000, name='w')  \n",
    "]\n",
    "\n",
    "\n",
    "result = gp_minimize(objective, space, n_calls=100, random_state=42)\n",
    "\n",
    "\n",
    "best_u, best_v, best_w = result.x\n",
    "best_auroc = -result.fun\n",
    "\n",
    "print(f\"Best AUROC: {best_auroc}, Best u: {best_u}, Best v: {best_v}, Best w: {best_w}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
