from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import json
import torch
import random
import numpy as np
import argparse

# reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--model_name",
    type=str,
    required=True,
    choices=["llama8b", "llama70b", "qwen7b", "qwen72b", "olmo7b", "olmo32b"],
    help="Model to use",
)
args = parser.parse_args()
model_name = args.model_name
quant = False
if model_name == "llama8b":
    model_path = "meta-llama/Llama-3.1-8B-Instruct"
elif model_name == "llama70b":
    model_path = "meta-llama/Llama-3.3-70B-Instruct"
    quant = True
elif model_name == "qwen7b":
    model_path = "Qwen/Qwen2.5-7B-Instruct"
elif model_name == "qwen72b":
    model_path = "Qwen/Qwen2.5-72B-Instruct"
    quant = True
elif model_name == "olmo7b":
    model_path = "allenai/OLMo-2-1124-7B-Instruct"
elif model_name == "olmo32b":
    model_path = "allenai/OLMo-2-0325-32B-Instruct"
else:
    raise ValueError("Model not supported")

# Path to the MATH dataset
with open("math.json", "r") as json_file:
    data = json.load(json_file)

results = []
tokenizer = AutoTokenizer.from_pretrained(model_path)
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,  # Enable 8-bit quantization
    llm_int8_threshold=6.0,  # (Optional) Default threshold for LLM.int8()
    llm_int8_skip_modules=None,  # (Optional) Skip quantization for specific modules
)

if quant:
    model = AutoModelForCausalLM.from_pretrained(
        model_path, device_map="auto", quantization_config=bnb_config
    )
else:
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

valid_indices = list(range(len(data)))
for i in tqdm(valid_indices):
    d = data[i]
    prompt = f"""
        Solve the following math problem efficiently and clearly:

            - For simple problems (2 steps or fewer):
            Provide a concise solution with minimal explanation.

            - For complex problems (3 steps or more):
            Use this step-by-step format:

            ## Step 1: [Concise description]
            [Brief explanation and calculations]

            ## Step 2: [Concise description]
            [Brief explanation and calculations]

            ...

            Regardless of the approach, always conclude with:

            Therefore, the final answer is: $\\boxed{{answer}}$. I hope it is correct.

            Where [answer] is just the final number or expression that solves the problem.
        
        Problem: {d["problem"]}
        """

    messages = [{"role": "user", "content": prompt}]

    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=1024,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.9,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=5,
    )

    generations = tokenizer.batch_decode(
        outputs[:, input_ids.shape[-1] :], skip_special_tokens=True
    )
    d["model_responses"] = [g.strip() for g in generations]
    results.append(d)

with open(f"5_samples_math_gen_{model_name}.json", "w") as json_file:
    json.dump(results, json_file, indent=4)
