{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3baaf58f",
   "metadata": {},
   "source": [
    "# Unbalanced Barycenters: Testing Idea"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0cb891b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import wandb\n",
    "import random\n",
    "import itertools \n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys \n",
    "from src.data import load_dataset, DatasetSampler\n",
    "from src.utils import weights_init_D,freeze, unfreeze, middle_rgb\n",
    "from src.cost import cost_image_shape_latent, cost_image_color_latent  \n",
    "from src.models import Stochastic_ResNet_D\n",
    "\n",
    "# path to StyleGAN\n",
    "import dnnlib\n",
    "import legacy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01d2d890",
   "metadata": {},
   "source": [
    "## 1. Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e88941ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = 'cuda'\n",
    "# path to styleGAN\n",
    "GENERATOR_PATH = \"pretrained_models/colored_mnist.pkl\"\n",
    "PATH_TO_DATA =  \"../../data/train_MNIST\"\n",
    "NUMBER_PALETTES = [5_000,5_000,100]\n",
    "HUE_MEANS = [0,120,0]\n",
    "HUE_STD = 0.0\n",
    "SATURATIONS = [1,1,0] # from 0 to 1: green, red and white\n",
    "BRIGHTNESS = 1 # from 0 to 1\n",
    "SATURATION_THRESHOLD = 0.8 # from 0 to 1\n",
    "\n",
    "IMG_SIZE =32\n",
    "NC =1\n",
    "LATENT_SIZE=512\n",
    "K=2\n",
    "\n",
    "LR_ENCODER = 1e-7\n",
    "LR_POTENTIAL =1e-4\n",
    "LR_MVALUE = 1e-2\n",
    "BETAS = (0.0, 0.9)\n",
    "\n",
    "NUM_EPOCHS = 3000\n",
    "INNER_ITERATIONS = 10\n",
    "LAMBDAS = [0.5,0.5]\n",
    "TAU = 10\n",
    "BATCH_SIZE = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0250d32e",
   "metadata": {},
   "source": [
    "## 2. Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70fde470",
   "metadata": {},
   "source": [
    "### 2.1 Shape Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "299297ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = load_dataset(\"MNIST_2_3_7\", PATH_TO_DATA, img_size=32, batch_size=64,\n",
    "                                 shuffle=True, device='cpu')\n",
    "\n",
    "train_set_shape = output[0]\n",
    "test_set_shape = output[1]\n",
    "train_sampler_shape = output[2]\n",
    "test_sampler_shape = output[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de13d920",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Train size of 2 digits: \", len(torch.nonzero(train_set_shape[:][1] == 0)))\n",
    "print(\"Train size of 3 digits : \", len(torch.nonzero(train_set_shape[:][1] == 1)))\n",
    "print(\"Train size of 7 (outliers)  : \", len(torch.nonzero(train_set_shape[:][1] == 2)))\n",
    "print(\"\\n\")\n",
    "print(\"Test size of 2 digits : \", len(torch.nonzero(test_set_shape[:][1] == 0)))\n",
    "print(\"Test size of 3 digits : \", len(torch.nonzero(test_set_shape[:][1] == 1)))\n",
    "print(\"Test size of 7 (outliers) : \", len(torch.nonzero(test_set_shape[:][1] == 2)))\n",
    "# min and max of images\n",
    "torch.min(train_set_shape[:][0]),torch.max(train_set_shape[:][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5efc879",
   "metadata": {},
   "source": [
    "### 2.2 Color Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b5af72d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize hue spectr for green color, satiration and brightness\n",
    "hsv_data = []\n",
    "for mean, nums , sats in zip(HUE_MEANS, NUMBER_PALETTES , SATURATIONS):\n",
    "    hue_vectors =  mean + np.random.randn(nums)*HUE_STD # shape:( NUMBER_PALETTES, )\n",
    "    saturation_vectors = sats*np.ones(nums) # shape:( NUMBER_PALETTES, )\n",
    "    brightness_vectors = BRIGHTNESS*np.ones(nums) # shape:( NUMBER_PALETTES, )\n",
    "    hsv_vectors = np.stack([hue_vectors.reshape(-1,1),\n",
    "                        saturation_vectors.reshape(-1,1),\n",
    "                        brightness_vectors.reshape(-1,1)],axis=1).reshape(-1, 3)# shape:(NUMBER_PALETTES,3)\n",
    "\n",
    "    # translate HSV -> RGB \n",
    "    # Importantly: now Hue from 0 to 360 and we translate it from 0 to 1\n",
    "    hsv_vectors[:,0] = hsv_vectors[:,0]/360\n",
    "    print(\"len of colors\", len(hsv_vectors))\n",
    "    hsv_data.append(hsv_vectors)\n",
    "     \n",
    "# we use matplotlib function : https://matplotlib.org/stable/api/_as_gen/matplotlib.colors.hsv_to_rgb.html \n",
    "rgb_dataset = matplotlib.colors.hsv_to_rgb(np.concatenate(hsv_data,axis=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9fd6fe73",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_sampler = DatasetSampler(rgb_dataset,flag_label=False,batch_size=256)\n",
    "# min and max of colors\n",
    "np.min(rgb_dataset), np.max(rgb_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8be4eb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_samplers = []\n",
    "data_samplers.append(train_sampler_shape)\n",
    "data_samplers.append(color_sampler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45ca7e79",
   "metadata": {},
   "source": [
    "## 3. Style GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a2ab64dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with dnnlib.util.open_url(GENERATOR_PATH) as f:\n",
    "    G =  legacy.load_network_pkl(f)['G_ema'].to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3d4f6c84",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = lambda x, c: 0.5 * (G(x,c) + 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "358f2a83",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = torch.randn((1, 512)).to(DEVICE)\n",
    "image = generator(z, None)\n",
    "# print(image.max(), image.min())\n",
    "# plt.imshow(image[0].permute(1,2,0).cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "23dd03d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_out_to_0_1(x):\n",
    "    #assert torch.min(x) < -0.5\n",
    "    return torch.clip(x,0,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71018fa1",
   "metadata": {},
   "source": [
    "## 4. Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f89f5bf4",
   "metadata": {},
   "source": [
    "### 4.1 Maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cba9a5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "encoder = []\n",
    "encoder.append(Stochastic_ResNet_D(size=IMG_SIZE,\n",
    "                  nc=NC,\n",
    "                  nfilter=64, \n",
    "                  nfilter_max=512, \n",
    "                  res_ratio=0.1,\n",
    "                  noise=True,\n",
    "                  n_output=LATENT_SIZE,bn_flag=True,pn_flag=True).to(DEVICE))\n",
    "\n",
    "for f in encoder: \n",
    "    weights_init_D(f)\n",
    "\n",
    "\n",
    "class linear_model(torch.nn.Module):\n",
    "    def __init__(self, in_=3, nz=10, hidden=[128,256,512], out_=512):\n",
    "        super().__init__()\n",
    "        model= []\n",
    "        hidden = [in_+nz] + hidden + [out_]\n",
    "        for ins,outs in zip(hidden[:-1],hidden[1:]):\n",
    "            model.append(torch.nn.Linear(ins,outs,bias=True))\n",
    "            model.append(torch.nn.ReLU())\n",
    "        model.pop()\n",
    "        model = torch.nn.Sequential(*model)\n",
    "        self.all_model = model\n",
    "        self.nz = nz\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = torch.randn((len(x), self.nz), device=x.device)\n",
    "        input = torch.cat([x,z], dim=1)\n",
    "        return self.all_model(input)\n",
    "\n",
    "\n",
    "encoder.append(linear_model(nz=10).to(DEVICE))\n",
    "param_enc = [net.parameters() for net in encoder]\n",
    "encoder_opt = torch.optim.Adam(  itertools.chain(*param_enc),\n",
    "                                  LR_ENCODER, betas=BETAS)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062c2dea",
   "metadata": {},
   "source": [
    "### 4.2 Potentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fadbed70",
   "metadata": {},
   "outputs": [],
   "source": [
    "nets_for_pot = [Stochastic_ResNet_D(size=IMG_SIZE,\n",
    "                  nc=3,\n",
    "                  nfilter=64, \n",
    "                  nfilter_max=512, \n",
    "                  res_ratio=0.1,\n",
    "                  noise=False,\n",
    "                  n_output=1,bn_flag=False,pn_flag=False).to(DEVICE)\n",
    "                  ]\n",
    "\n",
    "\n",
    "for f in nets_for_pot: \n",
    "    weights_init_D(f)\n",
    "    \n",
    "nets_for_pot_opt = torch.optim.Adam( nets_for_pot[0].parameters(),\n",
    "                               LR_POTENTIAL, betas=BETAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "77a75992",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Pots(torch.nn.Module):\n",
    "    # TODO: optimize when 2 potentials\n",
    "    def __init__(self, bary_weights):\n",
    "        assert len(bary_weights) > 1\n",
    "        super().__init__()\n",
    "        self._lambdas = bary_weights\n",
    "        self._net = nets_for_pot[0]\n",
    "        \n",
    "    def __getitem__(self, idx):\n",
    "        assert 0 <= idx < 2 # only when there are two prob\n",
    "        \n",
    "        if idx == 0:\n",
    "            def f_pot(x, m): # include m\n",
    "                res = 0.0\n",
    "                res += self._net(x)\n",
    "                res += m / len(self._lambdas) / self._lambdas[idx] # include m\n",
    "                return res\n",
    "        else:\n",
    "            def f_pot(x, m): # include m\n",
    "                res = 0.0\n",
    "                res -= self._net(x)\n",
    "                res += m / len(self._lambdas) / self._lambdas[idx] # include m\n",
    "                return res\n",
    "\n",
    "        return f_pot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "be8e2429",
   "metadata": {},
   "outputs": [],
   "source": [
    "potentials = Pots(LAMBDAS).to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3007ba7",
   "metadata": {},
   "source": [
    "### 4.3 MinValue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e1d21dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add this class for m\n",
    "class MinValue(torch.nn.Module):\n",
    "    def __init__(self, device):\n",
    "        super().__init__()\n",
    "        self.m = torch.nn.Parameter(torch.zeros(1).to(device)+0)\n",
    "\n",
    "    def forward(self):\n",
    "        return self.m\n",
    "\n",
    "mvalue = MinValue(DEVICE)\n",
    "mvalue_opt = torch.optim.Adam(mvalue.parameters(), LR_MVALUE, (0, 0.9))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce2044c",
   "metadata": {},
   "source": [
    "## 5. Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44895dab",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login()\n",
    "wandb.init(project=\"Unbalanced\" ,name=\"Exp-\" )\n",
    "epoch = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ea641ec",
   "metadata": {},
   "source": [
    "### 5.1 Training Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97994d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "# base loop for training\n",
    "start_epoch = epoch\n",
    "\n",
    "for epoch in tqdm(range(start_epoch, NUM_EPOCHS)):\n",
    "    freeze(nets_for_pot[0])       \n",
    "    for k in range(K):\n",
    "        unfreeze(encoder[k])\n",
    "    freeze(mvalue)\n",
    "\n",
    "\n",
    "    for it in range(INNER_ITERATIONS):\n",
    "\n",
    "        loss = 0\n",
    "        encoder_opt.zero_grad()\n",
    "\n",
    "        # sample data from each distribution\n",
    "        data = [data_samplers[k].sample(BATCH_SIZE).to(DEVICE) for k in range(K)]\n",
    "            \n",
    "\n",
    "        for k in range(K):            \n",
    "            z_k = encoder[k](data[k])\n",
    "            mapped_k = normalize_out_to_0_1(generator(z_k, c=None))\n",
    "\n",
    "            ##======= calculation cost ========##\n",
    "            if k == 0:\n",
    "                cost_k =  cost_image_shape_latent(data[k], mapped_k).mean(dim=0) #[B,1] \n",
    "            elif k == 1:\n",
    "                cost_k =  100 * cost_image_color_latent(data[k], mapped_k).mean(dim=0)   #[B,1]\n",
    "            else:\n",
    "                raise ValueError\n",
    "                \n",
    "            if it ==  INNER_ITERATIONS - 1:\n",
    "                wandb.log({f\"cost of the {k} barycenter\": cost_k.item()}, step=epoch)\n",
    "            ##=================================##\n",
    " \n",
    "            ##====== integral of potential  ======##\n",
    "            m = mvalue()\n",
    "            pot_k = potentials[k](mapped_k, m).mean()  # [B,1].mean()\n",
    "            if it ==  INNER_ITERATIONS-1:\n",
    "                wandb.log({f\"potential of the {k} barycenter\": pot_k.item()}, step=epoch)\n",
    "            ##===================================##\n",
    "            loss_k = LAMBDAS[k]*(cost_k - pot_k) \n",
    "            loss +=  loss_k\n",
    "            if it == INNER_ITERATIONS - 1:\n",
    "                wandb.log({f\"loss of the {k} barycenter\":loss_k.item()},\n",
    "                          step=epoch)\n",
    "\n",
    "        loss.backward()\n",
    "        encoder_opt.step()\n",
    "\n",
    "    wandb.log({f\"loss of inner problem\": loss.item()},\n",
    "              step=epoch)\n",
    "\n",
    "    #================================#\n",
    "    #===========  Outer  ============#\n",
    "    #================================#\n",
    " \n",
    "    # Outer optimization problem\n",
    "    # training OT potential \n",
    "    nets_for_pot_opt.zero_grad()\n",
    "    mvalue_opt.zero_grad()\n",
    "    unfreeze(mvalue)\n",
    "    # unfreezing of potentials \n",
    "    unfreeze(nets_for_pot[0])\n",
    "    for k in range( K):\n",
    "        freeze(encoder[k])\n",
    "\n",
    "\n",
    "    m = mvalue()\n",
    "    loss = 0\n",
    "    for k in range(K):\n",
    "        ##======= get_latent_code ==========##\n",
    "        with torch.no_grad():\n",
    "            z_k = encoder[k](data[k])\n",
    "        ##====================================## \n",
    "\n",
    "        with torch.no_grad():\n",
    "            mapped_k = normalize_out_to_0_1(generator(z_k, c=None) )\n",
    "            \n",
    "        ##======= calculation cost ========##\n",
    "        if k == 0:\n",
    "            cost_k = cost_image_shape_latent(data[k], mapped_k)  #[B,1] \n",
    "        elif k ==1:\n",
    "            cost_k = 100 * cost_image_color_latent(data[k], mapped_k)\n",
    "        else:\n",
    "            raise ValueError\n",
    "\n",
    "        assert cost_k.requires_grad == False\n",
    "        ##=================================##\n",
    "\n",
    "\n",
    "        ##====== integral of potential  ======##\n",
    "        pot_k = potentials[k](mapped_k, m)  # [B,1] \n",
    "        # assert pot_k.requires_grad == True\n",
    "        ##===================================##\n",
    "        cost_k = cost_k - pot_k\n",
    "        cost_k = - 2 * TAU * (F.softplus(-cost_k/TAU) - F.softplus(cost_k*0)).mean()\n",
    "        loss_k = LAMBDAS[k]*(cost_k + m) \n",
    "        loss +=  loss_k\n",
    "\n",
    "\n",
    "    loss = -1*loss \n",
    "    loss.backward()\n",
    "    nets_for_pot_opt.step()\n",
    "    mvalue_opt.step()\n",
    "    mvalue_opt.zero_grad()\n",
    "\n",
    "    wandb.log({f\"m value\": m.item()}, step=epoch)\n",
    "\n",
    "\n",
    "\n",
    "    ##===== plotting results =====##\n",
    "\n",
    "    if epoch % 200 == 0:\n",
    "        for k in range( K):\n",
    "            freeze(encoder[k])\n",
    "\n",
    "        data = [data_samplers[k].sample(BATCH_SIZE).to(DEVICE)\n",
    "                    for k in range(K)]\n",
    "\n",
    "        for k in range( K):\n",
    "\n",
    "            with torch.no_grad():\n",
    "                z_k = encoder[k](data[k][:8])\n",
    "            \n",
    "            fig,ax = plt.subplots(7,8,figsize=(8,7),dpi=200)\n",
    "\n",
    "            for idx in range(8):\n",
    "\n",
    "                if k == 0:\n",
    "                    ax[0,idx].imshow(data[k][:8][idx].permute(1,2,0).cpu(),\n",
    "                                 cmap = 'gray' if k==0 else None)\n",
    "                if k == 1:\n",
    "                    for idx in range(8):\n",
    "                        ax[0,idx].set_aspect( 1 ) \n",
    "                        ax[0,idx].add_artist(plt.Circle(( 0.5 , 0.5 ), 0.4 ,color=data[k][:8][idx].cpu().numpy() ) ) \n",
    "                        ax[0,idx].set_xticks([]);ax[0,idx].set_yticks([]);\n",
    "\n",
    "\n",
    "            for run in range(1,6):\n",
    "                with torch.no_grad():\n",
    "                    mapped_k  = normalize_out_to_0_1(generator(z_k,c=None))#[8,3,64,64]\n",
    "\n",
    "                for idx in range(8):\n",
    "                    ax[run,idx].imshow(mapped_k[idx].detach().permute(1,2,0).cpu())\n",
    "\n",
    "\n",
    "            clr = middle_rgb(mapped_k,  SATURATION_THRESHOLD )\n",
    "            for idx in range(8):\n",
    "\n",
    "                if k == 0:\n",
    "                    ax[6,idx].imshow(data[k][:8][idx].permute(1,2,0).cpu(),\n",
    "                                 cmap = 'gray' if k==0 else None)\n",
    "                if k == 1:\n",
    "                    ax[6,idx].set_aspect( 1 ) \n",
    "                    ax[6,idx].add_artist(plt.Circle(( 0.5 , 0.5 ), 0.4 ,color=clr[idx].cpu().numpy() ) ) \n",
    "                    ax[6,idx].set_xticks([]);ax[0,idx].set_yticks([]);\n",
    "\n",
    "\n",
    "            for i in range(7):\n",
    "                for j in range(8):\n",
    "                    ax[i,j].set_xticks([]);ax[i,j].set_yticks([]);\n",
    "\n",
    "            fig.tight_layout(pad=0.01)\n",
    "\n",
    "            # wandb.log({f\"Barycenter Images {k} \" + \"unfixed\" + \" of distributions\":fig},step=epoch)\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6342963a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save\n",
    "import os\n",
    "save_path = f'trained_uotbary/TAU_{TAU}_iter{NUM_EPOCHS}'\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "\n",
    "torch.save(encoder[0], os.path.join(save_path,'encoder0.pth'))\n",
    "torch.save(encoder[1], os.path.join(save_path,'encoder1.pth'))\n",
    "torch.save(nets_for_pot[0], os.path.join(save_path,'pot.pth'))\n",
    "torch.save(mvalue, os.path.join(save_path,'mvalue.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6785d2c7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
