import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json

from minimal_multitask.data import DATASETS, FileDataset
from minimal_multitask.utils import encode_with_messages_format

from tqdm import tqdm
import argparse
import os
import pickle

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="huggyllama/llama-7b")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--train_dataset", type=str, default="alpaca")
parser.add_argument("--dtype", default="bf16")
parser.add_argument("--prompt_only", action="store_true")
parser.add_argument("--label_only", action="store_true")
parser.add_argument("--only_first_two", action="store_true")  # only use the first two messages
parser.add_argument("--output_file", type=str)

args = parser.parse_args()

assert os.path.exists(args.train_dataset)

if os.path.exists(args.output_file):
    print(f'File {args.output_file} already exists. Skipping...')
    exit(0)

if args.dtype == "bf16":
    kwargs = {"torch_dtype": torch.bfloat16}
elif args.dtype == "fp16":
    kwargs = {"torch_dtype": torch.float16}
elif args.dtype == "fp32":
    kwargs = {"torch_dtype": torch.float32}
if "llama" in args.model_name:
    kwargs["attn_implementation"] = "sdpa"

if os.getenv("HF_TOKEN") is not None:
    kwargs["token"] = os.getenv("HF_TOKEN")

if args.tokenizer is not None:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)

base_train_dataset = load_dataset("json", data_files=args.train_dataset)["train"]
def tokenize(x):
    return encode_with_messages_format(x, tokenizer, 2048, True, args.label_only, args.only_first_two, args.prompt_only)
train_dataset = base_train_dataset.map(
    tokenize, num_proc=8, load_from_cache_file=True, keep_in_memory=False
)
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

print(f"Train dataset size: {len(train_dataset)}")

selected_idx = []
for i, sample in tqdm(enumerate(train_dataset)):
    if (sample['labels'] != -100).any():
        selected_idx.append(i)

j = 0
with open(args.output_file, "w") as fout:
    with open(args.train_dataset, "r") as fin:
        for i, line in tqdm(enumerate(fin)):
            if i % 5000 == 0:
                print(f'{i=}')

            if i == selected_idx[j]:
                fout.write(line)
                j += 1
                if j % 5000 == 0:
                    print(f'{j=}')
                