import json 
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm import tqdm 

def transformers_generate(model_path, testset):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype="auto",
        device_map="cuda:1"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    outputs = []
    for item in tqdm(testset):
        # if item["question_id"] != "abc374_a":
        #     continue

        FORMATTING_MESSAGE_WITH_STARTER_CODE = "You will use the following starter code to write the solution to the problem and enclose your code within delimiters."
        FORMATTING_WITHOUT_STARTER_CODE = "You will read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters."
        if item["starter_code"]:
            prompt = f"### Instruction:\n{item['question_content']}\n\n{FORMATTING_MESSAGE_WITH_STARTER_CODE}\n```python\n{item['starter_code']}\n```\n\n### Response:\n"
        else:
            prompt = f"### Instruction:\n{item['question_content']}\n\n{FORMATTING_WITHOUT_STARTER_CODE}\n\n```python\n# YOUR CODE HERE\n```\n\n### Response:\n"

        messages = [
            {"role": "system", "content": prompt}
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        # print(text)

        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=2048
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        print(response)
        outputs.append([response])

    with open("output-32b.json", 'w') as f:
        json.dump(outputs, f, indent=4)


def vllm_generate(model_path, testset):
    llm = LLM(
        model=model_path,
        tokenizer=model_path,
        gpu_memory_utilization=0.98
    )

    sampling_params = SamplingParams(
        n=1,
        max_tokens=2048,
        temperature=0,
    )


        FORMATTING_WITHOUT_STARTER_CODE = "You will read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters."
        # prompt = "#write a quick sort algorithm.\ndef quick_sort("

        print(prompt)

        vllm_outputs = llm.generate([prompt], sampling_params)

        print(vllm_outputs)
        assert False

def get_testset(test_data_path):

    testset = []
    with open(test_data_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            testset.append(data)

    testset = sorted(testset, key=lambda x: x['question_id'])
    return testset 

if __name__ == "__main__":

    model_name = "Qwen/Qwen2.5-Coder-32B-Instruct"
    test_data_path = "code_bench/livecodebench/code_generation_lite/test_datasets.jsonl"

    testset = get_testset(test_data_path)


    generate_type = "transformers"
    if generate_type == "transformers":
        transformers_generate(model_name, testset)
    else:
        vllm_generate(model_name, testset)