{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torch.utils.data import Dataset\n",
    "import transformers\n",
    "from typing import Dict, Optional, Sequence\n",
    "from transformers import AutoTokenizer,AutoModelForCausalLM\n",
    "import copy\n",
    "import torch\n",
    "from dataclasses import dataclass, field\n",
    "from transformers.generation.stopping_criteria import StoppingCriteria,StoppingCriteriaList,STOPPING_CRITERIA_INPUTS_DOCSTRING,add_start_docstrings\n",
    "import yaml\n",
    "import alfworld\n",
    "import alfworld.agents.environment\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "IGNORE_INDEX = -100\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'   # specify the gpu id which is available\n",
    "\n",
    "model_name_or_path = 'MODEL_NAME_OR_PATH for the model used'\n",
    "\n",
    "base_model_name = \"codellama7b\"   # Change this base_model_name to the model that you use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16,device_map='auto')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path,padding_side='left')\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('succ_alfworld_examples_good_bad.json','r') as f:\n",
    "    succ_examples = json.load(f)\n",
    "with open('fail_alfworld_examples_train_good_bad.json','r') as f:\n",
    "    fail_examples = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefixes = {\n",
    "    'pick_and_place': 'put',\n",
    "    'pick_clean_then_place': 'clean',\n",
    "    'pick_heat_then_place': 'heat',\n",
    "    'pick_cool_then_place': 'cool',\n",
    "    'look_at_obj': 'examine',\n",
    "    'pick_two_obj': 'puttwo'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_from_traj(traj):\n",
    "    data = list()\n",
    "\n",
    "    traj = traj.replace('> ','>')\n",
    "    traj = traj.strip().split('\\n')\n",
    "\n",
    "    assert (len(traj)-2)%3==0\n",
    "\n",
    "    for i in range(4,len(traj),3):\n",
    "        output = traj[i].split('==>')[-1]+'\\n'\n",
    "        if not traj[i].endswith('GOOD.') and not traj[i].endswith('BAD.'):\n",
    "            for j in range(i+1,len(traj)):\n",
    "                output += traj[j]+'\\n'\n",
    "                if traj[j].endswith('GOOD.') or traj[j].endswith('BAD.'):\n",
    "                    break\n",
    "        data.append({'input':'\\n'.join(traj[:i])+'\\n==>','output':output})\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_examples = list()\n",
    "for k,v in prefixes.items():\n",
    "    for i in range(2):\n",
    "        train_examples.extend(get_data_from_traj(succ_examples[f'react_{v}_{i}']))\n",
    "    train_examples.extend(get_data_from_traj(fail_examples[f'react_{v}_0']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:\n",
    "    \"\"\"Tokenize a list of strings.\"\"\"\n",
    "    tokenized_list = [\n",
    "        tokenizer(\n",
    "            text,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=\"longest\",\n",
    "            # max_length=tokenizer.model_max_length,\n",
    "            truncation=True,\n",
    "        )\n",
    "        for text in strings\n",
    "    ]\n",
    "    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]\n",
    "    input_ids_lens = labels_lens = [\n",
    "        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list\n",
    "    ]\n",
    "    return dict(\n",
    "        input_ids=input_ids,\n",
    "        labels=labels,\n",
    "        input_ids_lens=input_ids_lens,\n",
    "        labels_lens=labels_lens,\n",
    "    )\n",
    "\n",
    "def preprocess(\n",
    "    sources: Sequence[str],\n",
    "    targets: Sequence[str],\n",
    "    tokenizer: transformers.PreTrainedTokenizer,\n",
    ") -> Dict:\n",
    "    \"\"\"Preprocess the data by tokenizing.\"\"\"\n",
    "    examples = [s + t for s, t in zip(sources, targets)]\n",
    "    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]\n",
    "    input_ids = examples_tokenized[\"input_ids\"]\n",
    "    labels = copy.deepcopy(input_ids)\n",
    "    for label, source_len in zip(labels, sources_tokenized[\"input_ids_lens\"]):\n",
    "        label[:source_len] = IGNORE_INDEX\n",
    "    return dict(input_ids=input_ids, labels=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SupervisedDataset(Dataset):\n",
    "    \"\"\"Dataset for supervised fine-tuning.\"\"\"\n",
    "\n",
    "    def __init__(self, all_examples: list, tokenizer: transformers.PreTrainedTokenizer=None, is_eval=False, data_size=256):\n",
    "        super(SupervisedDataset, self).__init__()\n",
    "\n",
    "        list_data_dict = all_examples\n",
    "\n",
    "        sources = [example['input'] for example in list_data_dict]\n",
    "        \n",
    "        if not is_eval:\n",
    "            targets = [f\"{example['output']}{tokenizer.eos_token}\" for example in list_data_dict]\n",
    "            # self.targets = targets\n",
    "            data_dict = preprocess(sources, targets, tokenizer)\n",
    "\n",
    "            self.input_ids = data_dict[\"input_ids\"]\n",
    "            self.labels = data_dict[\"labels\"]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return dict(input_ids=self.input_ids[i], labels=self.labels[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class DataCollatorForSupervisedDataset(object):\n",
    "    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n",
    "\n",
    "    tokenizer: transformers.PreTrainedTokenizer\n",
    "\n",
    "    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n",
    "        input_ids, labels = tuple([instance[key] for instance in instances] for key in (\"input_ids\", \"labels\"))\n",
    "        input_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id\n",
    "        )\n",
    "        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)\n",
    "        return dict(\n",
    "            input_ids=input_ids,\n",
    "            labels=labels,\n",
    "            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = SupervisedDataset(train_examples,tokenizer=tokenizer)\n",
    "eval_dataset = SupervisedDataset(train_examples,tokenizer=tokenizer)\n",
    "data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_module = dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import prepare_model_for_kbit_training\n",
    "\n",
    "model.gradient_checkpointing_enable()\n",
    "model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=32,\n",
    "    lora_alpha=64,\n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "        \"lm_head\",\n",
    "    ],\n",
    "    bias=\"none\",\n",
    "    lora_dropout=0.05,\n",
    "    task_type=\"CAUSAL_LM\",\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "from datetime import datetime\n",
    "\n",
    "project = \"critic-finetune-succ12-fail6-good-bad\"\n",
    "run_name = base_model_name + \"-\" + project\n",
    "output_dir = \"./\" + run_name\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "trainer = transformers.Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=transformers.TrainingArguments(\n",
    "        output_dir=output_dir,\n",
    "        warmup_steps=1,\n",
    "        per_device_train_batch_size=2,\n",
    "        gradient_accumulation_steps=1,\n",
    "        gradient_checkpointing=True,\n",
    "        max_steps=1000,\n",
    "        learning_rate=2.5e-5, # Want a small lr for finetuning\n",
    "        bf16=True,\n",
    "        optim='adamw_torch',\n",
    "        logging_dir=\"./logs\",        # Directory for storing logs\n",
    "        logging_steps=100,\n",
    "        save_strategy=\"steps\",       # Save the model checkpoint every logging step\n",
    "        save_steps=100,                # Save checkpoints every 50 steps\n",
    "        evaluation_strategy=\"steps\", # Evaluate the model every logging step\n",
    "        eval_steps=100,               # Evaluate and save checkpoints every 50 steps\n",
    "        do_eval=True,                # Perform evaluation at the end of training\n",
    "    ),\n",
    "    **data_module\n",
    ")\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alfworld",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
