{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from Models.hicbridge  import Unet, GaussianDiffusion\n",
    "from Data.GM12878_DataModule import GM12878Module\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os \n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "\n",
    "model_ddpm = Unet(\n",
    "    dim = 64,\n",
    "    dim_mults = (1, 1, 2, 2, 4, 4),\n",
    "    channels = 1,\n",
    "    self_condition= True\n",
    ")\n",
    "\n",
    "diffusion = GaussianDiffusion(\n",
    "    model_ddpm,\n",
    "    image_size = 256,\n",
    "    beta_schedule = 'linear',\n",
    "    timesteps = 1000,   # number of steps\n",
    "    indi = False,\n",
    "    objective = 'pred_noise',\n",
    "    noise_schedule = 'null',\n",
    "    indi_step_size = 1000,\n",
    "    loss_type = 'l1'\n",
    ")\n",
    "\n",
    "model_ddpm.load_state_dict(torch.load(\"Trained_Models/ddpm_inverse_500.ckpt\", map_location=torch.device('cpu') ))\n",
    "devices = 'cuda:0'\n",
    "\n",
    "model_ddpm.to(device= devices)\n",
    "diffusion.to(device = devices)\n",
    "\n",
    "\n",
    "dm_test = GM12878Module(batch_size= 32, res=10000, piece_size=256)\n",
    "dm_test.prepare_data()\n",
    "\n",
    "train_chrom = [1,3,5,6,7,9,11,12,13,15,17,18,19,21]\n",
    "\n",
    "with torch.no_grad():\n",
    "    for chr in train_chrom:\n",
    "        dm_test.setup(stage=chr)\n",
    "\n",
    "        print(len(dm_test.test_dataloader()))\n",
    "\n",
    "        pred_x1s = []\n",
    "\n",
    "        for i,data in enumerate(dm_test.test_dataloader()):\n",
    "\n",
    "            x0 = data[1].to(devices)\n",
    "            \n",
    "            pred_x1 = diffusion.accelated_sample(num_timesteps= 500, condition = x0).detach().cpu()\n",
    "            pred_x1 = torch.triu(pred_x1) + torch.triu(pred_x1, 1).permute(0,1,3,2)            \n",
    "            pred_x1 = pred_x1.numpy()\n",
    "\n",
    "            if i == 0:\n",
    "                pred_x1s = pred_x1\n",
    "            else:\n",
    "                pred_x1s = np.vstack([pred_x1s, pred_x1])\n",
    "                \n",
    "            print(pred_x1.shape, pred_x1s.shape)\n",
    "            \n",
    "        print(pred_x1s.shape)\n",
    "        np.save(\"Data/Splits/gm12878_ddpm_more_chr_\"+str(chr)+\"_res_10000_piece_256\", pred_x1s)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "JaeminKim",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
