import os
import torch
import random
import numpy as np
import argparse
from transformers import AutoTokenizer
from vllm import SamplingParams, LLM
from utils import read_json, write_json, is_math_verify_equiv, remove_boxed, last_boxed_only_string, get_token_count

from collections import Counter

random.seed(42)
np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="Model to evaluate")
parser.add_argument("--gpus", type=str, required=True, help="which GPUs to use")
parser.add_argument("--k", type=int, default=1, help="Number of samples (for average and Self-Consistency) to evaluate")
parser.add_argument("--dataset", type=str, default="MATH_500", help="Dataset to evaluate on")
parser.add_argument("--save_output", action="store_true", default=False, help="Whether to save the output to a file")
parser.add_argument("--temp", type=float, default=0.7, help="Temperature for generation")
parser.add_argument("--use_hint", action="store_true", default=False, help="Whether to use hint")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
num_gpus = len(args.gpus.split(","))
llm = LLM(model = args.model,
            max_model_len = 12000,
            tensor_parallel_size = num_gpus,
            trust_remote_code = True)   
tokenizer = AutoTokenizer.from_pretrained(args.model)

if not tokenizer.chat_template:
    print("No chat template found, setting default")
    tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
                              
# system_prompt = (f"A conversation between User and Assistant. The user asks a question, and "
#                 f"the Assistant solves it. The assistant first thinks about the reasoning process in the mind and "
#                 f"then provides the user with the answer. Even when the user provides the answer, "
#                 f"the assistant should still provide the reasoning process from scratch and work out the answer by itself.")

instruction_following = (f"\n\nYou must put your answer inside \\boxed{{}} "
                            f"and your final answer will be extracted automatically by the \\boxed{{}} tag.")

if "Llama" in args.model:
    model_name = "Llama"
elif "Qwen" in args.model:
    model_name = "Qwen"
elif "Octo" in args.model:
    model_name = "Octo"

if args.use_hint:
    test_samples = read_json(f"./dataset/{args.dataset}/test_with_hint_{model_name}.json")
    print(f"================\n\n\n*** Using hint for {model_name} ***\n\n\n================")
else:
    test_samples = read_json(f"./dataset/{args.dataset}/test.json")

for i in test_samples:
    if "question" not in i:
        i["question"] = i["problem"]
    if "gold_answer" not in i:
        i["gold_answer"] = i["expected_answer"]

prompts = []    
for i in test_samples:
    if "Octo" in args.model:
        system_prompt = (f"A conversation between User and Assistant. The user asks a question, and "
                            f"the Assistant solves it. The assistant first thinks about the reasoning process in the mind and "
                            f"then provides the user with the answer.")
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": i['question'] + instruction_following},
        ]
    else:
        messages = [
            #{"role": "system", "content": system_prompt},
            {"role": "user", "content": i['question'] + instruction_following},
        ]
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        tokenize = False,
        )
    for _ in range(args.k):
        prompts.append(text)

print(prompts[0])

sampling_params = SamplingParams(
    temperature = args.temp,
    top_p = 1.0,
    max_tokens = 10000
)

outputs = llm.generate(
    prompts,
    sampling_params = sampling_params
)
    
results = [o.outputs[0].text for o in outputs]

# Track token counts for each result
token_counts = []
for result in results:
    # Use tokenizer to get token count for consistency with the model
    tokens = tokenizer.encode(result)
    token_counts.append(len(tokens))

for sample in test_samples:
    sample['predictions'] = []
    sample['token_counts'] = []

for i in range(len(test_samples)):
    for k in range(args.k): 
        pred_boxed = last_boxed_only_string(results[i*args.k+k])
        if pred_boxed:
            pred = remove_boxed(pred_boxed)
        else:
            pred = "N/A"
        # pred = results[i*args.k+k]
        test_samples[i]['predictions'].append(pred)
        test_samples[i]['token_counts'].append(token_counts[i*args.k+k])
        # just for printing
        test_samples[i]['reasoning'] = results[i*args.k+k]

for sample in test_samples:
    # Filter out "N/A" predictions when calculating majority
    valid_predictions = [pred for pred in sample['predictions'] if pred != "N/A"]
    if valid_predictions:
        maj = Counter(valid_predictions).most_common(1)[0][0]
        sample['majority'] = maj
    else:
        # If all predictions are "N/A", set majority to "N/A"
        sample['majority'] = "N/A"

accs = []
token_count_per_k = []
for k_idx in range(args.k):
    num_correct = 0
    total_tokens = 0
    for sample in test_samples:
        if is_math_verify_equiv(sample['predictions'][k_idx], sample['gold_answer']):
            num_correct += 1
        total_tokens += sample['token_counts'][k_idx]
    accs.append(num_correct / len(test_samples))
    token_count_per_k.append(total_tokens / len(test_samples))

num_sc_correct = 0
for sample in test_samples:
    if is_math_verify_equiv(sample['majority'], sample['gold_answer']):
        num_sc_correct += 1
sc_acc = num_sc_correct / len(test_samples)


# Calculate pass@k for all powers of 2 up to args.k
valid_k_values = [k for k in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] if k <= args.k]
pass_at_k_scores = {}

for k_val in valid_k_values:
    pass_at_k_score = 0
    for sample in test_samples:
        # Check if any of the first k_val predictions is correct
        for pred in sample['predictions'][:k_val]:
            if is_math_verify_equiv(pred, sample['gold_answer']):
                pass_at_k_score += 1
                break   
    pass_at_k_scores[k_val] = pass_at_k_score / len(test_samples)

# Print 5 random question+output pairs
print("\n" + "="*50)
print("5 RANDOM QUESTION+OUTPUT PAIRS:")
print("="*50)

# Get 5 random indices
random_indices = random.sample(range(len(test_samples)), min(5, len(test_samples)))

for idx in random_indices:
    sample = test_samples[idx]
    print(f"\nQuestion {idx}:")
    print(f"Q: {sample['question']}")
    print(f"Reasoning: {sample['reasoning']}")
    print(f"Gold Answer: {sample['gold_answer']}")
    print(f"Majority: {sample['majority']}")
    print("-" * 30)

print(f"model: {args.model}")
print(f"avg. accuracy: {np.mean(accs)*100:.02f} ± {np.std(accs)*100:.02f}")
print(f"avg. token count: {np.mean(token_count_per_k):.02f} ± {np.std(token_count_per_k):.02f}")
print(f"SC accuracy: {sc_acc*100:.02f}")
# Print all pass@k results
for k_val in valid_k_values:
    print(f"pass@{k_val}: {pass_at_k_scores[k_val]*100:.02f}")
print(f"temperature: {args.temp}")

if args.save_output:
    model_name = args.model.split("/")[-3]
    write_json(test_samples, f"{model_name}_{args.dataset}_{args.k}.json")