import sys
import torch
from transformers import (
    TrainingArguments, 
    DataCollatorWithPadding, 
    AutoTokenizer,
    AutoModelForCausalLM,
)
from adapters import AdapterConfig
from peft import get_peft_model, LoraConfig, TaskType
from datetime import datetime
from datasets import Dataset
import os
# from ft_utils import read_jsonl, load_model, get_train_idx, write_jsonl
from peft import PeftModel
from transformers import Trainer, EvalPrediction
import torch
import numpy as np
import wandb
import json
from tqdm import tqdm

model_name = 'meta-llama/Llama-3.1-8B'
store_path = "/gpfs/radev/home/tl688/scratch/llamaf/LLaMA-Factory/saves/llama/lora/sft_8bbase_70bllamainfo_new_updatedescription/checkpoint-700/"

use_adapter = True

base_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(store_path)
base_model.resize_token_embeddings(len(tokenizer))
if use_adapter:
    print('Using adapter')
    model = PeftModel.from_pretrained(base_model, store_path)
else:
    print('No adapter used')
    model = base_model.from_pretrained(store_path)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

import pandas as pd
import numpy as np


# df_train = pd.read_json("../../INC-Math/ft_data/llama3.1-8b/train/data_lvl_543_greedy.jsonl")
import pandas as pd    
df_test = pd.read_json(path_or_buf="../../INC-Math/ft_data/llama3.1-70b/test/data_lvl_54321_greedy.jsonl", lines=True)


meta_descrip = '''<description>\nTo analyze the possibility and potential to solve the given math question, we can consider the different approaches mentioned: COT, PAL, CodeNL, and NLCode. Let\'s explore each approach in detail:\n\n### COT (Chain of Thought in Natural Language)\n\n1. **Understand the Problem**: We have two lines in 3D space. The first line is given by a point and a direction vector, and the second line is given similarly. The lines are coplanar if there exists a plane that contains both lines.\n\n2. **Condition for Coplanarity**: Two lines are coplanar if the vector connecting any point on the first line to any point on the second line is perpendicular to the cross product of the direction vectors of the two lines.\n\n3. **Mathematical Formulation**:\n   - Let \\(\\mathbf{a} = \\begin{pmatrix} 2 \\\\ 3 \\\\ 4 \\end{pmatrix}\\), \\(\\mathbf{b} = \\begin{pmatrix} 1 \\\\ 1 \\\\ -k \\end{pmatrix}\\).\n   - Let \\(\\mathbf{c} = \\begin{pmatrix} 1 \\\\ 4 \\\\ 5 \\end{pmatrix}\\), \\(\\mathbf{d} = \\begin{pmatrix} k \\\\ 2 \\\\ 1 \\end{pmatrix}\\).\n   - The vector connecting a point on the first line to a point on the second line is \\(\\mathbf{c} + u\\mathbf{d} - (\\mathbf{a} + t\\mathbf{b})\\).\n   - The cross product of the direction vectors is \\(\\mathbf{b} \\times \\mathbf{d}\\).\n   - The coplanarity condition is \\((\\mathbf{c} - \\mathbf{a}) \\cdot (\\mathbf{b} \\times \\mathbf{d}) = 0\\).\n\n4. **Solve for k**: Calculate the cross product and the dot product, set the equation to zero, and solve for \\(k\\).\n\n### PAL (Program-Aided Language)\n\n1. **Write a Python Program**: Use Python to perform vector operations and solve the equation for \\(k\\).\n\n```python\nimport numpy as np\n\n# Define vectors\na = np.array([2, 3, 4])\nb = np.array([1, 1, -1])  # Placeholder for k\nc = np.array([1, 4, 5])\nd = np.array([1, 2, 1])   # Placeholder for k\n\n# Define a function to calculate k\ndef find_k():\n    # Cross product of b and d\n    def cross_product(k):\n        b_k = np.array([1, 1, -k])\n        d_k = np.array([k, 2, 1])\n        return np.cross(b_k, d_k)\n\n    # Dot product of (c - a) and cross product\n    def dot_product(k):\n        cross_prod = cross_product(k)\n        return np.dot(c - a, cross_prod)\n\n    # Solve for k such that dot_product(k) = 0\n    for k in range(-10, 11):  # Example range, adjust as needed\n        if dot_product(k) == 0:\n            print(f"Possible value of k: {k}")\n\nfind_k()\n```\n\n### CodeNL (Code First, Natural Language Explanation)\n\n1. **Write the Code**: Implement the solution in Python as shown above.\n2. **Analyze the Code**: \n   - The code calculates the cross product of the direction vectors for each \\(k\\).\n   - It then computes the dot product of this cross product with the vector connecting the two points.\n   - It checks for which values of \\(k\\) this dot product is zero, indicating coplanarity.\n\n3. **Obtain the Final Answer**: Run the code to find all possible values of \\(k\\).\n\n### NLCode (Natural Language to Code)\n\n1. **Explain the Solution**: \n   - Explain the condition for coplanarity and how it translates into a mathematical equation involving \\(k\\).\n   - Describe how to calculate the cross and dot products.\n\n2. **Translate to Code**: Implement the explanation in Python code.\n\n3. **Execute and Verify**: Run the code to find the values of \\(k\\) that satisfy the condition.\n\n### Conclusion\n\nEach approach has its strengths. COT is useful for a deep understanding and manual solving, PAL leverages programming for efficient computation, CodeNL combines both for clarity, and NLCode ensures a thorough understanding before coding. For this problem, using PAL or CodeNL can efficiently find the possible values of \\(k\\) by leveraging Python\'s computational capabilities.\n</description>'''

prompt = meta_descrip + ' Please choose the correct method to solve the problem, you have four methods to choose from: cot, pal, codenl, nlcode. '

row

np.random.seed(2024)
train_list = []
for item in df_test.index:
    row = df_test.loc[item]
    if len(row['label']) == 0:
        continue
    else:
        dict_list = {}
        instruction = prompt + f"Here is the question: {row['question']} Your choice: "  # prompt should be defined externally
        inputs = ''
        output = row['label']
        dict_list['instruction'] = instruction
        dict_list['input'] = inputs 
        dict_list['output'] = output
        train_list.append(dict_list)
        

len(train_list)

train_list[0]

input_data = []
labels = []
for item in train_list:
    input_data.append(item['instruction'])
    labels.append(item['output'])

len(input_data)

generated_text = []
for idx in range(0,len(input_data),4):
    input_text = input_data[idx:idx+4]
    with torch.no_grad():
        # Pass attention mask to handle padding properly
        generation_config = {
        "max_new_tokens": 10,          # 控制生成长度
        "do_sample": False,
        # "top_p": 0.9,
        # "repetition_penalty": 1.2,
        "pad_token_id": tokenizer.eos_token_id,
        "use_cache": True               # 启用KV缓存
        }
        encoding = tokenizer(
                    input_text,
                    max_length=2048,
                    truncation=True,
                    padding=True,
                    return_tensors='pt',
                    padding_side='left'
                )
        input_ids = encoding['input_ids'].cuda()
        attention_mask = encoding['attention_mask'].cuda()
        output = model.generate(input_ids=input_ids, attention_mask=attention_mask, **generation_config)
    #     print(output)
    #             text = tokenizer.batch_decode(output.logits, skip_special_tokens=True)
        text = tokenizer.batch_decode(output, skip_special_tokens=True)
        print(text[0][len(input_text[0]):])
        for idx_text in range(len(text)):
            generated_text.append(text[idx_text][len(input_text[idx_text]):])
#         labels.append(item['output'].split(': ')[-1].split('.')[0])


# text[0]

# Calculate accuracy
correct_count = 0
total_count = len(labels)
predicted_class = []

for i in range(total_count):
    decision = generated_text[i].split(": ")[-1].split('.')[0]

    predicted_class.append(decision)
#     print(decision)
#     print(labels[i])
    if decision in labels[i]:
        correct_count += 1
    elif 'cot' in labels[i]:
        correct_count += 1

accuracy = correct_count / total_count

accuracy

# Calculate accuracy
correct_count = 0
total_count = len(labels)
predicted_class = []

for i in range(total_count):
    decision = 'cot'
    if decision in labels[i]:
        correct_count += 1

accuracy = correct_count / total_count


