import sys
import torch
from transformers import (
    TrainingArguments, 
    DataCollatorWithPadding, 
    AutoTokenizer,
    AutoModelForCausalLM,
)
from adapters import AdapterConfig
from peft import get_peft_model, LoraConfig, TaskType
from datetime import datetime
from datasets import Dataset
import os
# from ft_utils import read_jsonl, load_model, get_train_idx, write_jsonl
from peft import PeftModel
from transformers import Trainer, EvalPrediction
import torch
import numpy as np
import wandb
import json
from tqdm import tqdm
import argparse


# parser = argparse.ArgumentParser()
# parser.add_argument('--model', type=str, default="/gpfs/radev/home/tl688/scratch/llamaf/LLaMA-Factory/saves/llama/lora/sft_8bbase_70bllamainfo_new_updatedescription_binary/checkpoint-600/")
# parser.add_argument('--savepath',type=str, default ="./save_out/llama8b_70bdata_binaryout.pkl")
# args = parser.parse_args()
model_name = 'meta-llama/Llama-3.1-8B'
store_path = "/gpfs/radev/home/tl688/scratch/llamaf/LLaMA-Factory/saves/llama/lora/sft_8bbase_70bllamainfo_new_updatedescription_binary/checkpoint-600/"
save_path = "./save_out/llama8b_70bdata_binaryout_test.pkl"

use_adapter = True

base_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(store_path)
base_model.resize_token_embeddings(len(tokenizer))
if use_adapter:
    print('Using adapter')
    model = PeftModel.from_pretrained(base_model, store_path)
else:
    print('No adapter used')
    model = base_model.from_pretrained(store_path)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

import pandas as pd    
df_test = pd.read_json("../../INC-Math/ft_data/llama3.1-70b/test/data_lvl_54321_greedy_binaryclass.json")

input_data = []
labels = []
for item in df_test.index:
    item = df_test.loc[item]
    input_data.append(item['instruction'])
    labels.append(item['output'])

generated_text = []
for idx in range(0,len(input_data),4):
    input_text = input_data[idx:idx+4]
    with torch.no_grad():
        # Pass attention mask to handle padding properly
        generation_config = {
        "max_new_tokens": 10,          # 控制生成长度
        "do_sample": False,
        # "top_p": 0.9,
        # "repetition_penalty": 1.2,
        "pad_token_id": tokenizer.eos_token_id,
        "use_cache": True               # 启用KV缓存
        }
        encoding = tokenizer(
                    input_text,
                    max_length=2048,
                    truncation=True,
                    padding=True,
                    return_tensors='pt',
                    padding_side='left'
                )
        input_ids = encoding['input_ids'].cuda()
        attention_mask = encoding['attention_mask'].cuda()
        output = model.generate(input_ids=input_ids, attention_mask=attention_mask, **generation_config)
    #     print(output)
    #             text = tokenizer.batch_decode(output.logits, skip_special_tokens=True)
        text = tokenizer.batch_decode(output, skip_special_tokens=True)
#         print(text[0][len(input_text[0]):])
        for idx_text in range(len(text)):
            generated_text.append(text[idx_text][len(input_text[idx_text]):])
#         labels.append(item['output'].split(': ')[-1].split('.')[0])


import json
# Save the list to a file
with open("./save_out/llama8b_70bdata_binaryout.pkl", "w") as f:
    json.dump(generated_text, f, indent=4)  # indent=4 for pretty formatting


import math

def compute_perp(prompt_text,target_text):
    # Tokenize the prompt and target texts separately
    prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt", truncation=False, max_length=2048)
    target_ids = tokenizer.encode(target_text, return_tensors="pt")

    # Concatenate prompt and target tokens.
    # The model will see the whole sequence, but we want the loss computed only for the target tokens.
    input_ids = torch.cat([prompt_ids, target_ids[:,1:]], dim=1)

    # Create labels for computing the loss.
    # For tokens you don’t want to compute the loss on (here, the prompt tokens), set their label value to -100.
    labels = input_ids.clone()
    labels[:, :prompt_ids.size(1)] = -100  # mask the prompt tokens

    # Compute the loss (average cross-entropy over the target tokens).
    with torch.no_grad():
        outputs = model(input_ids.cuda(), labels=labels.cuda())
        loss = outputs.loss

    # Calculate perplexity: exponentiation of the loss.
    perplexity = math.exp(loss.item())
    return perplexity

# Calculate accuracy
correct_count = 0
total_count = len(labels)
predicted_class = []
for i in range(0,total_count,2):
#     q1_i = i
#     q2_i = i+1
#     decision = generated_text[i].split(": ")[-1].split('.')[0]
    decision1 = generated_text[i].split(" ")[-1].split('.')[0] #cot
    decision2 = generated_text[i+1].split(" ")[-1].split('.')[0] #pal
    if (decision1 in labels[i]) and (decision2 in labels[i+1]):
        correct_count += 1
    else:
#         true_count = np.random.choice([0,1])
#         correct_count +=true_count
        prep1 = compute_perp(input_data[i], decision1)
        prep2 = compute_perp(input_data[i+1], decision2)
        
        if prep1<prep2:
            final_d = decision1
        else:
            final_d = decision2
            
        if final_d == decision1:
            if (decision1 in labels[i]):
                correct_count += 1
        
        if final_d == decision2:
            if (decision2 in labels[i+1]):
                correct_count += 1
    
accuracy = correct_count / (total_count//2)

print(accuracy)

