{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "\n",
    "# Get the absolute path of the parent folder (where config.py lives)\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\n",
    "if parent_dir not in sys.path:\n",
    "    sys.path.insert(0, parent_dir)\n",
    "\n",
    "print(\"Added to sys.path:\", parent_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import json\n",
    "\n",
    "from config import get_config\n",
    "from train_utils.gpu_utils import get_device_summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== FLOPs / Params utilities =====\n",
    "# Drop this in a cell after you've created the model\n",
    "\n",
    "import torch\n",
    "import math\n",
    "\n",
    "def _make_dummy_input(input_shape, device):\n",
    "    # input_shape: (C, H, W)\n",
    "    return torch.randn(1, *input_shape, device=device)\n",
    "\n",
    "def count_trainable_params(model: torch.nn.Module) -> int:\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "def profile_with_fvcore(model, input_shape, device):\n",
    "    try:\n",
    "        from fvcore.nn import FlopCountAnalysis, parameter_count\n",
    "    except Exception as e:\n",
    "        return None\n",
    "    model_eval = model.eval()\n",
    "    dummy = _make_dummy_input(input_shape, device)\n",
    "    with torch.no_grad():\n",
    "        flops = FlopCountAnalysis(model_eval, dummy).total()        # FLOPs (float ops)\n",
    "    params = parameter_count(model_eval)[\"\"]\n",
    "    macs = flops / 2.0                                              # approx: 1 MAC ≈ 2 FLOPs\n",
    "    return dict(params=params, macs=macs, flops=flops)\n",
    "\n",
    "def profile_with_ptflops(model, input_shape, device):\n",
    "    try:\n",
    "        from ptflops import get_model_complexity_info\n",
    "    except Exception as e:\n",
    "        return None\n",
    "    C, H, W = input_shape\n",
    "    # ptflops expects (C, H, W) and returns MACs\n",
    "    model_eval = model.eval().to(device)\n",
    "    with torch.no_grad():\n",
    "        macs_str, params_str = get_model_complexity_info(\n",
    "            model_eval, (C, H, W), as_strings=True, print_per_layer_stat=False, verbose=False\n",
    "        )\n",
    "    # Parse strings like '12.34 M', '0.56 G'\n",
    "    def _to_num(s):\n",
    "        s = s.strip().upper().replace(' ', '')\n",
    "        if s.endswith('K'): return float(s[:-1]) * 1e3\n",
    "        if s.endswith('M'): return float(s[:-1]) * 1e6\n",
    "        if s.endswith('G'): return float(s[:-1]) * 1e9\n",
    "        return float(s)\n",
    "    macs = _to_num(macs_str)\n",
    "    params = _to_num(params_str)\n",
    "    flops = macs * 2.0  # common convention\n",
    "    return dict(params=params, macs=macs, flops=flops)\n",
    "\n",
    "def profile_model(model, input_shape=(1,32,32), device='cuda' if torch.cuda.is_available() else 'cpu'):\n",
    "    \"\"\"\n",
    "    Returns a dict with:\n",
    "      - params (int)\n",
    "      - macs   (float, operations)\n",
    "      - flops  (float, operations)\n",
    "      - pretty strings for Params (M), MACs (G), FLOPs (G)\n",
    "    \"\"\"\n",
    "    model = model.to(device).eval()\n",
    "    # First: fvcore (FLOPs), else: ptflops (MACs)\n",
    "    result = profile_with_fvcore(model, input_shape, device)\n",
    "    if result is None:\n",
    "        result = profile_with_ptflops(model, input_shape, device)\n",
    "    if result is None:\n",
    "        raise RuntimeError(\"Neither fvcore nor ptflops is available. Install one of them to compute FLOPs/MACs.\")\n",
    "\n",
    "    params = int(result[\"params\"])\n",
    "    macs   = float(result[\"macs\"])\n",
    "    flops  = float(result[\"flops\"])\n",
    "\n",
    "    pretty = {\n",
    "        \"Params (M)\": f\"{params/1e6:.2f}\",\n",
    "        \"MACs (G)\" : f\"{macs/1e9:.3f}\",\n",
    "        \"FLOPs (G)\": f\"{flops/1e9:.3f}\",\n",
    "    }\n",
    "    return {\"params\": params, \"macs\": macs, \"flops\": flops, \"pretty\": pretty}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== Example usage after you build your model =====\n",
    "def print_model_profile(model, cfg, device):\n",
    "    try:\n",
    "        trainable_params=count_trainable_params(model)\n",
    "        total_params=sum(p.numel() for p in model.parameters())\n",
    "        prof = profile_model(model, input_shape=cfg.input_shape, device=device)\n",
    "        print(\"Trainable parameters:\", f\"{trainable_params:,}\")\n",
    "        print(\"All parameters:\", f\"{total_params:,}\")\n",
    "        print(\"== Model profile ==\")\n",
    "        for k, v in prof[\"pretty\"].items():\n",
    "            print(f\"{k}: {v}\")\n",
    "    except RuntimeError as e:\n",
    "        print(e)\n",
    "        print(\"\\nQuick install (choose one):\")\n",
    "        print(\"  !pip install fvcore\")\n",
    "        print(\"  # or\")\n",
    "        print(\"  !pip install ptflops\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device= get_device_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.model import create_model\n",
    "cfg=get_config(config_path=\"/config/convnext_gaussian_bs32_ep50_lr1e-04_ds7200000_g500_sched-RLRP.yml\")\n",
    "\n",
    "print(json.dumps(vars(cfg), indent=2))\n",
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n",
    "\n",
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)\n",
    "print_model_profile(model, cfg, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.model import create_model\n",
    "cfg=get_config(config_path=\"/config/efficientnet_bs32_ep50_lr1e-02_ds7200000_g500.yml\")\n",
    "\n",
    "print(json.dumps(vars(cfg), indent=2))\n",
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n",
    "\n",
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)\n",
    "print_model_profile(model, cfg, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.model import create_model\n",
    "cfg=get_config(config_path=\"/config/swin_bs32_ep50_lr1e-04_ds7200000_g500.yml\")\n",
    "\n",
    "print(json.dumps(vars(cfg), indent=2))\n",
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n",
    "\n",
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)\n",
    "print_model_profile(model, cfg, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.model_vit import create_model\n",
    "cfg=get_config(config_path=\"/\" \\\n",
    "\"experiments/exp_preload_trained_model_and_train_more/config/vit_tiny_patch16_224_gaussian_bs32_ep50_lr1e-04_p12_ds7200000_g500_sched-RLRP_preload_p4.yml\")\n",
    "\n",
    "print(json.dumps(vars(cfg), indent=2))\n",
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n",
    "\n",
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)\n",
    "print_model_profile(model, cfg, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.model_mamba import create_model\n",
    "cfg=get_config(config_path=\"/\" \\\n",
    "\"experiments/exp_preload_trained_model_and_train_more/config/mambaout_base_plus_rw_bs16_ep50_lr1e-04_p12_ds7200000_g500_sched-RLRP_preload-p4.yml\")\n",
    "\n",
    "print(json.dumps(vars(cfg), indent=2))\n",
    "os.makedirs(cfg.output_dir, exist_ok=True)\n",
    "print(f\"[INFO] Saving all outputs to: {cfg.output_dir}\")\n",
    "\n",
    "# Model and optimizer\n",
    "model, optimizer = create_model(cfg.backbone, cfg.input_shape, cfg.learning_rate)\n",
    "model.to(device)\n",
    "print_model_profile(model, cfg, device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
