{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "os.chdir('/home/jovyan/USR/data/test_time_gd/')\n",
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7!IU:pk!OSJ|u1XXdo0cZuzren|?!IU:pk!|\n",
      "decode(tokenizer.encode):\n",
      "7!IU:pk!OSJ|u1XXdo0cZuzren|?!IU:pk!|[PAD][PAD][PAD][PAD]\n"
     ]
    }
   ],
   "source": [
    "from kv_dataset_utils import generate_sequence, create_tokenizer\n",
    "\n",
    "tokenizer = create_tokenizer()\n",
    "\n",
    "num_kv_pairs = 1\n",
    "k_length = 2\n",
    "v_length = 2\n",
    "n_segments = 2\n",
    "min_segment_len = 4\n",
    "max_segment_len = 16\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "sample = generate_sequence(num_kv_pairs=num_kv_pairs, k_length=k_length, v_length=v_length, n_segments=n_segments,\n",
    "                           min_segment_len=min_segment_len, max_segment_len=max_segment_len)\n",
    "print(sample['sequence'] + sample['target'])\n",
    "\n",
    "print('decode(tokenizer.encode):')\n",
    "print(tokenizer.decode(tokenizer(sample['sequence'] + sample['target'],\n",
    "                                 padding=True, pad_to_multiple_of=8).input_ids).replace(' ', ''))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Config\n",
    "\n",
    "config = GPT2Config.from_pretrained('gpt2')\n",
    "config.n_layer = 4\n",
    "config.n_head = 4\n",
    "config.n_embd = 128\n",
    "config.vocab_size = 128\n",
    "config.pad_token_id = vocab['[PAD]']\n",
    "config.bos_token_id = vocab['[BOS]']\n",
    "config.eos_token_id = vocab['[EOS]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class KVDataset(Dataset):\n",
    "    def __init__(self, num_samples, **gen_params):\n",
    "        self.samples = [\n",
    "            generate_sequence(**gen_params) \n",
    "            for _ in range(num_samples)\n",
    "        ]\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.samples[idx]\n",
    "        input_seq = sample['sequence']\n",
    "        target_seq = sample['target']\n",
    "        \n",
    "        return {\n",
    "            'input_seq': input_seq,\n",
    "            'target_seq': target_seq,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = KVDataset(1000, \n",
    "                         num_kv_pairs=num_kv_pairs, \n",
    "                         k_length=k_length, \n",
    "                         v_length=v_length, \n",
    "                         n_segments=n_segments,\n",
    "                         min_segment_len=min_segment_len, \n",
    "                         max_segment_len=max_segment_len)\n",
    "\n",
    "def collate_fn(batch):\n",
    "    input_seq = [item['input_seq'] for item in batch]\n",
    "    target_seq = [item['target_seq'] for item in batch]\n",
    "    seq = [item['input_seq'] + item['target_seq'] for item in batch]\n",
    "    input_ids = tokenizer(seq, return_tensors=\"pt\", add_special_tokens=False,\n",
    "                          padding=True, pad_to_multiple_of=8).input_ids\n",
    "    # add labels_mask\n",
    "    # input_seq: 0, target_seq: 1, seq = input_seq + target_seq\n",
    "    labels_mask = torch.zeros_like(input_ids)\n",
    "    for i, item in enumerate(batch):\n",
    "        input_seq_len = len(item['input_seq'])\n",
    "        target_seq_len = len(item['target_seq'])\n",
    "        # +1 as bos token was added to the beginning, +2 as eos is in the end and we want to predict it\n",
    "        labels_mask[i, input_seq_len:input_seq_len+target_seq_len] = 1\n",
    "\n",
    "    labels = input_ids * labels_mask + (1 - labels_mask) * -100\n",
    "    return {\n",
    "        'input_seq': input_seq,\n",
    "        'target_seq': target_seq,\n",
    "        'input_ids': input_ids,\n",
    "        'labels': labels,\n",
    "        'labels_mask': labels_mask,\n",
    "    }\n",
    "\n",
    "batch_size = 2\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_seq ['!Tk:K2!hL|nRX21Ii|?!Tk:', 'MFkAnXE|pgLMgQ!v3:k4!V|?!v3:']\n",
      "target_seq ['K2!|', 'k4!|']\n",
      "input_ids tensor([[66, 49, 14, 68, 40, 58, 66, 11, 41, 69, 17, 47, 53, 58, 57, 38, 12, 69,\n",
      "         67, 66, 49, 14, 68, 40, 58, 66, 69,  0,  0,  0,  0,  0],\n",
      "        [42, 35, 14, 30, 17, 53, 34, 69, 19, 10, 41, 42, 10, 46, 66, 25, 59, 68,\n",
      "         14, 60, 66, 51, 69, 67, 66, 25, 59, 68, 14, 60, 66, 69]])\n",
      "labels tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
      "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,   40,\n",
      "           58,   66,   69, -100, -100, -100, -100, -100],\n",
      "        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
      "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
      "         -100, -100, -100, -100,   14,   60,   66,   69]])\n",
      "labels_mask tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,\n",
      "         1, 1, 1, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 1, 1, 1, 1]])\n"
     ]
    }
   ],
   "source": [
    "for batch in train_dataloader:\n",
    "    for k in batch:\n",
    "        print(k, batch[k])\n",
    "    break\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Training with HuggingFace Trainer\n",
    "from typing import Any, Dict, Optional, Union\n",
    "from transformers import Trainer, TrainingArguments, EarlyStoppingCallback\n",
    "from datasets import Dataset\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "from transformers import GPT2LMHeadModel\n",
    "model = GPT2LMHeadModel(config)\n",
    "\n",
    "num_kv_pairs = 1\n",
    "k_length = 2\n",
    "v_length = 2\n",
    "n_segments = 4\n",
    "min_segment_len = 2\n",
    "max_segment_len = 16\n",
    "\n",
    "# Create dataset and dataloader\n",
    "train_dataset = KVDataset(100000,\n",
    "                         num_kv_pairs=num_kv_pairs,\n",
    "                         k_length=k_length,\n",
    "                         v_length=v_length,\n",
    "                         n_segments=n_segments,\n",
    "                         min_segment_len=min_segment_len,\n",
    "                         max_segment_len=max_segment_len)\n",
    "\n",
    "valid_dataset = KVDataset(1000,\n",
    "                         num_kv_pairs=num_kv_pairs,\n",
    "                         k_length=k_length,\n",
    "                         v_length=v_length,\n",
    "                         n_segments=n_segments,\n",
    "                         min_segment_len=min_segment_len,\n",
    "                         max_segment_len=max_segment_len)\n",
    "\n",
    "def collate_fn(batch):\n",
    "    input_seq = [item['input_seq'] for item in batch]\n",
    "    target_seq = [item['target_seq'] for item in batch]\n",
    "    seq = [item['input_seq'] + item['target_seq'] for item in batch]\n",
    "    input_ids = tokenizer(seq, return_tensors=\"pt\", add_special_tokens=False,\n",
    "                          padding=True, pad_to_multiple_of=8).input_ids\n",
    "    # add labels_mask\n",
    "    # input_seq: 0, target_seq: 1, seq = input_seq + target_seq\n",
    "    labels_mask = torch.zeros_like(input_ids)\n",
    "    for i, item in enumerate(batch):\n",
    "        input_seq_len = len(item['input_seq'])\n",
    "        target_seq_len = len(item['target_seq'])\n",
    "        labels_mask[i, input_seq_len:input_seq_len+target_seq_len] = 1\n",
    "\n",
    "    labels = input_ids * labels_mask + (1 - labels_mask) * -100\n",
    "\n",
    "    return {\n",
    "        # 'input_seq': input_seq,\n",
    "        # 'target_seq': target_seq,\n",
    "        'input_ids': input_ids,\n",
    "        'labels': labels,\n",
    "        # 'labels_mask': labels_mask,\n",
    "    }\n",
    "\n",
    "# target sequence looks like: \"XXXX!|\"\n",
    "# let's not count ! and | in the accuracy calculation\n",
    "ignore_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in ['!', '|']]\n",
    "\n",
    "# Define compute_metrics function for token-level accuracy and loss\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels, inputs = eval_pred.predictions, eval_pred.label_ids, eval_pred.inputs\n",
    "\n",
    "    logits = logits[..., :-1, :]\n",
    "    labels = labels[..., 1:]\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    # Create a mask for tokens that are not padding (-100) and ignored tokens (like ! and |)\n",
    "    mask = (labels != -100)\n",
    "    for t_id in ignore_token_ids:\n",
    "        mask &= (labels != t_id)\n",
    "    \n",
    "    # Calculate token-level accuracy only on content tokens\n",
    "    masked_predictions = predictions[mask]\n",
    "    masked_labels = labels[mask]\n",
    "    \n",
    "    accuracy = (masked_predictions == masked_labels).mean()\n",
    "    \n",
    "    # get exact_match per-sample accuracy, ignore masked tokens\n",
    "    # predictions.shape = (batch_size, seq_len)\n",
    "    exact_match = np.mean([\n",
    "        np.all(pred[mask[i]] == lab[mask[i]])\n",
    "        for i, (pred, lab) in enumerate(zip(predictions, labels))\n",
    "        if np.any(mask[i])  # Skip samples that are all masked\n",
    "    ])\n",
    "\n",
    "    for pred, label, inp in zip(predictions[:5], labels[:5], inputs[:5]):\n",
    "        mask = (label != -100)\n",
    "        pred = pred[mask]\n",
    "        inp[inp==-100] = 0\n",
    "        label[label==-100] = 0\n",
    "        print('i:', tokenizer.decode(inp, skip_special_tokens=True).replace(' ', ''))\n",
    "        print('p:', tokenizer.decode(pred, skip_special_tokens=True).replace(' ', ''))\n",
    "        print('t:', tokenizer.decode(label, skip_special_tokens=True).replace(' ', ''))\n",
    "        print('-' * 50)\n",
    "\n",
    "    return {\n",
    "        \"token_accuracy\": float(accuracy),\n",
    "        \"exact_match\": float(exact_match),\n",
    "    }\n",
    "\n",
    "\n",
    "class CustomTrainer(Trainer):\n",
    "\n",
    "    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):\n",
    "        num_training_steps = int(num_training_steps / 0.9)  # to make final lr not zero, for linear it is lr/10.\n",
    "        return super().create_scheduler(num_training_steps, optimizer)\n",
    "\n",
    "    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:\n",
    "        # log early stopping patience\n",
    "        for cb in self.callback_handler.callbacks:\n",
    "            if isinstance(cb, EarlyStoppingCallback):\n",
    "                logs['patience'] = cb.early_stopping_patience_counter\n",
    "                break\n",
    "        return super().log(logs, start_time=start_time)\n",
    "\n",
    "\n",
    "# Training arguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    num_train_epochs=5,\n",
    "    per_device_train_batch_size=64,\n",
    "    per_device_eval_batch_size=64,\n",
    "    warmup_steps=0,\n",
    "    weight_decay=0.001,\n",
    "    learning_rate=1e-04,\n",
    "    lr_scheduler_type='constant_with_warmup',\n",
    "    logging_dir=\"./logs\",\n",
    "    logging_steps=100,\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=100,\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=100,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"token_accuracy\",\n",
    "    greater_is_better=True,\n",
    "    remove_unused_columns=False,\n",
    "    include_num_input_tokens_seen=True,\n",
    "    include_for_metrics=['inputs']\n",
    ")\n",
    "\n",
    "# Initialize Trainer\n",
    "trainer = CustomTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    data_collator=collate_fn,\n",
    "    compute_metrics=compute_metrics,\n",
    "    callbacks=[EarlyStoppingCallback(early_stopping_patience=15),],\n",
    ")\n",
    "\n",
    "# Train the model\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "70"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer.vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "62"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ALPHABET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~ \\t\\n\\r\\x0b\\x0c'"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string.printable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
