import os
import sys
import time
from tqdm import tqdm
import torch

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import Agent, Math500Dataset

data_root = os.environ.get("DATA_ROOT")
cuda=4
num_samples=30
total_start_time = time.time()

model_path = f"{data_root}/DeepSeek-R1-Distill-Qwen-7B"
# model_path = f"{data_root}/Qwen3-8B"
model_name = os.path.basename(model_path)

model = Agent(model_path=model_path, is_anyprecision=False, device="cuda:3")
model = Agent(model_path=model_path, is_anyprecision=False, device=f"cuda:{cuda}")

dataset = Math500Dataset(dataset_path=f"{data_root}/MATH-500", prompt_type="better", shuffle=True, xverify_path=f"{data_root}/xVerify-9B-C", device=f"cuda:{cuda}")
dataset_name = "MATH-500"

prompt = dataset.get_prompt()[:num_samples]

total_generation_times = []
evaluation_times = []
other_times = []

answers = []
lengths = []
token_counts = []

for i in tqdm(range(len(prompt)), desc="Processing prompt without kv-cache"):
    question_start_time = time.time()
    
    generation_start_time = time.time()
    
    # print(prompt[i])
    answer, length = model(prompt[i])
    print(answer)
    answers.append(answer)
    lengths.append(length)
    token_counts.append(length)
    
    generation_end_time = time.time()
    generation_time = generation_end_time - generation_start_time
    total_generation_times.append(generation_time)
    
    question_end_time = time.time()
    question_total_time = question_end_time - question_start_time
    
    evaluation_time = 0.0
    evaluation_times.append(evaluation_time)
    
    other_time = question_total_time - generation_time - evaluation_time
    other_times.append(other_time)
    
    print(f"Question {i + 1} - Total: {question_total_time:.3f}s, Generation: {generation_time:.3f}s, Evaluation: {evaluation_time:.3f}s, Other: {other_time:.3f}s, Tokens: {length}")
    
    # # Clear KV-Cache after each inference
    # if hasattr(model.model, 'clear_kv_cache'):
    #     model.model.clear_kv_cache()
    # else:
    #     # If the model doesn't have clear_kv_cache method, try to clear CUDA cache
    #     torch.cuda.empty_cache()

evaluation_start_time = time.time()

# Get evaluation results
# results = dataset.result_eval(answers)
results = dataset.eval_xverify(answers)

evaluation_end_time = time.time()
total_evaluation_time = evaluation_end_time - evaluation_start_time

avg_evaluation_time_per_question = total_evaluation_time / len(prompt) if prompt else 0
evaluation_times = [avg_evaluation_time_per_question] * len(prompt)

other_times = []
for i in range(len(prompt)):
    other_time = total_generation_times[i] - evaluation_times[i]
    other_times.append(other_time)

avg_total_time = sum(total_generation_times) / len(total_generation_times) if total_generation_times else 0
avg_evaluation_time = sum(evaluation_times) / len(evaluation_times) if evaluation_times else 0
avg_other_time = sum(other_times) / len(other_times) if other_times else 0

mean_length = sum(lengths) / len(lengths)
avg_tokens = sum(token_counts) / len(token_counts) if token_counts else 0
total_tokens = sum(token_counts) if token_counts else 0

total_end_time = time.time()
total_runtime = total_end_time - total_start_time

results_dir = "./tt"
os.makedirs(results_dir, exist_ok=True)

# Generate filename using dataset and model names
result_file = os.path.join(results_dir, f"{dataset_name}_{model_name}_results_xverify.txt")

# Save results to file
with open(result_file, "w", encoding='utf-8') as f:
    f.write(f"Evaluation Results:\n")
    f.write("=" * 50 + "\n")
    f.write(f"Dataset: {dataset_name}\n")
    f.write(f"Model: {model_name}\n")
    f.write(f"Accuracy: {results}\n")
    f.write(f"Mean Length: {mean_length}\n")
    f.write(f"Average tokens per question: {avg_tokens:.2f}\n")
    f.write(f"Total tokens generated: {total_tokens}\n")
    f.write("\n" + "=" * 50 + "\n")
    f.write(f"\nTime statistics summary:\n")
    f.write(f"Total runtime: {total_runtime:.3f} seconds\n")
    f.write(f"Average total time per question: {avg_total_time:.3f} seconds\n")
    f.write(f"Average evaluation time per question: {avg_evaluation_time:.3f} seconds\n")
    f.write(f"Average other time per question: {avg_other_time:.3f} seconds\n")
    f.write(f"Total number of questions processed: {len(prompt)}\n")
    f.write(f"Total evaluation time: {total_evaluation_time:.3f} seconds\n")

print(f"\nTime statistics summary:")
print(f"Total runtime: {total_runtime:.3f} seconds")
print(f"Average total time per question: {avg_total_time:.3f} seconds")
print(f"Average evaluation time per question: {avg_evaluation_time:.3f} seconds")
print(f"Average other time per question: {avg_other_time:.3f} seconds")
print(f"Total number of questions processed: {len(prompt)}")
print(f"Total evaluation time: {total_evaluation_time:.3f} seconds")
print(f"Average tokens per question: {avg_tokens:.2f}")
print(f"Total tokens generated: {total_tokens}")
