import os, time
import json
from vllm import LLM, SamplingParams
from datasets import load_from_disk, load_dataset
from utils import DATASET_KEYS, RESPONSE_EXTRACTOR, RESPONSE_COMPARATOR
import pandas as pd
import argparse
import numpy as np


# This script evaluates a model on a dataset

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--dataset', type=str)
parser.add_argument('--scale', type=str, default='1.5B')
parser.add_argument('--tok_limit', type=int, default=32768)
parser.add_argument('--thinking_mode', type=str, choices=['low', 'medium', 'high', 'none'], default='low')
parser.add_argument('--beta', type=int, default=300)
args = parser.parse_args()
os.environ['TOKENIZERS_PARALLELISM'] = "false"
THINKING_MODE = args.thinking_mode

dataset_name = args.dataset
model_path = args.model_path
scale = args.scale
tok_limit = args.tok_limit
dataset_name = args.dataset
beta = args.beta
results = {}

print("Dataset:", dataset_name, "\nScale:", scale)

QUESTION_KEY = DATASET_KEYS[dataset_name]["question"]
ANSWER_KEY = DATASET_KEYS[dataset_name]["answer"]
eq = RESPONSE_COMPARATOR[dataset_name]

if dataset_name == 'datasets/converted_aime_dataset':
    dataset = load_from_disk(dataset_name)
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 100
elif dataset_name == 'di-zhang-fdu/MATH500':
    dataset = load_dataset(dataset_name)
    TEST_N = 3
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 500
elif dataset_name == 'openai/gsm8k':
    dataset = load_dataset(dataset_name, 'main')
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 1319
elif dataset_name == 'datasets/gsm8k':
    dataset = load_dataset(dataset_name, 'main')
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 1319
elif dataset_name == "datasets/MATH500":
    dataset = load_dataset(dataset_name)
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 500

def get_scores(ds, outputs, thinkings, thinking_lengths, save_file_name=None):
    predictions, golds = [], []
    results = []
    num_responses_under_beta = 0
    for input, output, thinking, thinking_length in zip(ds, outputs, thinkings, thinking_lengths):
        gold = RESPONSE_EXTRACTOR[dataset_name](input[ANSWER_KEY])
        prediction = [
            RESPONSE_EXTRACTOR[dataset_name](resp.text)
            for resp in output.outputs
        ]
        if thinking_length < beta:
            num_responses_under_beta += 1
        predictions.append(prediction)
        golds.append(gold)
        results.append(
            {
                QUESTION_KEY: input[QUESTION_KEY],
                ANSWER_KEY: input[ANSWER_KEY],
                "responses": [resp.text for resp in output.outputs],
                "prediction": prediction,
                "gold": gold,
                "tokens": sum([len(resp.token_ids) for resp in output.outputs]) / len(output.outputs),
                "accuracy": [eq(gold, pred) for pred in prediction],
                "thinking": thinking,
                "thinking_length": thinking_length,
                "under_beta": thinking_length < beta,
            }
        )
    if save_file_name is not None:
        with open(save_file_name, 'w') as f:
            json.dump(results, f, indent=4)

    results = pd.DataFrame(results)
    predictions, golds, tokens, thinking_lengths, under_beta = results["prediction"], results["gold"], results["tokens"], results["thinking_length"], results["under_beta"]
    pass_at_1 = sum([any([eq(g, pred) for pred in p[:1]]) for p, g in zip(predictions, golds)]) / len(predictions)
    pass_at_k_list = []
    acc_at_k_list = []
    k = TEST_N
    print("Average tokens:", sum(tokens) / len(tokens))
    print("Average thinking length:", sum(thinking_lengths) / len(thinking_lengths))
    print("Number of responses under beta:", sum(under_beta))
    for i in range(k):
        pass_at_i = sum([any([eq(g, pred) for pred in p[:i+1]]) for p, g in zip(predictions, golds)]) / len(predictions)
        acc_at_i = sum([eq(g, p[i]) for p, g in zip(predictions, golds)]) / len(predictions)
        acc_at_k_list.append(acc_at_i)
        pass_at_k_list.append(pass_at_i)
        print(
            f"Pass @ {i+1}: {pass_at_i}"
        )

    def get_most_common(solns):
        soln_counts = {}
        for soln in solns:
            if soln is None:
                continue
            added = False
            for other_solns in solns:
                if eq(soln, other_solns):
                    added = True
                    soln_counts[soln] = soln_counts.get(soln, 0) + 1
            if not added:
                soln_counts[soln] = 1
        if len(soln_counts) == 0:
            return None
        return max(soln_counts, key=soln_counts.get)
    
    predictions_maj = [get_most_common(p) for p in predictions]
    all_preds = sum([[eq(golds[i], p) for p in predictions[i]] for i in range(len(predictions))], [])
    avg_pass_rate = sum(all_preds) / len(all_preds)
    pass_at_n = sum([eq(g, p) for p, g in zip(predictions_maj, golds)]) / len(predictions)
    print(
        f"Pass @ 1(with majority): {pass_at_n}"
    )
    
    return {
        'pass@1': pass_at_1,
        'pass@1(majority)': sum([eq(g, p) for p, g in zip(predictions_maj, golds)]) / len(predictions),
        'average_pass_rate': avg_pass_rate,
        'std_pass_rate': np.std(acc_at_k_list),
        'acc@k': acc_at_k_list,
        'pass@k': pass_at_k_list,
        'avg_tokens': sum(tokens) / len(tokens),
        'avg_thinking_length': sum(thinking_lengths) / len(thinking_lengths),
        'num_responses_under_beta': sum(under_beta)
    }



def construct_prompt_with_thinking(test_ds, thinking_outputs):
    new_prompts = []
    thinkings = []
    thinking_lengths = []
    for x, thinking_output in zip(test_ds, thinking_outputs):
        thinking = thinking_output.outputs[0]
        new_prompt = x + thinking.text + "</think>"
        thinkings.append(thinking.text)
        thinking_lengths.append(len(thinking.token_ids))
        new_prompts.append(new_prompt)
    return new_prompts, thinkings, thinking_lengths




def evaluate_model(model_name):
    test_prompts = []
    model = LLM(model_name, gpu_memory_utilization=0.4, tensor_parallel_size=1)    

    test_ds = dataset['test'].shuffle(seed=0).select(range(min(MAX_TEST_SAMPLES, len(dataset['test']))))
    
    for x in test_ds:

        if THINKING_MODE == 'low':
            prompt = [{
                "role": "user",
                "content": f"Think in low detail — keep reasoning very short and minimal. Only outline the essential steps needed to reach the answer, without extra explanation. Please reason step by step, and put your final answer within \\boxed{{}}. Question: {x[QUESTION_KEY]}"
            }]

        elif THINKING_MODE == 'medium':
            prompt = [{
                "role": "user",
                "content": f"Think in medium detail — provide a balanced explanation that covers the main reasoning steps, but avoid going into excessive depth or unnecessary derivations. Please reason step by step, and put your final answer within \\boxed{{}}. Question: {x[QUESTION_KEY]}"
            }]

        elif THINKING_MODE == 'high':
            prompt = [{
                "role": "user",
                "content": f"Think in high detail — provide a thorough, clear explanation that fully justifies each step, covering assumptions, intermediate calculations, and checks where appropriate. Please reason step by step, and put your final answer within \\boxed{{}}. Question: {x[QUESTION_KEY]}"
                
            }]

        else:
            prompt = [{
                "role": "user",
                "content": f"Please reason step by step, and put your final answer within \\boxed{{}}. Question: {x[QUESTION_KEY]}",
            }]

        prompt_tokens = model.llm_engine.tokenizer.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
        test_prompts.append(prompt_tokens)
    

    thinking_sampling_params = SamplingParams(
        temperature=TEST_TEMPERATURE,
        max_tokens=14000,
        n=TEST_N,
        stop = ["</think>"], 
    )

    answer_sampling_params = SamplingParams(
        temperature=TEST_TEMPERATURE,
        max_tokens=16384,
        n=TEST_N
    )

    print("Generating test outputs...")
    start_time = time.time()
    test_outputs = model.generate(prompts=test_prompts, sampling_params=thinking_sampling_params, use_tqdm=True)


    test_prompts, thinkings, thinking_lengths = construct_prompt_with_thinking(test_prompts, test_outputs)


    test_outputs = model.generate(prompts=test_prompts, sampling_params=answer_sampling_params, use_tqdm=True)
    end_time = time.time()
    test_scores = get_scores(test_ds, test_outputs, thinkings, thinking_lengths, f"outputs/{dataset_name.replace('/', '_')}_results_{model_path.replace('/', '_')}_{THINKING_MODE}_ZA.json")
    print("Test:", test_scores)
    time_taken = end_time - start_time
    print("Time taken:", time_taken)

    return {'test': test_scores, 'time_taken': time_taken}

print("Found model_path:", model_path)
print("This is not a checkpoint, will evaluate directly...")
scores = evaluate_model(model_path)
results[model_path] = scores

with open(f'results/{dataset_name.replace("/", "_")}_results_{model_path.replace("/", "_")}_{THINKING_MODE}_ZA.json', 'w') as f:
    json.dump(results, f, indent=4)
