{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from read_data import *\n",
    "\n",
    "model_name_or_path = 'mistralai/Mistral-7B-Instruct-v0.1'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "\n",
    "# read openbookqa data\n",
    "data = get_nq_dataset('../openbookqa_data/',split='train',tokenizer=tokenizer,size=1000,sft=True)\n",
    "print(data[0])\n",
    "data = get_trivia_dataset('../openbookqa_data/',split='train',tokenizer=tokenizer,size=1000,sft=True)\n",
    "print(data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read multi-choice data\n",
    "data = get_mc_dataset('../mc_data/',name='nq_',split='train',tokenizer=tokenizer,size=1000,sft=True)\n",
    "print(data[0])\n",
    "data = get_mc_dataset('../mc_data/',name='trivia_',split='train',tokenizer=tokenizer,size=1000,sft=True)\n",
    "print(data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read machine translation data\n",
    "data = get_mt_dataset('../mt_data/',split='train',tokenizer=tokenizer,size=1000,sft=True,lang='ru')\n",
    "print(data[0])\n",
    "data = get_mt_dataset('../mt_data/',split='train',tokenizer=tokenizer,size=1000,sft=True,lang='zh')\n",
    "print(data[0])"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
