{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6f35e4b3-0b4c-43f2-8970-49e1892fa36b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from datasets import load_dataset\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f9aca2fe-0d03-4a36-91f0-5ec116e22ebc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f2df127b-66e9-4100-973e-e785e59dcb7f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'code': '#!/usr/bin/env python\\n# -*- encoding: utf-8 -*-\\n# Created on 2017-10-21 01:46:53\\n# Project: gat_zhoubianyou\\n\\nfrom pyspider.libs.base_handler import *\\nimport re\\nimport logging\\nimport random\\nfrom collections import defaultdict\\n\\nlogger = logging.getLogger(__name__)\\n\\n\\nclass Handler(BaseHandler):\\n    crawl_config = {\\n        \\'headers\\': {\\n            \\'User-Agent\\': \\'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.79 Safari/537.36\\',\\n        }\\n    }\\n\\n    PROXY_UPADATER = \\'update_proxy\\'\\n    PROXY_POOL = defaultdict(int)\\n    FAIL_THRESHOLD = 3\\n    COMMENT_FETCHER = \\'comment_fetcher\\'\\n    LOCATIONS = [\\'taipei\\', \\'hongkong\\', \\'macau\\', \\'kaohsiung\\', \\'kenting\\', \\'hualien\\']\\n\\n    @every(minutes=1)\\n    def on_start(self):\\n        # all location\\n        for location in self.LOCATIONS:\\n            proxy = random.choice(self._get_valid_proxies())\\n            self.crawl(\\n                \\'http://www.dianping.com/{location}/attraction\\'.format(location=location),\\n                callback=self.index_page,\\n                proxy=proxy,\\n                save={\\'proxy\\': proxy}\\n            )\\n\\n    @config(age=100)\\n    @catch_status_code_error\\n    def index_page(self, response):\\n        proxy = response.save[\\'proxy\\']\\n        if not response.ok:\\n            self.PROXY_POOL[proxy] += 1\\n            return\\n        page = response.doc(\\'div.Pages a.NextPage\\')\\n        if page is not None:\\n            proxy = random.choice(self._get_valid_proxies())\\n            self.crawl(\\n                page.attr.href,\\n                cookies=response.cookies,\\n                callback=self.index_page,\\n                proxy=proxy,\\n                save={\\'proxy\\': proxy}\\n            )\\n\\n        shop_selector = \\'div.poi-ctn > ul > li > div.txt > div.poi-title a\\'\\n        for shop in response.doc(shop_selector).items():\\n            proxy = random.choice(self._get_valid_proxies())\\n            self.crawl(\\n                shop.attr.href + \\'/review_more\\',\\n                cookies=response.cookies,\\n                callback=self.comment_index_page,\\n                proxy=proxy,\\n                save={\\'proxy\\': proxy}\\n            )\\n\\n    @config(priority=2, age=100)\\n    @catch_status_code_error\\n    def comment_index_page(self, response):\\n        proxy = response.save[\\'proxy\\']\\n        if not response.ok:\\n            self.PROXY_POOL[proxy] += 1\\n            return\\n        for each in response.doc(\\'a[href^=\"http\"]\\').items():\\n            # all shop comment pages\\n            if re.match(\"http://www.dianping.com/shop/\\\\d+/review_more\\\\?pageno=\\\\d+\", each.attr.href):\\n                # save comment detail\\n                self.send_message(self.COMMENT_FETCHER, {\\n                    \\'url\\': each.attr.href,\\n                    \\'cookies\\': response.cookies,\\n                }, each.attr.href)\\n                # follow\\n                proxy = random.choice(self._get_valid_proxies())\\n                self.crawl(\\n                    each.attr.href,\\n                    cookies=response.cookies,\\n                    callback=self.comment_index_page,\\n                    proxy=proxy,\\n                    save={\\'proxy\\': proxy}\\n                )\\n\\n    def on_message(self, project, message):\\n        if project == self.PROXY_UPDATER:\\n            # new proxy added\\n            proxy_host = message\\n            if not self.PROXY_POOL.get(proxy_host):\\n                self.PROXY_POOL[proxy_host] = 0\\n\\n    def _get_valid_proxies(self):\\n        # TODO: GC proxy pool\\n        proxies = [\\n            k for k, v in filter(\\n                lambda item: item[1] <= self.FAIL_THRESHOLD, self.PROXY_POOL.iteritems()\\n            )\\n        ]\\n        return proxies if proxies else [\\'\\']\\n', 'repo_name': 'manongbang/zhoubianyou', 'path': 'collector/gat_zhoubianyou.py', 'language': 'Python', 'license': 'mit', 'size': 3660}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "ds = load_dataset(\"codeparrot/github-code\", streaming=True, split=\"train\", data_dir='../../.gh_code_data')\n",
    "ds = ds.shuffle(buffer_size=10_00, seed=42)\n",
    "ds = ds.filter(lambda x: x['language'] == 'Python')\n",
    "print(next(iter(ds)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c9499d26-615c-4dbf-b11f-c48d22806972",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model directly\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Daoguang/PyCodeGPT\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"Daoguang/PyCodeGPT\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "151b85ce-8e83-4780-b7a9-b37af875f0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batchify(dataset, batch_size):\n",
    "    batch = []\n",
    "    for example in dataset:\n",
    "        batch.append(example)\n",
    "        if len(batch) == batch_size:\n",
    "            yield batch\n",
    "            batch = []\n",
    "    if batch:\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9b3c47f7-9338-42c9-95fd-b09aa622f8b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shift_batch(batch, dim=None):\n",
    "    if dim is None:\n",
    "        return batch, batch.shape[1]\n",
    "    else:\n",
    "        fin_shape = batch.shape[1] - dim + 1\n",
    "        ans = [torch.roll(batch, -i, dims=1)[:, :fin_shape] for i in range(dim)]\n",
    "        return ans, fin_shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37f8f6d9-5522-4c34-8ef4-a271f0adcde2",
   "metadata": {},
   "source": [
    "# random generated stuff go!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0f622277-12f7-4a7b-b48d-34587e716b6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Load dataset\n",
    "ds = load_dataset(\"codeparrot/github-code\", streaming=True, split=\"train\", data_dir='../../.gh_code_data')\n",
    "\n",
    "# Load tokenizer and model\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Daoguang/PyCodeGPT\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\"Daoguang/PyCodeGPT\")\n",
    "\n",
    "# Define a new head\n",
    "# class NewHead(nn.Module):\n",
    "#     def __init__(self, input_dim, output_dim):\n",
    "#         super(NewHead, self).__init__()\n",
    "#         self.linear = nn.Linear(input_dim, output_dim)\n",
    "\n",
    "#     def forward(self, x):\n",
    "#         return self.linear(x)\n",
    "\n",
    "# Initialize the new head\n",
    "input_dim = model.config.hidden_size\n",
    "output_dim = tokenizer.vocab_size  # Output dimension should match the vocabulary size for language modeling\n",
    "# new_head = NewHead(input_dim, output_dim)\n",
    "\n",
    "# Combine the model and the new head\n",
    "class CombinedModel(nn.Module):\n",
    "    def __init__(self, base_model, new_head):\n",
    "        super(CombinedModel, self).__init__()\n",
    "        self.base_model = base_model\n",
    "        self.new_head = new_head\n",
    "\n",
    "    def forward(self, input_ids, attention_mask=None, targets=None):\n",
    "        outputs = self.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        last_hidden_state = outputs.hidden_states[-1]\n",
    "        logits, loss = self.new_head(last_hidden_state, targets=targets)\n",
    "        return logits, loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b41d325d-86be-4b87-b0df-5df886011373",
   "metadata": {},
   "source": [
    "# Heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "53815fb7-0fb5-4c72-917d-43c4000dec6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1 token forward\n",
    "class DefaultHead(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(DefaultHead, self).__init__()\n",
    "        self.linear = nn.Linear(input_dim, output_dim)\n",
    "\n",
    "    def forward(self, x, targets=None):\n",
    "        if targets == None:\n",
    "            raise ValueError\n",
    "            logits = self.linear(x[:, [-1], :]) \n",
    "            loss = None\n",
    "        else:\n",
    "            logits = self.linear(x)\n",
    "            loss = F.cross_entropy(\n",
    "                logits.view(-1, logits.size(-1)), targets.view(-1),\n",
    "                ignore_index=-1)\n",
    "\n",
    "        return logits, loss\n",
    "\n",
    "class CPHead(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim, n_tokens=2,  r=2):\n",
    "\n",
    "        super(CPHead, self).__init__()\n",
    "        \n",
    "        self.n_embd = input_dim\n",
    "        self.n = output_dim\n",
    "        self.d = n_tokens\n",
    "        self.r = r\n",
    "        self.lm_heads = []\n",
    "        for k in range(n_tokens):\n",
    "            sz = self.r * self.n\n",
    "            self.lm_heads.append(\n",
    "                nn.Linear(self.n_embd, sz, bias=False))\n",
    "        self.lm_heads = nn.ModuleList(self.lm_heads)\n",
    "        \n",
    "        self.lm_head_weight = nn.Linear(self.n_embd, self.r, bias=True)\n",
    "\n",
    "    \n",
    "    def _build_core(self, k, x, targets=None):\n",
    "        B, T, C = x.shape  # Here x is [batch_size, block_size, n_embd]\n",
    "        r1 = 1 if k == 0 else self.r         # Left rank\n",
    "        n = self.n                           # Vocabulary size\n",
    "        r2 = 1 if k == self.d-1 else self.r  # Right rank\n",
    "\n",
    "        G = self.lm_heads[k](x)\n",
    "        G = G.reshape(B*T, n, self.r)\n",
    "        G = nn.functional.log_softmax(G, dim=1)\n",
    "\n",
    "        if targets is None:\n",
    "            G = G[-1]  # We select last output\n",
    "        else:\n",
    "            t = targets[k].reshape(-1, 1, 1).repeat(1, 1, self.r)\n",
    "            G = torch.gather(G, dim=1, index=t).squeeze(1)\n",
    "\n",
    "        return G\n",
    "\n",
    "    \n",
    "    def forward(self, x, targets, with_w_norm=True):\n",
    "        pred = None\n",
    "        loss = None\n",
    "\n",
    "        if targets is None:\n",
    "            pred = self._head_forward_pred(x)\n",
    "        else:        \n",
    "            loss = self._head_forward_loss(x, targets, with_w_norm=with_w_norm)\n",
    "            \n",
    "        return pred, loss\n",
    "\n",
    "\n",
    "    def _head_forward_loss(self, x, targets, with_w_norm=True):\n",
    "        w = self.lm_head_weight(x).reshape(-1, self.r)\n",
    "\n",
    "        w_norm = nn.functional.softmax(w, dim=1) \n",
    "        log_w = torch.log(w_norm)\n",
    "        log_cores = [self._build_core(k, x, targets) for k in range(self.d)]\n",
    "        log_cores.append(log_w)\n",
    "        loss = torch.sum(torch.stack(log_cores), dim=0)\n",
    "        loss = torch.logsumexp(loss, dim=1)\n",
    "        loss = -1. * torch.mean(loss) / 2\n",
    "\n",
    "        if with_w_norm:\n",
    "            wmx = torch.argmax(w, dim=1)\n",
    "            fs = torch.bincount(wmx, minlength=self.r) / w.shape[0]\n",
    "            ps = torch.mean(w_norm, dim=0)\n",
    "\n",
    "            aux_ls = (fs * ps).sum() * self.r - 1\n",
    "\n",
    "            loss += aux_ls * 1.E0\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def _head_forward_pred(self, x):\n",
    "        w = self.lm_head_weight(x).reshape(-1, self.r)\n",
    "        w = nn.functional.softmax(w, dim=1)[-1, :]\n",
    "        \n",
    "        cores = [self._build_core(k, x) for k in range(self.d)]\n",
    "        \n",
    "        if self.d == 2:\n",
    "            core1 = torch.nn.functional.softmax(cores[0], dim=0).cpu().unsqueeze(0)\n",
    "            \n",
    "            core1 = torch.einsum('ijk,k->ijk', core1, w.cpu())\n",
    "            core2 = torch.nn.functional.softmax(cores[1], dim=0).cpu().unsqueeze(0).permute(2, 1, 0).numpy()\n",
    "            tt_tensor = [core1, core2]\n",
    "            idxs = sample(tt_tensor)\n",
    "            pred = torch.tensor(idxs, dtype=int).to(w.device)\n",
    "        else:\n",
    "            item_next = torch.multinomial(w, num_samples=1)[0]\n",
    "            idxs_next = []\n",
    "            for k in range(self.d):\n",
    "                G_curr = nn.functional.softmax(cores[k][:, item_next])\n",
    "                idx_next = torch.multinomial(G_curr, num_samples=1)\n",
    "                idxs_next.append(idx_next)       \n",
    "            \n",
    "            pred = torch.stack(idxs_next, dim=1)\n",
    "\n",
    "        return pred\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20dad4d7-659f-4918-a2ea-463aa86fd7cc",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5d12435f-502c-41b3-b608-2f749c2a12eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "max_seq_length = 256\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "95e68ec9-b27d-4c2e-9882-dc188cc6f758",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a62034cbb5f43508f8f61d3c90c4fcf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 256])\n"
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(10)):  # Example number of epochs\n",
    "    ds.set_epoch(epoch)\n",
    "    i = 0\n",
    "    for batch in batchify(ds, batch_size):\n",
    "        inputs = tokenizer([example['code'] for example in batch], return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)\n",
    "        labels = inputs['input_ids']\n",
    "\n",
    "        print(labels.shape)\n",
    "        \n",
    "        break\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "89ea79ca-a584-4837-b238-d66e67f98f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b8ed8c38-8afb-41a9-bc9a-d4f790c8ea56",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04542f4a-f5f2-4af0-8b95-74ff1a933ddc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1023055f71346c7bdea7da5e5b42b54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, iter 49, Loss: 3.1212821006774902\n",
      "Epoch 1, iter 99, Loss: 2.815415859222412\n",
      "Epoch 1, iter 149, Loss: 3.8389081954956055\n",
      "Epoch 1, iter 199, Loss: 1.6362663507461548\n",
      "Epoch 1, iter 249, Loss: 1.5634301900863647\n",
      "Epoch 1, iter 299, Loss: 1.5851682424545288\n",
      "Epoch 1, iter 349, Loss: 1.5759323835372925\n",
      "Epoch 1, iter 399, Loss: 1.744019627571106\n",
      "Epoch 1, iter 449, Loss: 1.3101390600204468\n",
      "Epoch 1, iter 499, Loss: 1.6151516437530518\n",
      "Epoch 1, iter 549, Loss: 1.2113717794418335\n",
      "Epoch 1, iter 599, Loss: 1.149249792098999\n",
      "Epoch 1, iter 649, Loss: 1.6295090913772583\n",
      "Epoch 1, iter 699, Loss: 1.3771244287490845\n",
      "Epoch 1, iter 749, Loss: 1.1608221530914307\n",
      "Epoch 1, iter 799, Loss: 1.2058619260787964\n",
      "Epoch 1, iter 849, Loss: 1.7569851875305176\n",
      "Epoch 1, iter 899, Loss: 1.4447206258773804\n",
      "Epoch 1, iter 949, Loss: 1.488002061843872\n",
      "Epoch 1, iter 999, Loss: 1.5170942544937134\n",
      "Epoch 1, iter 1049, Loss: 0.6695148944854736\n",
      "Epoch 1, iter 1099, Loss: 0.9726626873016357\n",
      "Epoch 1, iter 1149, Loss: 1.168906331062317\n",
      "Epoch 1, iter 1199, Loss: 0.8974885940551758\n",
      "Epoch 1, iter 1249, Loss: 0.9510067105293274\n",
      "Epoch 1, iter 1299, Loss: 1.2436408996582031\n",
      "Epoch 1, iter 1349, Loss: 0.5787196159362793\n",
      "Epoch 1, iter 1399, Loss: 1.0897878408432007\n",
      "Epoch 1, iter 1449, Loss: 1.0857479572296143\n",
      "Epoch 1, iter 1499, Loss: 1.112429141998291\n",
      "Epoch 1, iter 1549, Loss: 0.8287778496742249\n",
      "Epoch 1, iter 1599, Loss: 1.3047486543655396\n",
      "Epoch 1, iter 1649, Loss: 1.615628719329834\n",
      "Epoch 1, iter 1699, Loss: 1.0071465969085693\n",
      "Epoch 1, iter 1749, Loss: 1.2802793979644775\n",
      "Epoch 1, iter 1799, Loss: 0.969286322593689\n",
      "Epoch 1, iter 1849, Loss: 0.8618493676185608\n",
      "Epoch 1, iter 1899, Loss: 1.5973984003067017\n",
      "Epoch 1, iter 1949, Loss: 0.9801539182662964\n",
      "Epoch 1, iter 1999, Loss: 1.1028289794921875\n",
      "Epoch 1, iter 2049, Loss: 1.1380289793014526\n",
      "Epoch 1, iter 2099, Loss: 1.1610127687454224\n",
      "Epoch 1, iter 2149, Loss: 1.0002868175506592\n",
      "Epoch 1, iter 2199, Loss: 0.6402846574783325\n",
      "Epoch 1, iter 2249, Loss: 0.7217965126037598\n",
      "Epoch 1, iter 2299, Loss: 0.5549029111862183\n",
      "Epoch 1, iter 2349, Loss: 1.0287548303604126\n",
      "Epoch 1, iter 2399, Loss: 0.5772472620010376\n",
      "Epoch 1, iter 2449, Loss: 0.4062609076499939\n",
      "Epoch 1, iter 2499, Loss: 0.687863290309906\n",
      "Epoch 1, iter 2549, Loss: 1.0387920141220093\n",
      "Epoch 1, iter 2599, Loss: 1.0308129787445068\n",
      "Epoch 1, iter 2649, Loss: 0.830735445022583\n",
      "Epoch 1, iter 2699, Loss: 0.7491565346717834\n",
      "Epoch 1, iter 2749, Loss: 1.6273232698440552\n",
      "Epoch 1, iter 2799, Loss: 0.8406375050544739\n",
      "Epoch 1, iter 2849, Loss: 0.7786165475845337\n",
      "Epoch 1, iter 2899, Loss: 0.8739748001098633\n",
      "Epoch 1, iter 2949, Loss: 0.8700061440467834\n",
      "Epoch 1, iter 2999, Loss: 0.6621016263961792\n",
      "Epoch 1, iter 3049, Loss: 0.6750820875167847\n",
      "Epoch 1, iter 3099, Loss: 0.7607225179672241\n",
      "Epoch 1, iter 3149, Loss: 0.7082582712173462\n",
      "Epoch 1, iter 3199, Loss: 0.8669043183326721\n",
      "Epoch 1, iter 3249, Loss: 0.7809939384460449\n",
      "Epoch 1, iter 3299, Loss: 0.5670786499977112\n",
      "Epoch 1, iter 3349, Loss: 0.6136247515678406\n",
      "Epoch 1, iter 3399, Loss: 1.6970748901367188\n",
      "Epoch 1, iter 3449, Loss: 0.7870987057685852\n",
      "Epoch 1, iter 3499, Loss: 0.49941208958625793\n",
      "Epoch 1, iter 3549, Loss: 0.28150802850723267\n",
      "Epoch 1, iter 3599, Loss: 0.4602346420288086\n",
      "Epoch 1, iter 3649, Loss: 0.3829125463962555\n",
      "Epoch 1, iter 3699, Loss: 0.6033856272697449\n",
      "Epoch 1, iter 3749, Loss: 0.616432249546051\n",
      "Epoch 1, iter 3799, Loss: 0.7385809421539307\n",
      "Epoch 1, iter 3849, Loss: 0.6574898362159729\n",
      "Epoch 1, iter 3899, Loss: 0.677036464214325\n",
      "Epoch 1, iter 3949, Loss: 0.9326784610748291\n",
      "Epoch 1, iter 3999, Loss: 0.5160709619522095\n",
      "Epoch 1, iter 4049, Loss: 0.7219968438148499\n",
      "Epoch 1, iter 4099, Loss: 0.8425723910331726\n",
      "Epoch 1, iter 4149, Loss: 0.6083604693412781\n",
      "Epoch 1, iter 4199, Loss: 0.6329906582832336\n",
      "Epoch 1, iter 4249, Loss: 0.8802624344825745\n",
      "Epoch 1, iter 4299, Loss: 0.7741968631744385\n",
      "Epoch 1, iter 4349, Loss: 0.41575831174850464\n",
      "Epoch 1, iter 4399, Loss: 0.6521817445755005\n",
      "Epoch 1, iter 4449, Loss: 0.6317262053489685\n",
      "Epoch 1, iter 4499, Loss: 0.6529626250267029\n",
      "Epoch 1, iter 4549, Loss: 0.4778100550174713\n",
      "Epoch 1, iter 4599, Loss: 0.7723530530929565\n",
      "Epoch 1, iter 4649, Loss: 0.8189315795898438\n",
      "Epoch 1, iter 4699, Loss: 0.5745095014572144\n",
      "Epoch 1, iter 4749, Loss: 0.49568110704421997\n",
      "Epoch 1, iter 4799, Loss: 0.5627462863922119\n",
      "Epoch 1, iter 4849, Loss: 1.5025770664215088\n",
      "Epoch 1, iter 4899, Loss: 0.5178912878036499\n",
      "Epoch 1, iter 4949, Loss: 0.9492175579071045\n",
      "Epoch 1, iter 4999, Loss: 1.2202634811401367\n",
      "Epoch 1, iter 5049, Loss: 0.5798580646514893\n",
      "Epoch 1, iter 5099, Loss: 0.7319478988647461\n",
      "Epoch 1, iter 5149, Loss: 0.6565324664115906\n",
      "Epoch 1, iter 5199, Loss: 0.535319983959198\n",
      "Epoch 1, iter 5249, Loss: 1.0117807388305664\n",
      "Epoch 1, iter 5299, Loss: 0.6302080154418945\n",
      "Epoch 1, iter 5349, Loss: 0.6724334359169006\n",
      "Epoch 1, iter 5399, Loss: 0.6787680983543396\n",
      "Epoch 1, iter 5449, Loss: 0.4345482885837555\n",
      "Epoch 1, iter 5499, Loss: 1.1291089057922363\n",
      "Epoch 1, iter 5549, Loss: 0.8137646913528442\n",
      "Epoch 1, iter 5599, Loss: 0.5360690355300903\n",
      "Epoch 1, iter 5649, Loss: 0.6674485206604004\n",
      "Epoch 1, iter 5699, Loss: 0.5762836337089539\n",
      "Epoch 1, iter 5749, Loss: 0.8750637173652649\n",
      "Epoch 1, iter 5799, Loss: 0.5703371167182922\n",
      "Epoch 1, iter 5849, Loss: 1.1112709045410156\n",
      "Epoch 1, iter 5899, Loss: 0.5713849663734436\n",
      "Epoch 1, iter 5949, Loss: 0.45026758313179016\n",
      "Epoch 1, iter 5999, Loss: 0.3358539044857025\n",
      "Epoch 1, iter 6049, Loss: 0.5991744995117188\n",
      "Epoch 1, iter 6099, Loss: 0.912899911403656\n",
      "Epoch 1, iter 6149, Loss: 0.3186110258102417\n",
      "Epoch 1, iter 6199, Loss: 0.5868042707443237\n",
      "Epoch 1, iter 6249, Loss: 0.5554691553115845\n",
      "Epoch 1, iter 6299, Loss: 0.5308990478515625\n",
      "Epoch 1, iter 6349, Loss: 1.0692439079284668\n",
      "Epoch 1, iter 6399, Loss: 0.5854967832565308\n",
      "Epoch 1, iter 6449, Loss: 1.258205533027649\n",
      "Epoch 1, iter 6499, Loss: 0.5196616649627686\n",
      "Epoch 1, iter 6549, Loss: 0.45179447531700134\n",
      "Epoch 1, iter 6599, Loss: 0.49724435806274414\n",
      "Epoch 1, iter 6649, Loss: 0.46973928809165955\n",
      "Epoch 1, iter 6699, Loss: 0.40244752168655396\n",
      "Epoch 1, iter 6749, Loss: 0.6960927844047546\n",
      "Epoch 1, iter 6799, Loss: 0.6703901290893555\n",
      "Epoch 1, iter 6849, Loss: 0.3451893627643585\n",
      "Epoch 1, iter 6899, Loss: 0.9348644018173218\n",
      "Epoch 1, iter 6949, Loss: 0.5330748558044434\n",
      "Epoch 1, iter 6999, Loss: 0.7467839121818542\n",
      "Epoch 1, iter 7049, Loss: 0.455165833234787\n",
      "Epoch 1, iter 7099, Loss: 0.4408777952194214\n",
      "Epoch 1, iter 7149, Loss: 0.44364723563194275\n",
      "Epoch 1, iter 7199, Loss: 1.15677809715271\n",
      "Epoch 1, iter 7249, Loss: 0.525337815284729\n",
      "Epoch 1, iter 7299, Loss: 0.8112031817436218\n",
      "Epoch 1, iter 7349, Loss: 0.4286497235298157\n",
      "Epoch 1, iter 7399, Loss: 0.5570491552352905\n",
      "Epoch 1, iter 7449, Loss: 0.3835654556751251\n",
      "Epoch 1, iter 7499, Loss: 0.5100639462471008\n",
      "Epoch 1, iter 7549, Loss: 0.5613664388656616\n",
      "Epoch 1, iter 7599, Loss: 0.5676849484443665\n",
      "Epoch 1, iter 7649, Loss: 0.5152788162231445\n",
      "Epoch 1, iter 7699, Loss: 0.5608341097831726\n",
      "Epoch 1, iter 7749, Loss: 1.1779799461364746\n",
      "Epoch 1, iter 7799, Loss: 0.5508455038070679\n",
      "Epoch 1, iter 7849, Loss: 0.5777753591537476\n",
      "Epoch 1, iter 7899, Loss: 0.45882031321525574\n",
      "Epoch 1, iter 7949, Loss: 0.4754669666290283\n",
      "Epoch 1, iter 7999, Loss: 0.6197246313095093\n",
      "Epoch 1, iter 8049, Loss: 0.44745585322380066\n",
      "Epoch 1, iter 8099, Loss: 0.6352418661117554\n",
      "Epoch 1, iter 8149, Loss: 0.5656026005744934\n",
      "Epoch 1, iter 8199, Loss: 0.4492414891719818\n",
      "Epoch 1, iter 8249, Loss: 0.4688359797000885\n",
      "Epoch 1, iter 8299, Loss: 0.7494032382965088\n",
      "Epoch 1, iter 8349, Loss: 1.0732836723327637\n",
      "Epoch 1, iter 8399, Loss: 0.806303083896637\n",
      "Epoch 1, iter 8449, Loss: 0.3599080741405487\n",
      "Epoch 1, iter 8499, Loss: 0.3013931214809418\n",
      "Epoch 1, iter 8549, Loss: 0.5256626605987549\n",
      "Epoch 1, iter 8599, Loss: 0.6085220575332642\n",
      "Epoch 1, iter 8649, Loss: 0.30817729234695435\n",
      "Epoch 1, iter 8699, Loss: 0.5415157079696655\n",
      "Epoch 1, iter 8749, Loss: 0.36323943734169006\n",
      "Epoch 1, iter 8799, Loss: 0.4086468815803528\n",
      "Epoch 1, iter 8849, Loss: 0.48809683322906494\n",
      "Epoch 1, iter 8899, Loss: 0.49047866463661194\n",
      "Epoch 1, iter 8949, Loss: 1.083860158920288\n",
      "Epoch 1, iter 8999, Loss: 0.3271395266056061\n",
      "Epoch 1, iter 9049, Loss: 0.42238450050354004\n",
      "Epoch 1, iter 9099, Loss: 0.26786568760871887\n",
      "Epoch 1, iter 9149, Loss: 0.5641554594039917\n",
      "Epoch 1, iter 9199, Loss: 0.8878674507141113\n",
      "Epoch 1, iter 9249, Loss: 0.35602787137031555\n",
      "Epoch 1, iter 9299, Loss: 0.5185374021530151\n",
      "Epoch 1, iter 9349, Loss: 0.5238265991210938\n",
      "Epoch 1, iter 9399, Loss: 0.43237823247909546\n",
      "Epoch 1, iter 9449, Loss: 0.2863785922527313\n",
      "Epoch 1, iter 9499, Loss: 0.3449312746524811\n",
      "Epoch 1, iter 9549, Loss: 0.2594788372516632\n",
      "Epoch 1, iter 9599, Loss: 0.8797675967216492\n",
      "Epoch 1, iter 9649, Loss: 0.24675533175468445\n",
      "Epoch 1, iter 9699, Loss: 0.5762588977813721\n",
      "Epoch 1, iter 9749, Loss: 0.4463930130004883\n",
      "Epoch 1, iter 9799, Loss: 0.717012345790863\n",
      "Epoch 1, iter 9849, Loss: 0.6285265684127808\n",
      "Epoch 1, iter 9899, Loss: 0.7853543162345886\n",
      "Epoch 1, iter 9949, Loss: 0.4842085540294647\n",
      "Epoch 1, iter 9999, Loss: 0.48675984144210815\n",
      "Epoch 1, iter 10049, Loss: 0.22412027418613434\n",
      "Epoch 1, iter 10099, Loss: 0.6983969211578369\n",
      "Epoch 1, iter 10149, Loss: 0.5395795702934265\n",
      "Epoch 1, iter 10199, Loss: 0.5216236710548401\n",
      "Epoch 1, iter 10249, Loss: 0.6413951516151428\n",
      "Epoch 1, iter 10299, Loss: 0.5272928476333618\n",
      "Epoch 1, iter 10349, Loss: 0.5756781697273254\n",
      "Epoch 1, iter 10399, Loss: 0.3975771963596344\n",
      "Epoch 1, iter 10449, Loss: 0.6457111239433289\n",
      "Epoch 1, iter 10499, Loss: 0.4156190752983093\n",
      "Epoch 1, iter 10549, Loss: 0.6887861490249634\n",
      "Epoch 1, iter 10599, Loss: 0.4721302390098572\n",
      "Epoch 1, iter 10649, Loss: 0.4616475999355316\n",
      "Epoch 1, iter 10749, Loss: 0.4870109558105469\n",
      "Epoch 1, iter 10799, Loss: 0.4528903067111969\n",
      "Epoch 1, iter 10849, Loss: 0.6945491433143616\n",
      "Epoch 1, iter 10899, Loss: 0.7176368832588196\n",
      "Epoch 1, iter 10949, Loss: 0.48355644941329956\n",
      "Epoch 1, iter 10999, Loss: 0.7421779632568359\n",
      "Epoch 1, iter 11049, Loss: 0.38444891571998596\n",
      "Epoch 1, iter 11099, Loss: 0.4159281551837921\n",
      "Epoch 1, iter 11149, Loss: 0.41810375452041626\n",
      "Epoch 1, iter 11199, Loss: 0.4496159255504608\n",
      "Epoch 1, iter 11249, Loss: 0.48001042008399963\n",
      "Epoch 1, iter 11299, Loss: 0.4814954102039337\n",
      "Epoch 1, iter 11349, Loss: 0.6593631505966187\n",
      "Epoch 1, iter 11399, Loss: 0.32246875762939453\n",
      "Epoch 1, iter 11449, Loss: 0.3741908371448517\n",
      "Epoch 1, iter 11499, Loss: 0.36032670736312866\n",
      "Epoch 1, iter 11549, Loss: 0.5536816120147705\n",
      "Epoch 1, iter 11599, Loss: 0.5873249173164368\n",
      "Epoch 1, iter 11649, Loss: 0.44044458866119385\n",
      "Epoch 1, iter 11699, Loss: 0.4085703492164612\n",
      "Epoch 1, iter 11749, Loss: 0.4723142385482788\n",
      "Epoch 1, iter 11799, Loss: 0.48492202162742615\n",
      "Epoch 1, iter 11849, Loss: 0.5371261239051819\n",
      "Epoch 1, iter 11899, Loss: 0.5357022285461426\n",
      "Epoch 1, iter 11949, Loss: 0.5463935732841492\n",
      "Epoch 1, iter 11999, Loss: 0.6238596439361572\n",
      "Epoch 1, iter 12049, Loss: 0.5716202855110168\n",
      "Epoch 1, iter 12099, Loss: 0.48773688077926636\n",
      "Epoch 1, iter 12149, Loss: 0.38733917474746704\n",
      "Epoch 1, iter 12199, Loss: 0.4491721987724304\n",
      "Epoch 1, iter 12249, Loss: 0.5991108417510986\n",
      "Epoch 1, iter 12299, Loss: 0.4511454999446869\n",
      "Epoch 1, iter 12349, Loss: 0.4736548364162445\n",
      "Epoch 1, iter 12399, Loss: 0.5043595433235168\n",
      "Epoch 1, iter 12449, Loss: 0.24921149015426636\n",
      "Epoch 1, iter 12499, Loss: 0.6136733889579773\n",
      "Epoch 1, iter 12549, Loss: 0.34853240847587585\n",
      "Epoch 1, iter 12599, Loss: 0.48390236496925354\n",
      "Epoch 1, iter 12649, Loss: 0.4543877840042114\n",
      "Epoch 1, iter 12699, Loss: 0.3550724983215332\n",
      "Epoch 1, iter 12749, Loss: 0.5426774621009827\n",
      "Epoch 1, iter 12799, Loss: 0.3952428996562958\n",
      "Epoch 1, iter 12849, Loss: 0.39713799953460693\n",
      "Epoch 1, iter 12899, Loss: 0.5450232028961182\n",
      "Epoch 1, iter 12949, Loss: 0.37484219670295715\n",
      "Epoch 1, iter 12999, Loss: 0.6502382755279541\n",
      "Epoch 1, iter 13049, Loss: 0.5390260219573975\n",
      "Epoch 1, iter 13099, Loss: 0.4892347455024719\n",
      "Epoch 1, iter 13149, Loss: 0.40763887763023376\n",
      "Epoch 1, iter 13199, Loss: 0.4128453731536865\n",
      "Epoch 1, iter 13249, Loss: 0.9366653561592102\n",
      "Epoch 1, iter 13299, Loss: 0.4623536169528961\n",
      "Epoch 1, iter 13349, Loss: 0.48702624440193176\n",
      "Epoch 1, iter 13399, Loss: 0.4636775553226471\n",
      "Epoch 1, iter 13449, Loss: 0.49984848499298096\n",
      "Epoch 1, iter 13499, Loss: 0.5499081015586853\n",
      "Epoch 1, iter 13549, Loss: 0.5302959084510803\n",
      "Epoch 1, iter 13599, Loss: 0.4349346160888672\n",
      "Epoch 1, iter 13649, Loss: 0.41960409283638\n",
      "Epoch 1, iter 13699, Loss: 0.67790687084198\n",
      "Epoch 1, iter 13749, Loss: 0.5351594090461731\n",
      "Epoch 1, iter 13799, Loss: 0.6144745945930481\n",
      "Epoch 1, iter 13849, Loss: 0.4176826477050781\n",
      "Epoch 1, iter 13899, Loss: 0.21118378639221191\n",
      "Epoch 1, iter 13949, Loss: 0.4454456567764282\n",
      "Epoch 1, iter 13999, Loss: 0.46540045738220215\n",
      "Epoch 1, iter 14049, Loss: 0.40801697969436646\n",
      "Epoch 1, iter 14099, Loss: 0.9979115724563599\n",
      "Epoch 1, iter 14149, Loss: 0.30207404494285583\n",
      "Epoch 1, iter 14199, Loss: 0.4945093095302582\n",
      "Epoch 1, iter 14249, Loss: 0.7119024991989136\n",
      "Epoch 1, iter 14299, Loss: 0.4488457441329956\n",
      "Epoch 1, iter 14349, Loss: 0.6182435750961304\n",
      "Epoch 1, iter 14399, Loss: 0.4677220284938812\n",
      "Epoch 1, iter 14449, Loss: 0.4087035357952118\n",
      "Epoch 1, iter 14499, Loss: 1.048825740814209\n",
      "Epoch 1, iter 14549, Loss: 0.4911051392555237\n",
      "Epoch 1, iter 14599, Loss: 0.4915575683116913\n",
      "Epoch 1, iter 14649, Loss: 0.30046430230140686\n",
      "Epoch 1, iter 14699, Loss: 0.4385083317756653\n",
      "Epoch 1, iter 14749, Loss: 1.1807423830032349\n",
      "Epoch 1, iter 14799, Loss: 0.5214922428131104\n",
      "Epoch 1, iter 14849, Loss: 0.4983356297016144\n",
      "Epoch 1, iter 14899, Loss: 0.39507007598876953\n",
      "Epoch 1, iter 14949, Loss: 0.48845359683036804\n",
      "Epoch 1, iter 14999, Loss: 0.2837924063205719\n",
      "Epoch 1, iter 15049, Loss: 0.8146432042121887\n",
      "Epoch 1, iter 15099, Loss: 0.4554211497306824\n",
      "Epoch 1, iter 15149, Loss: 0.4286242723464966\n",
      "Epoch 1, iter 15199, Loss: 0.474740207195282\n",
      "Epoch 1, iter 15249, Loss: 0.3306503891944885\n",
      "Epoch 1, iter 15299, Loss: 0.4137367606163025\n",
      "Epoch 1, iter 15349, Loss: 0.6922922134399414\n",
      "Epoch 1, iter 15399, Loss: 0.5414167046546936\n",
      "Epoch 1, iter 15449, Loss: 0.35694006085395813\n",
      "Epoch 1, iter 15499, Loss: 0.7253087162971497\n",
      "Epoch 1, iter 15549, Loss: 0.5955715179443359\n",
      "Epoch 1, iter 15599, Loss: 0.6924446225166321\n",
      "Epoch 1, iter 15649, Loss: 0.5913186073303223\n",
      "Epoch 1, iter 15699, Loss: 0.5844463109970093\n",
      "Epoch 1, iter 15749, Loss: 0.5139960050582886\n",
      "Epoch 1, iter 15799, Loss: 0.5787766575813293\n",
      "Epoch 1, iter 15849, Loss: 0.6490830779075623\n",
      "Epoch 1, iter 15899, Loss: 0.5081668496131897\n",
      "Epoch 1, iter 15949, Loss: 0.45147252082824707\n",
      "Epoch 1, iter 15999, Loss: 0.45459502935409546\n",
      "Epoch 1, iter 16049, Loss: 0.5689467191696167\n",
      "Epoch 1, iter 16099, Loss: 0.5833981037139893\n",
      "Epoch 1, iter 16149, Loss: 0.6352296471595764\n",
      "Epoch 1, iter 16199, Loss: 0.5339817404747009\n",
      "Epoch 1, iter 16249, Loss: 0.47534021735191345\n",
      "Epoch 1, iter 16299, Loss: 0.3649521768093109\n",
      "Epoch 1, iter 16349, Loss: 0.5009918212890625\n",
      "Epoch 1, iter 16399, Loss: 0.841126024723053\n",
      "Epoch 1, iter 16449, Loss: 0.39880526065826416\n",
      "Epoch 1, iter 16499, Loss: 0.48172569274902344\n",
      "Epoch 1, iter 16549, Loss: 0.4299085736274719\n",
      "Epoch 1, iter 16599, Loss: 0.5887877941131592\n",
      "Epoch 1, iter 16649, Loss: 0.32056725025177\n",
      "Epoch 1, iter 16699, Loss: 0.4583275318145752\n",
      "Epoch 1, iter 16749, Loss: 0.5258041024208069\n",
      "Epoch 1, iter 16799, Loss: 0.8984687924385071\n",
      "Epoch 1, iter 16849, Loss: 1.119257926940918\n",
      "Epoch 1, iter 16899, Loss: 0.5229347944259644\n",
      "Epoch 1, iter 16949, Loss: 0.5490875244140625\n",
      "Epoch 1, iter 16999, Loss: 0.4250674247741699\n",
      "Epoch 1, iter 17049, Loss: 0.5404412150382996\n",
      "Epoch 1, iter 17099, Loss: 0.7760878801345825\n",
      "Epoch 1, iter 17149, Loss: 0.4481006860733032\n",
      "Epoch 1, iter 17199, Loss: 0.48803454637527466\n",
      "Epoch 1, iter 17249, Loss: 0.23219579458236694\n",
      "Epoch 1, iter 17299, Loss: 0.3924061954021454\n",
      "Epoch 1, iter 17349, Loss: 0.5241193771362305\n",
      "Epoch 1, iter 17399, Loss: 0.6610113978385925\n",
      "Epoch 1, iter 17449, Loss: 0.46530216932296753\n",
      "Epoch 1, iter 17499, Loss: 0.4836755394935608\n",
      "Epoch 1, iter 17549, Loss: 0.4510023593902588\n",
      "Epoch 1, iter 17599, Loss: 0.6759055852890015\n",
      "Epoch 1, iter 17649, Loss: 0.43574637174606323\n",
      "Epoch 1, iter 17699, Loss: 0.6719537377357483\n",
      "Epoch 1, iter 17749, Loss: 0.3458125591278076\n",
      "Epoch 1, iter 17799, Loss: 0.2915322184562683\n",
      "Epoch 1, iter 17849, Loss: 0.37252116203308105\n",
      "Epoch 1, iter 17899, Loss: 0.42159131169319153\n",
      "Epoch 1, iter 17949, Loss: 0.42481669783592224\n",
      "Epoch 1, iter 17999, Loss: 0.4450947344303131\n",
      "Epoch 1, iter 18049, Loss: 0.3909890949726105\n",
      "Epoch 1, iter 18099, Loss: 0.6365718245506287\n",
      "Epoch 1, iter 18149, Loss: 0.41497907042503357\n",
      "Epoch 1, iter 18199, Loss: 0.49946728348731995\n",
      "Epoch 1, iter 18249, Loss: 0.5808010101318359\n",
      "Epoch 1, iter 18299, Loss: 0.4586760103702545\n",
      "Epoch 1, iter 18349, Loss: 0.32198891043663025\n",
      "Epoch 1, iter 18399, Loss: 0.5525920987129211\n",
      "Epoch 1, iter 18449, Loss: 0.3036528527736664\n",
      "Epoch 1, iter 18499, Loss: 0.34318941831588745\n",
      "Epoch 1, iter 18549, Loss: 0.36441904306411743\n",
      "Epoch 1, iter 18599, Loss: 0.44645950198173523\n",
      "Epoch 1, iter 18649, Loss: 0.3598591983318329\n",
      "Epoch 1, iter 18699, Loss: 0.5903543829917908\n",
      "Epoch 1, iter 18749, Loss: 0.40728288888931274\n",
      "Epoch 1, iter 18799, Loss: 0.4548740088939667\n",
      "Epoch 1, iter 18849, Loss: 0.7617127299308777\n",
      "Epoch 1, iter 18899, Loss: 0.36053767800331116\n",
      "Epoch 1, iter 18949, Loss: 0.46292224526405334\n",
      "Epoch 1, iter 18999, Loss: 0.695456862449646\n",
      "Epoch 1, iter 19049, Loss: 0.4444076418876648\n",
      "Epoch 1, iter 19099, Loss: 0.24886776506900787\n",
      "Epoch 1, iter 19149, Loss: 0.3985143303871155\n",
      "Epoch 1, iter 19199, Loss: 0.6678828001022339\n",
      "Epoch 1, iter 19249, Loss: 1.0412769317626953\n",
      "Epoch 1, iter 19299, Loss: 0.40019193291664124\n",
      "Epoch 1, iter 19349, Loss: 0.910286545753479\n",
      "Epoch 1, iter 19399, Loss: 0.8972539305686951\n",
      "Epoch 1, iter 19449, Loss: 0.699773371219635\n",
      "Epoch 1, iter 19499, Loss: 0.5403860211372375\n",
      "Epoch 1, iter 19549, Loss: 0.42979973554611206\n",
      "Epoch 1, iter 19599, Loss: 0.6086570620536804\n",
      "Epoch 1, iter 19649, Loss: 0.5679246783256531\n",
      "Epoch 1, iter 19699, Loss: 0.5574814677238464\n",
      "Epoch 1, iter 19749, Loss: 0.7106163501739502\n",
      "Epoch 1, iter 19799, Loss: 0.6173287034034729\n",
      "Epoch 1, iter 19849, Loss: 0.4307824373245239\n",
      "Epoch 1, iter 19899, Loss: 0.33494555950164795\n",
      "Epoch 1, iter 19949, Loss: 0.5813589692115784\n",
      "Epoch 1, iter 19999, Loss: 0.5473554134368896\n",
      "Epoch 1, iter 20049, Loss: 0.5786235928535461\n",
      "Epoch 1, iter 20099, Loss: 0.3729066848754883\n",
      "Epoch 1, iter 20149, Loss: 0.28402435779571533\n",
      "Epoch 1, iter 20199, Loss: 0.29238390922546387\n",
      "Epoch 1, iter 20249, Loss: 0.46344563364982605\n",
      "Epoch 1, iter 20299, Loss: 0.37522295117378235\n",
      "Epoch 1, iter 20349, Loss: 0.371530681848526\n",
      "Epoch 1, iter 20399, Loss: 0.33756330609321594\n",
      "Epoch 1, iter 20449, Loss: 0.9239720702171326\n",
      "Epoch 1, iter 20499, Loss: 0.4708491563796997\n",
      "Epoch 1, iter 20549, Loss: 0.42001527547836304\n",
      "Epoch 1, iter 20599, Loss: 0.36276403069496155\n",
      "Epoch 1, iter 20649, Loss: 0.368743360042572\n",
      "Epoch 1, iter 20699, Loss: 0.4651585519313812\n",
      "Epoch 1, iter 20749, Loss: 0.4650946259498596\n",
      "Epoch 1, iter 20799, Loss: 0.606928288936615\n",
      "Epoch 1, iter 20849, Loss: 0.5596572160720825\n",
      "Epoch 1, iter 20899, Loss: 0.294989675283432\n",
      "Epoch 1, iter 20949, Loss: 0.6201750040054321\n",
      "Epoch 1, iter 20999, Loss: 0.5126871466636658\n",
      "Epoch 1, iter 21049, Loss: 0.6758185029029846\n",
      "Epoch 1, iter 21099, Loss: 0.5652550458908081\n",
      "Epoch 1, iter 21149, Loss: 1.039006233215332\n",
      "Epoch 1, iter 21199, Loss: 0.3413277268409729\n",
      "Epoch 1, iter 21249, Loss: 0.450190007686615\n",
      "Epoch 1, iter 21299, Loss: 0.9402927160263062\n",
      "Epoch 1, iter 21349, Loss: 0.4239976108074188\n",
      "Epoch 1, iter 21399, Loss: 0.4193718135356903\n",
      "Epoch 1, iter 21449, Loss: 0.29953113198280334\n",
      "Epoch 1, iter 21499, Loss: 0.43580615520477295\n",
      "Epoch 1, iter 21549, Loss: 0.5083497166633606\n",
      "Epoch 1, iter 21599, Loss: 0.5041278600692749\n",
      "Epoch 1, iter 21649, Loss: 0.4822164475917816\n",
      "Epoch 1, iter 21699, Loss: 0.41728323698043823\n",
      "Epoch 1, iter 21749, Loss: 0.5509250164031982\n",
      "Epoch 1, iter 21799, Loss: 0.42467769980430603\n",
      "Epoch 1, iter 21849, Loss: 0.48956239223480225\n",
      "Epoch 1, iter 21899, Loss: 0.4364137351512909\n",
      "Epoch 1, iter 21949, Loss: 0.35061416029930115\n",
      "Epoch 1, iter 21999, Loss: 0.4591214656829834\n",
      "Epoch 1, iter 22049, Loss: 0.5740695595741272\n",
      "Epoch 1, iter 22099, Loss: 0.508955180644989\n",
      "Epoch 1, iter 22149, Loss: 0.4526319205760956\n",
      "Epoch 1, iter 22199, Loss: 0.6612211465835571\n",
      "Epoch 1, iter 22249, Loss: 0.3007740378379822\n",
      "Epoch 1, iter 22299, Loss: 0.3628166913986206\n",
      "Epoch 1, iter 22349, Loss: 0.38378629088401794\n",
      "Epoch 1, iter 22399, Loss: 0.354465126991272\n",
      "Epoch 1, iter 22449, Loss: 0.39485833048820496\n",
      "Epoch 1, iter 22499, Loss: 0.5114896893501282\n",
      "Epoch 1, iter 22549, Loss: 0.5157807469367981\n",
      "Epoch 1, iter 22599, Loss: 0.5817335844039917\n",
      "Epoch 1, iter 22649, Loss: 0.4792875349521637\n",
      "Epoch 1, iter 22699, Loss: 0.4823263883590698\n",
      "Epoch 1, iter 22749, Loss: 0.5256558060646057\n",
      "Epoch 1, iter 22799, Loss: 0.42263394594192505\n",
      "Epoch 1, iter 22849, Loss: 0.3515787720680237\n",
      "Epoch 1, iter 22899, Loss: 0.7328521609306335\n",
      "Epoch 1, iter 22949, Loss: 0.5316333174705505\n",
      "Epoch 1, iter 22999, Loss: 0.30043545365333557\n",
      "Epoch 1, iter 23049, Loss: 0.5161182880401611\n",
      "Epoch 1, iter 23099, Loss: 0.6147794127464294\n",
      "Epoch 1, iter 23149, Loss: 0.5367518663406372\n",
      "Epoch 1, iter 23199, Loss: 0.5803971290588379\n",
      "Epoch 1, iter 23249, Loss: 0.37121692299842834\n",
      "Epoch 1, iter 23299, Loss: 0.3971558213233948\n",
      "Epoch 1, iter 23349, Loss: 0.28810495138168335\n",
      "Epoch 1, iter 23399, Loss: 0.45270583033561707\n",
      "Epoch 1, iter 23449, Loss: 0.5335826873779297\n",
      "Epoch 1, iter 23499, Loss: 0.46416613459587097\n",
      "Epoch 1, iter 23549, Loss: 0.48225533962249756\n",
      "Epoch 1, iter 23599, Loss: 0.36873695254325867\n",
      "Epoch 1, iter 23649, Loss: 0.390213280916214\n",
      "Epoch 1, iter 23699, Loss: 0.32959944009780884\n",
      "Epoch 1, iter 23749, Loss: 0.8608492016792297\n",
      "Epoch 1, iter 23799, Loss: 0.5320635437965393\n",
      "Epoch 1, iter 23849, Loss: 0.2755299508571625\n",
      "Epoch 1, iter 23899, Loss: 0.47592976689338684\n",
      "Epoch 1, iter 23949, Loss: 0.3554219603538513\n",
      "Epoch 1, iter 23999, Loss: 0.36596930027008057\n",
      "Epoch 1, iter 24049, Loss: 0.308322548866272\n",
      "Epoch 1, iter 24099, Loss: 0.37769755721092224\n",
      "Epoch 1, iter 24149, Loss: 0.38423916697502136\n",
      "Epoch 1, iter 24199, Loss: 0.25022655725479126\n",
      "Epoch 1, iter 24249, Loss: 0.5640965700149536\n",
      "Epoch 1, iter 24299, Loss: 0.31266242265701294\n",
      "Epoch 1, iter 24349, Loss: 0.3877889811992645\n",
      "Epoch 1, iter 24399, Loss: 0.4608120322227478\n",
      "Epoch 1, iter 24449, Loss: 0.36243104934692383\n",
      "Epoch 1, iter 24499, Loss: 0.8535250425338745\n",
      "Epoch 1, iter 24549, Loss: 0.5410448312759399\n",
      "Epoch 1, iter 24599, Loss: 0.2677982449531555\n",
      "Epoch 1, iter 24649, Loss: 0.7620446681976318\n",
      "Epoch 1, iter 24699, Loss: 0.2891354262828827\n",
      "Epoch 1, iter 24749, Loss: 0.3890286087989807\n",
      "Epoch 1, iter 24799, Loss: 0.4077194333076477\n",
      "Epoch 1, iter 24849, Loss: 0.46321338415145874\n",
      "Epoch 1, iter 24899, Loss: 0.40945887565612793\n",
      "Epoch 1, iter 24949, Loss: 0.6449191570281982\n",
      "Epoch 1, iter 24999, Loss: 0.3355913758277893\n",
      "Epoch 1, iter 25049, Loss: 0.4131660461425781\n",
      "Epoch 1, iter 25099, Loss: 0.4951544404029846\n",
      "Epoch 1, iter 25149, Loss: 0.6010679602622986\n",
      "Epoch 1, iter 25199, Loss: 0.7546350955963135\n",
      "Epoch 1, iter 25249, Loss: 0.23390167951583862\n",
      "Epoch 1, iter 25299, Loss: 0.41392385959625244\n",
      "Epoch 1, iter 25349, Loss: 0.9156453609466553\n",
      "Epoch 1, iter 25399, Loss: 0.4599466621875763\n",
      "Epoch 1, iter 25449, Loss: 0.6037635803222656\n",
      "Epoch 1, iter 25499, Loss: 0.3594077229499817\n",
      "Epoch 1, iter 25549, Loss: 0.43203800916671753\n",
      "Epoch 1, iter 25599, Loss: 0.7027798891067505\n",
      "Epoch 1, iter 25649, Loss: 0.48244941234588623\n",
      "Epoch 1, iter 25699, Loss: 0.634736955165863\n",
      "Epoch 1, iter 25749, Loss: 0.6853585243225098\n",
      "Epoch 1, iter 25799, Loss: 0.4029219448566437\n",
      "Epoch 1, iter 25849, Loss: 0.4011194705963135\n",
      "Epoch 1, iter 25899, Loss: 0.4585818648338318\n",
      "Epoch 1, iter 25949, Loss: 0.4339938759803772\n",
      "Epoch 1, iter 25999, Loss: 0.34321683645248413\n",
      "Epoch 1, iter 26049, Loss: 0.3867975175380707\n",
      "Epoch 1, iter 26099, Loss: 0.5168933868408203\n",
      "Epoch 1, iter 26149, Loss: 0.4324333965778351\n",
      "Epoch 1, iter 26199, Loss: 0.2991330027580261\n",
      "Epoch 1, iter 26249, Loss: 0.3389105796813965\n",
      "Epoch 1, iter 26299, Loss: 0.4445664584636688\n",
      "Epoch 1, iter 26349, Loss: 0.5697652697563171\n",
      "Epoch 1, iter 26399, Loss: 0.6271389722824097\n"
     ]
    }
   ],
   "source": [
    "batch_size = 4\n",
    "max_seq_length = 2048\n",
    "# dim = 1\n",
    "# new_head = DefaultHead(model.config.hidden_size, tokenizer.vocab_size)\n",
    "dim = 2\n",
    "new_head = CPHead(model.config.hidden_size, tokenizer.vocab_size, n_tokens=dim)\n",
    "combined_model = CombinedModel(model, new_head).to(device)\n",
    "\n",
    "\n",
    "# Define loss and optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(combined_model.parameters(), lr=1e-4)\n",
    "\n",
    "\n",
    "# Training loop\n",
    "for epoch in tqdm(range(10)):  # Example number of epochs\n",
    "    ds.set_epoch(epoch)\n",
    "    i = 0\n",
    "    for batch in batchify(ds, batch_size):\n",
    "        inputs = tokenizer([example['code'] for example in batch], return_tensors='pt', padding=True, truncation=True, max_length=max_seq_length)\n",
    "        labels, seq_len = shift_batch(inputs['input_ids'], dim=dim)\n",
    "        labels = [elem.to(device) for elem in labels]\n",
    "        xs, attn_mask = inputs['input_ids'], inputs['attention_mask']\n",
    "        xs = xs[:, :seq_len]\n",
    "        attn_mask = attn_mask[:, :seq_len]\n",
    "        \n",
    "        # Forward pass\n",
    "        logits, loss = combined_model(xs.to(device), attention_mask=attn_mask.to(device), targets=labels)\n",
    "        #loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        i += 1\n",
    "\n",
    "        if i % 50 == 49:\n",
    "            lss.append(loss.item())\n",
    "            print(f\"Epoch {epoch + 1}, iter {i}, Loss: {loss.item()}\")\n",
    "            \n",
    "    print(f\"Epoch {epoch + 1}, Loss: {loss.item()}\")\n",
    "\n",
    "print(\"Training complete!\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87d3740-aac9-40ce-aa49-15097ee17788",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(p.numel() for p in model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc9c0cd-be18-4c9a-8097-a3da4729fde5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
