{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "13f6918f-167f-4b79-9db7-3b6bc03843f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "sys.path.append(\"../ALAE\")\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from src.light_sb_ou import LightSB_OU\n",
    "from src.distributions import TensorSampler\n",
    "from alae_ffhq_inference import load_model, decode\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "import glob\n",
    "import torch.nn as nn\n",
    "from scipy import linalg\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c6dd930c-ad3c-468d-ad1a-56f6ce1a4711",
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup_consistent_evaluation():\n",
    "    EVAL_SEED = 42  \n",
    "    torch.manual_seed(EVAL_SEED)\n",
    "    np.random.seed(EVAL_SEED)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(EVAL_SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5643d54e-2f09-4884-9f4f-572529ce0b55",
   "metadata": {},
   "outputs": [],
   "source": [
    "setup_consistent_evaluation()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d385a9eb-c630-454c-81c4-d6841904adec",
   "metadata": {},
   "source": [
    "# Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a757c66f-fc40-4ee1-a806-21d5223a8a66",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_DIR = r\"F:\\data\"\n",
    "DIM = 512\n",
    "\n",
    "fid_real_dir = os.path.join(BASE_DIR, \"fid_real_1000\")\n",
    "fid_generated_dir = os.path.join(BASE_DIR, \"fid_generated_1000\")\n",
    "os.makedirs(fid_real_dir, exist_ok=True)\n",
    "os.makedirs(fid_generated_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e3a2493-3381-43c2-8674-03f0551b4570",
   "metadata": {},
   "source": [
    "# Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0fdd73a9-5ad1-42d9-95c6-83d29e582826",
   "metadata": {},
   "outputs": [],
   "source": [
    "alae_model = load_model(\"../ALAE/configs/ffhq.yaml\", training_artifacts_dir=\"../ALAE/training_artifacts/ffhq/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a4bddb3e-d46e-4287-beab-1bcd118a15e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "latents = np.load(\"../data/latents.npy\")\n",
    "gender = np.load(\"../data/gender.npy\")\n",
    "age = np.load(\"../data/age.npy\")\n",
    "    \n",
    "train_size = 60000\n",
    "train_latents = latents[:train_size]\n",
    "train_gender = gender[:train_size]\n",
    "test_latents = latents[train_size:]\n",
    "test_gender = gender[train_size:]\n",
    "test_age = age[train_size:]\n",
    "    \n",
    "source_inds = np.arange(len(test_gender))[(test_gender == \"male\").reshape(-1)]\n",
    "source_latents = test_latents[source_inds] \n",
    "\n",
    "target_inds = np.arange(len(test_gender))[(test_gender == \"female\").reshape(-1)]\n",
    "target_latents = test_latents[target_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f579a3d5-6a7d-4f71-8bdc-93c562cf4bb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filtering MAN (source) and WOMAN (target) for training...\n",
      "Training samples - MAN: 26732, WOMAN: 32816\n",
      "Starting training...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   0%|                                                                       | 7/5000 [00:00<01:19, 63.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: Loss = 18280.113281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  20%|█████████████▋                                                     | 1020/5000 [00:09<00:37, 106.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1000: Loss = 6344.626953\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  40%|███████████████████████████                                        | 2023/5000 [00:18<00:27, 109.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 2000: Loss = 4443.502930\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  60%|████████████████████████████████████████▍                          | 3015/5000 [00:27<00:18, 107.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 3000: Loss = 4204.577637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  80%|█████████████████████████████████████████████████████▋             | 4009/5000 [00:37<00:09, 107.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 4000: Loss = 4080.543945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|███████████████████████████████████████████████████████████████████| 5000/5000 [00:46<00:00, 107.58it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "ParametrizedLightSB_OU(\n",
       "  (parametrizations): ModuleDict(\n",
       "    (S_rotation_matrix): ParametrizationList(\n",
       "      (0): Stiefel(n=512, k=512, tensorial_size=(10,), triv=linalg_matrix_exp)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Filtering MAN (source) and WOMAN (target) for training...\")\n",
    "man_inds = np.arange(train_size)[(train_gender == \"male\").reshape(-1)]\n",
    "woman_inds = np.arange(train_size)[(train_gender == \"female\").reshape(-1)]\n",
    "    \n",
    "X_train = torch.tensor(train_latents[man_inds])   \n",
    "Y_train = torch.tensor(train_latents[woman_inds])\n",
    "    \n",
    "print(f\"Training samples - MAN: {len(X_train)}, WOMAN: {len(Y_train)}\")\n",
    "    \n",
    "X_sampler = TensorSampler(X_train, device=\"cpu\")\n",
    "Y_sampler = TensorSampler(Y_train, device=\"cpu\")\n",
    "    \n",
    "D = LightSB_OU(\n",
    "    dim=DIM, \n",
    "    n_potentials=10, \n",
    "    epsilon=0.1, \n",
    "    b=0.02,\n",
    "    m=0.0,\n",
    "    sampling_batch_size=128, \n",
    "    S_diagonal_init=0.1,\n",
    "    is_diagonal=True\n",
    ").cpu()\n",
    "    \n",
    "D.init_r_by_samples(Y_sampler.sample(10))\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=1e-3)\n",
    "    \n",
    "print(\"Starting training...\")\n",
    "MAX_STEPS = 5000 \n",
    "    \n",
    "for step in tqdm(range(MAX_STEPS), desc=\"Training\"):\n",
    "    D_opt.zero_grad()\n",
    "    X0, X1 = X_sampler.sample(128), Y_sampler.sample(128)\n",
    "        \n",
    "    log_potential = D.get_log_potential(X1)\n",
    "    log_C = D.get_log_C(X0)\n",
    "        \n",
    "    D_loss = (-log_potential + log_C).mean()\n",
    "    D_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(D.parameters(), max_norm=float(\"inf\"))\n",
    "    D_opt.step()\n",
    "        \n",
    "    if step % 1000 == 0:\n",
    "        print(f\"Step {step}: Loss = {D_loss.item():.6f}\")\n",
    "    \n",
    "D.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fe23503-59b4-4d7e-8b36-147b9f26bd2b",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0b35f716-846f-4880-9940-ab99fd280336",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_images = 1000\n",
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa39851f-648d-43cb-818b-c5bd32f1fe3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_and_save_batch(model, latents, output_dir, start_idx, batch_size=8):\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, len(latents), batch_size):\n",
    "            end_idx = min(i + batch_size, len(latents))\n",
    "            batch_latents = latents[i:end_idx]\n",
    "            \n",
    "            batch_images = decode(model, batch_latents)\n",
    "            batch_images = (batch_images * 0.5 + 0.5).clamp(0, 1)\n",
    "            \n",
    "            for j, img in enumerate(batch_images):\n",
    "                global_idx = start_idx + i + j\n",
    "                img_np = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)\n",
    "                pil_img = Image.fromarray(img_np)\n",
    "                pil_img.save(os.path.join(output_dir, f\"image_{global_idx:05d}.png\"))\n",
    "            \n",
    "            del batch_images\n",
    "            if torch.cuda.is_available():\n",
    "                torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3ffa3c95-f2e5-4718-a553-4be0ff30b934",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generating 1000 REAL images (WOMAN)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Real images: 100%|█████████████████████████████████████████████████████████████████████| 32/32 [22:52<00:00, 42.88s/it]\n"
     ]
    }
   ],
   "source": [
    "print(f\"\\nGenerating {total_images} REAL images (WOMAN)...\")\n",
    "real_count = 0\n",
    "for start_idx in tqdm(range(0, total_images, batch_size), desc=\"Real images\"):\n",
    "    end_idx = min(start_idx + batch_size, total_images)\n",
    "    batch_latents = torch.tensor(target_latents[start_idx:end_idx])\n",
    "    decode_and_save_batch(alae_model, batch_latents, fid_real_dir, start_idx, batch_size=8)\n",
    "    real_count += (end_idx - start_idx)\n",
    "        \n",
    "    del batch_latents\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6e8e88b1-c96f-4d90-90d4-9bce6ac0a348",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Generating 1000 GENERATED images (MAN -> WOMAN)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generated images: 100%|████████████████████████████████████████████████████████████████| 32/32 [23:05<00:00, 43.30s/it]\n"
     ]
    }
   ],
   "source": [
    "print(f\"\\nGenerating {total_images} GENERATED images (MAN -> WOMAN)...\")\n",
    "generated_count = 0\n",
    "for start_idx in tqdm(range(0, total_images, batch_size), desc=\"Generated images\"):\n",
    "    end_idx = min(start_idx + batch_size, total_images)\n",
    "        \n",
    "    source_batch = torch.tensor(source_latents[start_idx:end_idx])\n",
    "        \n",
    "    with torch.no_grad():\n",
    "        transformed_batch = D(source_batch.cpu())\n",
    "        \n",
    "    decode_and_save_batch(alae_model, transformed_batch, fid_generated_dir, start_idx, batch_size=8)\n",
    "    generated_count += (end_idx - start_idx)\n",
    "        \n",
    "    del source_batch, transformed_batch\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0b0aa11-b21e-4094-a7c1-59df76f968aa",
   "metadata": {},
   "source": [
    "# FID calculation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c2c0181e-db88-4323-9fc9-ac267d266e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleFIDCalculator:\n",
    "    def __init__(self, device='cpu'):\n",
    "        self.device = device\n",
    "        self.transform = transforms.Compose([\n",
    "            transforms.Resize(299),\n",
    "            transforms.CenterCrop(299),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "        ])\n",
    "        \n",
    "    def load_inception_model(self):\n",
    "        try:\n",
    "            import torchvision.models as models\n",
    "            \n",
    "            try:\n",
    "                model = models.inception_v3(weights='DEFAULT')\n",
    "            except:\n",
    "                model = models.inception_v3(pretrained=True)\n",
    "            \n",
    "            model.eval()\n",
    "            model.fc = nn.Identity()\n",
    "            return model.to(self.device)\n",
    "            \n",
    "        except Exception as e:\n",
    "            print(f\"Could not load Inception model: {e}\")\n",
    "            return None\n",
    "    \n",
    "    def extract_features(self, directory, batch_size=16):\n",
    "        model = self.load_inception_model()\n",
    "        if model is None:\n",
    "            return None\n",
    "            \n",
    "        class ImageDataset(Dataset):\n",
    "            def __init__(self, directory, transform):\n",
    "                self.image_files = [os.path.join(directory, f) for f in os.listdir(directory) \n",
    "                                  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "                self.transform = transform\n",
    "            \n",
    "            def __len__(self):\n",
    "                return len(self.image_files)\n",
    "            \n",
    "            def __getitem__(self, idx):\n",
    "                try:\n",
    "                    image = Image.open(self.image_files[idx]).convert('RGB')\n",
    "                    return self.transform(image)\n",
    "                except Exception as e:\n",
    "                    print(f\"Error loading {self.image_files[idx]}: {e}\")\n",
    "                    return torch.zeros(3, 299, 299)\n",
    "        \n",
    "        dataset = ImageDataset(directory, self.transform)\n",
    "        if len(dataset) == 0:\n",
    "            print(f\"No images found in {directory}\")\n",
    "            return None\n",
    "            \n",
    "        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, \n",
    "                               num_workers=0, pin_memory=False)\n",
    "        \n",
    "        features = []\n",
    "        with torch.no_grad():\n",
    "            for batch in tqdm(dataloader, desc=f\"Processing {os.path.basename(directory)}\"):\n",
    "                batch = batch.to(self.device)\n",
    "                feat = model(batch)\n",
    "                features.append(feat.cpu().numpy())\n",
    "        \n",
    "        return np.concatenate(features, axis=0)\n",
    "    \n",
    "    def calculate_fid(self, real_dir, gen_dir, batch_size=16):\n",
    "        print(\"Extracting features from real images...\")\n",
    "        real_features = self.extract_features(real_dir, batch_size)\n",
    "        if real_features is None:\n",
    "            return None\n",
    "            \n",
    "        print(\"Extracting features from generated images...\")\n",
    "        gen_features = self.extract_features(gen_dir, batch_size)\n",
    "        if gen_features is None:\n",
    "            return None\n",
    "        \n",
    "        print(\"Calculating FID...\")\n",
    "        mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)\n",
    "        mu2, sigma2 = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)\n",
    "        \n",
    "        diff = mu1 - mu2\n",
    "        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "        \n",
    "        if np.iscomplexobj(covmean):\n",
    "            covmean = covmean.real\n",
    "            \n",
    "        fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)\n",
    "        return fid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7118e87c-36ec-4d32-a485-35b851032dd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Robust FID Calculation ===\n",
      "Extracting features from real images...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing fid_real_1000: 100%|██████████████████████████████████████████████████████| 125/125 [02:40<00:00,  1.28s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting features from generated images...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing fid_generated_1000: 100%|█████████████████████████████████████████████████| 125/125 [02:43<00:00,  1.31s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating FID...\n",
      "\n",
      "FID = 24.019523\n"
     ]
    }
   ],
   "source": [
    "print(\"=== Robust FID Calculation ===\")\n",
    "    \n",
    "real_count = len([f for f in os.listdir(fid_real_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])\n",
    "gen_count = len([f for f in os.listdir(fid_generated_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])\n",
    "        \n",
    "calculator = SimpleFIDCalculator(device='cuda' if torch.cuda.is_available() else 'cpu')\n",
    "fid_value = calculator.calculate_fid(fid_real_dir, fid_generated_dir, batch_size=8)  \n",
    "\n",
    "print(f\"\\nFID = {fid_value:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c7b22b-06b6-4ad1-82f6-685cd41fd5ef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
