import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
from tqdm import tqdm
import os


def generate_output(tokenizer, model, input_texts, temperature=None, do_sample=False):
    device = next(model.parameters()).device
    batch_outputs = []

    all_lengths = []
    for input_text in input_texts:
        inputs = tokenizer(input_text, return_tensors="pt").to(device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=7,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            num_return_sequences=16,
        )

        generated_texts = []
        input_len = inputs['input_ids'].shape[1]
        for out in outputs:
            output = out[input_len:]
            result = tokenizer.decode(output, skip_special_tokens=True)
            generated_texts.append(result)

        batch_outputs.append(generated_texts)

    return batch_outputs, all_lengths


model_path = "path/to/model/DeepSeek-R1-Distill-Qwen-14B"

tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    local_files_only=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"


formats = [
    "<｜User｜>{origin_text}<｜Assistant｜><think>\nLet me answer him without thinking more.</think>\nAnswer: ",
    "<｜User｜>{origin_text}<｜Assistant｜><think>\nI will answer directly. I won't output any thinking process.</think>\nAnswer: ",
    "<｜User｜>{origin_text}<｜Assistant｜><think>\nI will answer with a single number.</think>\nThe answer is: ",
    "<｜User｜>{origin_text}<｜Assistant｜><think>\nI should not think, but should answer directly.</think>\nThe answer is: "
]

temperatures = [0.5]

input_jsonl_file = "path/to/LongCoT/AIME/aime_2024_problems.jsonl"
output_directory = "path/to/LongCoT/AIME/aime_2024_direct_answers"

with open(input_jsonl_file, 'r', encoding='utf-8') as infile:
    index = 0
    for line in tqdm(infile):
        index += 1
        data = json.loads(line.strip())
        problem = data.get("Problem")

        for sample_type in ["greedy"] + temperatures:
            batch_input_texts = []
            for input_format in formats:
                input_text = input_format.format(origin_text=problem)
                batch_input_texts.append(input_text)

            if sample_type == "greedy":
                outputs, _ = generate_output(tokenizer, model, batch_input_texts, do_sample=False)
            else:
                outputs, _ = generate_output(tokenizer, model, batch_input_texts, temperature=sample_type, do_sample=True)

            if "direct_answers" not in data:
                data["direct_answers"] = []
            data["direct_answers"].extend(outputs)

        with open(os.path.join(output_directory, f"{data['ID']}.json"), 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)
        print(f"Problem {index} done.")
            