import os
import pandas as pd
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import Dataset
import transformers
from tqdm import tqdm
import gc

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

task = ['分析', '常识推理', '推荐', '编辑', '问答', '其他', '分类', '开放式生成', '描述', '识别', '写信', '判断',
        '抽取', '摘要', '转换']

print(len(task))

model_ckpt = "../Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(model_ckpt, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_ckpt,
    device_map="auto",
    trust_remote_code=True
)

lora_config = LoraConfig.from_pretrained("../resultst/final_checkpoint")

model = get_peft_model(model, lora_config)
model = model.merge_and_unload()

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)


for t in task:
    print(t)
    test_dataset = pd.read_json(f'../test/test-{t}.jsonl', lines=True)
    test_dataset['generation'] = ''
    prompt = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n'

    for i in tqdm(range(len(test_dataset))):
        text = prompt + '### Instruction:\n' + test_dataset.loc[i, 'en_input'] + '\n\n### Response:\n'

        sequences = pipeline(
            text,
            eos_token_id=tokenizer.eos_token_id,
            max_length=200,
            num_beams=1,
            repetition_penalty=2.0,
        )

        test_dataset.loc[i, 'generation'] = sequences[0]['generated_text']


    test_dataset.to_csv(f'../data/{t}_llama2_chat.csv', index=False)
