import json
import numpy as np

import encoder

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data

import numpy
import io
import sys
import threading
import math
import random

import json
import collections
from collections import Counter
from collections import OrderedDict
from progress.bar import Bar as Bar


parser = argparse.ArgumentParser()
parser.add_argument('--input', default=None, type=str, help='ft input file')
parser.add_argument('--vocab', type=str, default=None, help='vocab path')
parser.add_argument('--output', default=None, type=str, help='ft output file')
parser.add_argument('--add_bos', action='store_true', help='')
parser.add_argument('--add_eos', action='store_true', help='')
args = parser.parse_args()


if __name__ == "__main__":
    enc = encoder.get_encoder(args.vocab)
    
    writer = open(args.output, 'w')

    with open(args.input, 'r') as reader:
        line_idx = 0
        for line in reader:
            items = json.loads(line.strip())
            context = items['context']
            completion = items['completion']

            bos = 50256
            eos = 50256
            context_bpes, _ = enc.encode(context) 
            context_bpes += [bos] if args.add_bos else []

            completion_bpes, _ = enc.encode(' ' + completion)
            completion_bpes += [eos] if args.add_eos else []

            ft_json = {}
            ft_json['context'] = context_bpes
            ft_json['completion'] = completion_bpes 
            writer.write(json.dumps(ft_json)+'\n')

            line_idx += 1

    writer.close()
