{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9dab9f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# For licensing see accompanying LICENSE file.\n",
    "# Copyright (C) 2024 Apple Inc. All Rights Reserved.\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89151fde-b3c9-4265-892f-dffdccfb0494",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as T\n",
    "\n",
    "pt = {}\n",
    "for i in range(4): pt.update(T.load('./_snapshot/mgie/checkpoint/pytorch_model-0000%d-of-00004.bin'%(i+1), map_location='cpu'))\n",
    "print(len(pt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22503532-f4c3-4751-be4f-3250069aa179",
   "metadata": {},
   "outputs": [],
   "source": [
    "mllm, unet = {}, {}\n",
    "for key in pt:\n",
    "    if 'embed_tokens' in key or 'lm_head' in key or 'edit_head' in key: mllm[key] = pt[key].half()\n",
    "    elif 'unet' in key: unet[key.replace('unet.', '')] = pt[key].half()\n",
    "print(len(mllm)), print(len(unet))\n",
    "T.save(mllm, './_ckpt/mgie_7b/mllm.pt'), T.save(unet, './_ckpt/mgie_7b/unet.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
