import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import pandas as pd
import json
from datetime import datetime
from tqdm import tqdm
import argparse
import os

from utils import build_conversation

api_token = "your_token"
os.environ["HF_TOKEN"] = api_token

def main(args):
    """
    Executing the LRM inferences given the CLI arguments.

    Arguments
    -------------------
    - args:
    Inference settings, namely:
        - model: path to HF model - we support "DS_Llama_8B", "DS_Qwen_14B" and "QwQ_32b".
        - prompt_type: see utils.build_conversation - we support "none", "UP-ZS", "SP-ZS", "SP-FS1", "SP-FS3"
        - dataset_path: path to the reasoning dataset - we support MATH500 and GSM8K
        - seed: we use deterministic sampling for reploducibility
        - run_number: run id

    Return
    -------------------
    - conversation: list
    List of dictionary containing the input context to input to the LRM to perform the inference.
    """
    
    model_choices = {
        "DS_Llama_8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        "DS_Qwen_14B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
        "QwQ_32b": "Qwen/QwQ-32B"
    }

    max_num_tokens_models = {
        "DS_Llama_8B": 32768,
        "DS_Qwen_14B": 32768,
        "QwQ_32b": 32768
    }

    if args.model not in model_choices:
        raise ValueError(f"Invalid model key '{args.model}'. Valid options: {list(model_choices.keys())}")

    model_path = model_choices[args.model]
    device = "cuda"
    dtype = "float16"

    cache_dir = "hf_models/"

    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=getattr(torch, dtype), cache_dir=cache_dir).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)

    df = pd.read_json(args.dataset_path, lines=True)

    dataset_name = ""

    if "MATH500" in args.dataset_path:
        dataset_name = "MATH500"

    elif "GSM8K" in args.dataset_path:
        dataset_name = "GSM8K"
    
    else:
        raise ValueError(f"Invalid dataset name '{args.dataset_path}'. Valid options: [MATH500, GSM8K]")
    
    dataset_type = ""

    if "train" in args.dataset_path:
        dataset_type = "train"

    elif "test" in args.dataset_path:
        dataset_type = "test"
    
    else:
        raise ValueError(f"Invalid dataset type '{args.dataset_path}'. Valid options: [train, test]")

    output_base_path = os.path.join("results", dataset_name, dataset_type, args.prompt_type, args.model)
    os.makedirs(output_base_path, exist_ok=True)

    config_path = os.path.join(output_base_path, f"config_run_{args.run_number}_TEST_SEED_{args.seed}.json")
    run_path = os.path.join(output_base_path, f"run_{args.run_number}_TEST_SEED_{args.seed}.json")

    config_file = [{
        "run_number": args.run_number,
        "model_path": model_path,
        "seed": args.seed,
        "prompt_type": args.prompt_type,
        "dataset_path": args.dataset_path,
    }]
    with open(config_path, "w") as f:
        json.dump(config_file, f, indent=4)

    results = []

    for i in tqdm(range(len(df))):

        if dataset_name == "MATH500":
            question = df.iloc[i]["problem"]

        elif dataset_name == "GSM8K":
            question = df.iloc[i]["question"]

        else:
            raise ValueError(f"Invalid dataset '{args.dataset_path}'. Inference not supported on this dataset.")

        conversation = build_conversation(question, args.prompt_type)

        inputs = tokenizer.apply_chat_template(
            conversation,
            return_tensors="pt",
            return_dict=True,
            add_generation_prompt=True
        ).to(device)

        set_seed(args.seed)
        start = datetime.now()
        output = model.generate(**inputs, max_new_tokens=max_num_tokens_models[args.model])
        end = datetime.now()

        prediction = tokenizer.decode(output[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)

        results.append({
            "question_id": i,
            "conversation": conversation,
            "question": question,
            "answer": prediction,
            "runtime": (end - start).total_seconds(),
            "number_tokens": output[0, :].size(0)
        })

        pd.DataFrame(results).to_json(run_path, orient="records", indent=4)

# -----------------------------------------------------------
# ------------------------ CLI Entry ------------------------
# -----------------------------------------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run LRM inference on a given dataset.")

    parser.add_argument("--model", type=str, required=True,
                        choices=["DS_Llama_8B", "DS_Qwen_14B", "QwQ_32b"],
                        help="Model key to use")
    parser.add_argument("--prompt_type", type=str, required=True,
                        choices=["none", "UP-ZS", "SP-ZS", "SP-FS1", "SP-FS3"],
                        help="Prompt format to use")
    parser.add_argument("--dataset_path", type=str, default="datasets/MATH500_test.jsonl",
                        help="Path to the dataset")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    parser.add_argument("--run_number", type=int, default=0,
                        help="Run ID number")

    args = parser.parse_args()
    main(args)