import sys
import torch
from transformers import (
    TrainingArguments, 
    DataCollatorWithPadding, 
    AutoTokenizer,
    AutoModelForCausalLM,
)
from adapters import AdapterConfig
from peft import get_peft_model, LoraConfig, TaskType
from datetime import datetime
from datasets import Dataset
import os
# from ft_utils import read_jsonl, load_model, get_train_idx, write_jsonl
from peft import PeftModel
from transformers import Trainer, EvalPrediction
import torch
import numpy as np
import wandb
import json
from tqdm import tqdm
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="/gpfs/radev/home/tl688/scratch/llamaf/LLaMA-Factory/saves/llama/lora/sft_8bbase_70bllamainfo_new_updatedescription_binary/checkpoint-600/")
parser.add_argument('--savepath',type=str, default ="./save_out/llama8b_70bdata_binaryout.pkl")
parser.add_argument('--testdatapath',type=str, default ="../../INC-Math/ft_data/llama3.1-70b/test/data_lvl_54321_greedy_4class.json")
args = parser.parse_args()
model_name = 'meta-llama/Llama-3.1-8B'
store_path = args.model
save_path = args.savepath
test_path = args.testdatapath

use_adapter = True

base_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(store_path)
base_model.resize_token_embeddings(len(tokenizer))
if use_adapter:
    print('Using adapter')
    model = PeftModel.from_pretrained(base_model, store_path)
else:
    print('No adapter used')
    model = base_model.from_pretrained(store_path)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

import pandas as pd
import numpy as np

# df_train = pd.read_json("../../INC-Math/ft_data/llama3.1-8b/train/data_lvl_543_greedy.jsonl")
import pandas as pd    
df_test = pd.read_json(test_path)
print(df_test.head())


input_data = []
labels = []
for item in df_test.index:
    item = df_test.loc[item]
    input_data.append(item['instruction'])
    labels.append(item['output'])

generated_text = []
for idx in range(0,len(input_data),4):
    input_text = input_data[idx:idx+4]
    with torch.no_grad():
        # Pass attention mask to handle padding properly
        generation_config = {
        "max_new_tokens": 10,          # 控制生成长度
        "do_sample": False,
        # "top_p": 0.9,
        # "repetition_penalty": 1.2,
        "pad_token_id": tokenizer.eos_token_id,
        "use_cache": True               # 启用KV缓存
        }
        encoding = tokenizer(
                    input_text,
                    max_length=2048,
                    truncation=True,
                    padding=True,
                    return_tensors='pt',
                    padding_side='left'
                )
        input_ids = encoding['input_ids'].cuda()
        attention_mask = encoding['attention_mask'].cuda()
        output = model.generate(input_ids=input_ids, attention_mask=attention_mask, **generation_config)
    #     print(output)
    #             text = tokenizer.batch_decode(output.logits, skip_special_tokens=True)
        text = tokenizer.batch_decode(output, skip_special_tokens=True)
        for idx_text in range(len(text)):
            generated_text.append(text[idx_text][len(input_text[idx_text]):])
#         break


import json
# Save the list to a file
with open(save_path, "w") as f:
    json.dump(generated_text, f, indent=4)  # indent=4 for pretty formatting

# # Calculate accuracy
# correct_count = 0
# total_count = len(labels)
# predicted_class = []

# for i in range(total_count):
# #     decision = generated_text[i].split(": ")[-1].split('.')[0]
#     decision = generated_text[i].split(" ")[-1].split('.')[0]
# #     if decision not in ['Yes', 'No']:
# #         decision = 'Yes'
#     predicted_class.append(decision)
# #     print(decision)
# #     print(labels[i])
#     if decision in labels[i]:
#         correct_count += 1

# accuracy = correct_count / total_count

# accuracy

