import os, time
import json
from vllm import LLM, SamplingParams
from datasets import load_from_disk, load_dataset
from utils import DATASET_KEYS, RESPONSE_EXTRACTOR, RESPONSE_COMPARATOR
import pandas as pd
import argparse
import numpy as np


# This script evaluates a model on a dataset

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='')
parser.add_argument('--summarizer', type=str, default="models/Qwen2.5-7B-Instruct")
parser.add_argument('--dataset', type=str)
parser.add_argument('--scale', type=str, default='1.5B')
parser.add_argument('--tok_limit', type=int, default=32768)
parser.add_argument('--thinking_mode', type=str, choices=['low', 'medium', 'high', 'none'], default='low')
parser.add_argument('--split', type=str, default='[0:1000]')
parser.add_argument('--beta', type=int, default=300)
args = parser.parse_args()
os.environ['TOKENIZERS_PARALLELISM'] = "false"
THINKING_MODE = args.thinking_mode
SPLIT = args.split
dataset_name = args.dataset
model_path = args.model_path
summarizer_name = args.summarizer
scale = args.scale
tok_limit = args.tok_limit
dataset_name = args.dataset
beta = args.beta
results = {}

print("Dataset:", dataset_name, "\nScale:", scale)

QUESTION_KEY = DATASET_KEYS[dataset_name]["question"]
ANSWER_KEY = DATASET_KEYS[dataset_name]["answer"]
eq = RESPONSE_COMPARATOR[dataset_name]

if dataset_name == 'datasets/converted_aime_dataset':
    dataset = load_from_disk(dataset_name)
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 100
elif dataset_name == 'datasets/compression_dataset':
    dataset = load_from_disk(dataset_name)
    TEST_N = 10
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 10
elif dataset_name == 'di-zhang-fdu/MATH500':
    dataset = load_dataset(dataset_name)
    TEST_N = 3
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 500
elif dataset_name == 'openai/gsm8k':
    dataset = load_dataset(dataset_name, 'main')
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 1319
elif dataset_name == 'datasets/gsm8k':
    dataset = load_dataset(dataset_name, 'main')
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 1319
elif dataset_name == "datasets/MATH500":
    dataset = load_dataset(dataset_name)
    TEST_N = 1
    MAX_TOKENS = tok_limit
    TEST_TEMPERATURE = 0.6
    MAX_TEST_SAMPLES = 500



def construct_prompt_for_summarization(test_ds, test_prompts, thinkings, summarizer_tokenizer, summarizer):

    new_prompts = []
    for x, thinking in zip(test_ds, thinkings):
        prompt = [
            {
                "role": "system",
                "content": "You are a helpful assistant that summarizes the following reasoning."
            },
            {
                "role": "user",
                "content": f"Summarize the reasoning process for the following question: {x[QUESTION_KEY]} \n\n Given Reasoning: {thinking}. Summary should be like the reasoning process but  eliminate the verbosity and retain the main ideas and essential steps."
            }
        ]
        summary_prompt = summarizer_tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
        new_prompts.append(summary_prompt)

    summary_sampling_params = SamplingParams(
        temperature=TEST_TEMPERATURE,
        max_tokens=1024,
        n=1
    )

    print("Generating summary thinking outputs...")
    summary_outputs = summarizer.generate(prompts=new_prompts, sampling_params=summary_sampling_params, use_tqdm=True)

    summary_texts = [output.outputs[0].text for output in summary_outputs]
    summary_lengths = [len(output.outputs[0].token_ids) for output in summary_outputs]

    new_prompts = [test_prompt + summary_text + "</think>" for test_prompt, summary_text in zip(test_prompts, summary_texts)]


    return new_prompts, summary_texts, summary_lengths



def get_scores(ds, outputs, thinkings, thinking_lengths, summary_texts, summary_lengths, save_file_name=None):
    predictions, golds = [], []
    results = []
    num_responses_under_beta = 0
    print("Getting scores...")

    for x, output, thinking, thinking_length, summary_text, summary_length in zip(ds, outputs, thinkings, thinking_lengths, summary_texts, summary_lengths):
        gold = RESPONSE_EXTRACTOR[dataset_name](x[ANSWER_KEY])
        prediction = [
            RESPONSE_EXTRACTOR[dataset_name](resp.text)
            for resp in output.outputs
        ]
        if thinking_length < beta:
            num_responses_under_beta += 1
        predictions.append(prediction)
        golds.append(gold)
        results.append(
            {
                QUESTION_KEY: x[QUESTION_KEY],
                ANSWER_KEY: x[ANSWER_KEY],
                "responses": [resp.text for resp in output.outputs],
                "prediction": prediction,
                "gold": gold,
                "tokens": sum([len(resp.token_ids) for resp in output.outputs]) / len(output.outputs),
                "accuracy": [eq(gold, pred) for pred in prediction],
                "thinking": thinking,
                "thinking_length": thinking_length,
                "under_beta": thinking_length < beta,
                "summary": summary_text,
                "summary_length": summary_length,
            }
        )
    if save_file_name is not None:
        with open(save_file_name, 'w') as f:
            json.dump(results, f, indent=4)

    results = pd.DataFrame(results)
    predictions, golds, tokens, thinking_lengths, under_beta, summary_lengths = results["prediction"], results["gold"], results["tokens"], results["thinking_length"], results["under_beta"], results["summary_length"]
    pass_at_1 = sum([any([eq(g, pred) for pred in p[:1]]) for p, g in zip(predictions, golds)]) / len(predictions)
    pass_at_k_list = []
    acc_at_k_list = []
    k = TEST_N
    print("Average tokens:", sum(tokens) / len(tokens))
    print("Average thinking length:", sum(thinking_lengths) / len(thinking_lengths))
    print("Average summary thinking length:", sum(summary_lengths) / len(summary_lengths))
    print("Number of responses under beta:", sum(under_beta))
    return {
        'pass@1': pass_at_1,
        'avg_tokens': sum(tokens) / len(tokens),
        'avg_thinking_length': sum(thinking_lengths) / len(thinking_lengths),
        'avg_summary_length': sum(summary_lengths) / len(summary_lengths),
        'num_responses_under_beta': sum(under_beta)
    }



def construct_prompt_with_thinking(test_ds, thinking_outputs):
    new_prompts = []
    thinkings = []
    thinking_lengths = []
    for x, thinking_output in zip(test_ds, thinking_outputs):
        thinking = thinking_output.outputs[0]
        new_prompt = x + thinking.text + "</think>"
        thinkings.append(thinking.text)
        thinking_lengths.append(len(thinking.token_ids))
        new_prompts.append(new_prompt)
    return new_prompts, thinkings, thinking_lengths



os.makedirs("train_outputs", exist_ok=True)
os.makedirs("train_results", exist_ok=True)


def evaluate_model(model_name):
    test_prompts = []
    model = LLM(model_name, gpu_memory_utilization=0.4, tensor_parallel_size=1)    

    summarizer = LLM(summarizer_name, gpu_memory_utilization=0.5, tensor_parallel_size=1)
    tokenizer = model.get_tokenizer()
    summary_tokenizer = summarizer.get_tokenizer()
    
    # test_ds = dataset['test'].shuffle(seed=0).select(range(min(MAX_TEST_SAMPLES, len(dataset['test']))))

    split_idx = SPLIT.replace('[', '').replace(']', '')
    start_idx, end_idx = split_idx.split(':')
    start_idx, end_idx = int(start_idx), int(end_idx)

    # test_ds = dataset.shuffle(seed=0).select(range(min(MAX_TEST_SAMPLES, len(dataset))))
    test_ds = dataset.shuffle(seed=0).select(range(start_idx, end_idx))
    
    print("Length of test dataset:", len(test_ds))
    new_test_ds = []
    for i, x in enumerate(test_ds):
        temp_x = [
            {
                "id": i,
                QUESTION_KEY: x[QUESTION_KEY],
                ANSWER_KEY: x[ANSWER_KEY],
            }
        ]*TEST_N
        new_test_ds.extend(temp_x)

    test_ds = new_test_ds

    print("Length of new test dataset:", len(test_ds))
    for i, x in enumerate(test_ds):
        prompt = [{
            "role": "user",
            "content": f"Please reason step by step, and put your final answer within \\boxed{{}}. Question: {x[QUESTION_KEY]}",
        }]

        prompt_tokens = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
        test_prompts.append(prompt_tokens)
    

    thinking_sampling_params = SamplingParams(
        temperature=TEST_TEMPERATURE,
        max_tokens=14000,
        n=1,
        stop = ["</think>"], 
    )

    answer_sampling_params = SamplingParams(
        temperature=TEST_TEMPERATURE,
        max_tokens=16384,
        n=1
    )

    print("Generating thinking outputs...")
    start_time = time.time()
    test_outputs = model.generate(prompts=test_prompts, sampling_params=thinking_sampling_params, use_tqdm=True)


    thinking_prompts, thinkings, thinking_lengths = construct_prompt_with_thinking(test_prompts, test_outputs)

    summary_thinking_prompts, summaries, summary_lengths = construct_prompt_for_summarization(test_ds, test_prompts, thinkings, summary_tokenizer, summarizer)

    print("Generating summary outputs...")
    test_outputs = model.generate(prompts=summary_thinking_prompts, sampling_params=answer_sampling_params, use_tqdm=True)
    end_time = time.time()
    test_scores = get_scores(test_ds, test_outputs, thinkings, thinking_lengths, summaries, summary_lengths, f"train_outputs/{dataset_name.replace('/', '_')}_results_{model_path.replace('/', '_')}_{THINKING_MODE}_ZSA_{SPLIT}.json")
    print("Test:", test_scores)
    time_taken = end_time - start_time
    print("Time taken:", time_taken)

    return {'test': test_scores, 'time_taken': time_taken}

print("Found model_path:", model_path)
print("This is not a checkpoint, will evaluate directly...")
scores = evaluate_model(model_path)
results[model_path] = scores

with open(f'train_results/{dataset_name.replace("/", "_")}_results_{model_path.replace("/", "_")}_{THINKING_MODE}_ZSA_{SPLIT}.json', 'w') as f:
    json.dump(results, f, indent=4)
