import argparse
import collections
import math
import time
from typing import List
import os
from collections import Counter

import torch
from transformers import pipeline, AutoTokenizer
from trl.data_utils import is_conversational, apply_chat_template, maybe_apply_chat_template
from transformers import GenerationConfig
from vllm import LLM, SamplingParams

from calib.data import preprocess_dataset
# from hivemind.utils.model_utils import get_model, get_tokenizer
from calib.utils import extract_answer, math_equal


def main():
    parser = argparse.ArgumentParser(description="Evaluate model performance on math or countdown dataset.")
    # parser.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-0.5B-Instruct", help="Model name or path.")
    parser.add_argument("--model_name", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", help="Model name or path.")
    parser.add_argument("--dataset_name", type=str, default="gsm8k", help="Dataset name.")
    # parser.add_argument("--dataset_name", type=str, default="binary_cot_outputs/completions/completions_Qwen_Qwen2.5-Math-1.5B-Instruct_deepmath_8000_16_0.7_4096.pt", help="Dataset name.")
    parser.add_argument("--num_prompts", type=int, default=1, help="Number of prompts to evaluate.")
    parser.add_argument("--num_completions_per_prompt", type=int, default=16)
    # parser.add_argument("--max_completion_length", type=int, default=10, help="Maximum completion length.")
    parser.add_argument("--max_completion_length", type=int, default=4096, help="Maximum completion length.")
    # parser.add_argument("--max_model_len", type=int, default=8192, help="Maximum model length.")
    parser.add_argument("--max_model_len", type=int, default=4096, help="Maximum model length.")
    parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature.")
    parser.add_argument("--top_k", type=int, default=-1, help="Top-k sampling.")
    parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling.")
    parser.add_argument("--additional_prompt", type=str, default="Please reason step by step, and put your final answer within \\boxed{}.", help="Additional prompt for the model.")
    # parser.add_argument("--additional_prompt", type=str, default="", help="Additional prompt for the model.")
    parser.add_argument("--difficulty", type=str, default=None, help="Difficulty level of the dataset.")
    parser.add_argument("--output_dir", type=str, default="outputs/completions", help="Output directory.")
    parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type.")
    parser.add_argument(
        "--max_answer_chars",
        type=int,
        default=None,
        help="If set, drop dataset rows whose answer length exceeds this many characters.",
    )
    args = parser.parse_args()
    
    filename = os.path.join(args.output_dir, f"completions_{args.model_name.replace('/', '_')}_{args.dataset_name}_{args.num_prompts}_{args.num_completions_per_prompt}_{args.temperature}_{args.max_completion_length}_{args.max_answer_chars}_{args.difficulty}.pt")

    print(f"Will save to {filename}")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    
    # Load dataset
    dataset = preprocess_dataset(
        args.dataset_name,
        additional_prompt=args.additional_prompt,
        n=args.num_prompts,
        max_answer_chars=args.max_answer_chars,
        difficulty=args.difficulty,
    )
    N = len(dataset)
    
    # Stop at </think>, for reasoning models
    # stop_token_id = tokenizer.convert_tokens_to_ids("</think>")
    
    # Setup tokenizer and generator
    llm = LLM(model=args.model_name, dtype=args.dtype, gpu_memory_utilization=0.9, max_model_len=args.max_model_len)
    sampling_params = SamplingParams(
        n=args.num_completions_per_prompt,
        temperature=args.temperature,
        repetition_penalty=1.0,
        top_k=args.top_k,
        top_p=args.top_p,
        min_p=0.0,
        max_tokens=args.max_completion_length,
        # stop_token_ids=[stop_token_id] if stop_token_id is not None else None,
    )

    start_time = time.time()

    prompts = dataset["prompt"]
    prompts_text = [tokenizer.apply_chat_template(example, tokenize=False, add_generation_prompt=True) for example in prompts]
    request_outputs = llm.generate(prompts_text, sampling_params)
    completion_ids = [[out.token_ids for out in output.outputs] for output in request_outputs]
    # completion_ids = [completion_ids[i][j] for i in range(len(completion_ids)) for j in range(args.num_completions_per_prompt)]
    completions = [tokenizer.batch_decode(completion_ids[i], skip_special_tokens=True) for i in range(N)]
    prompt_ids = [output.prompt_token_ids for output in request_outputs]
    completion_lengths = [[len(completion) for completion in completion_ids[i]] for i in range(N)]

    assert len(prompts_text) == len(completion_lengths) == len(completions) == len(prompt_ids) == len(completions) == N
    assert len(completions[0]) == len(completion_lengths[0]) == args.num_completions_per_prompt
    
    # reward_func = get_reward_func(args.dataset_name)
    # keys = [key for key in dataset[0].keys() if key not in ["prompt", "completion"]]
    # reward_kwargs = {key: sum([[sample[key]]*args.num_completions_per_prompt for sample in dataset], []) for key in keys}
    # reward_kwargs = [{key: [sample[key]]*args.num_completions_per_prompt for key in keys} for sample in dataset]
    # rewards = [reward_func(completions=completions[i], **reward_kwargs[i]) for i in range(N)]
    responses = [[extract_answer(completion) for completion in completions[i]] for i in range(N)]
    answers = [example["answer"] for example in dataset]
    rewards = [[math_equal(responses[i][j], answers[i]) for j in range(len(responses[i]))] for i in range(N)]
    
    assert len(responses) == len(answers) == len(rewards) == N
    assert len(responses[0]) == len(rewards[0]) == args.num_completions_per_prompt

    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")

    # should create directory if it doesn't exist
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save({
        "completion_ids": completion_ids,
        "prompt_ids": prompt_ids,
        "completions": completions,
        "prompts_text": prompts_text,
        "completion_lengths": completion_lengths,
        "responses": responses,
        "answers": answers,
        "rewards": rewards,
        "args": args,
    }, filename)
    
    
if __name__ == "__main__":
    main()
