from datasets import load_from_disk, concatenate_datasets
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import torch
from tqdm import tqdm
import json
import os
from pathlib import Path
import json
from typing import Any


def to_jsonable(seq):
    """
    把  [ {token_id: Logprob|None} | None , ... ]
    转成  [{token_id, logprob, rank, token}, ...]  的可 JSON 序列化格式
    """
    out = []
    for elem in (seq or []):                    # 如果 seq 本身是 None → 返回 []
        if elem is None:
            out.append(_null_row())             # ① 整个元素为空
            continue

        if not isinstance(elem, dict) or not elem:
            # ② 既不是 dict 也不是 None（例如意外传入 str/int）
            out.append(_null_row(extra=str(elem)))
            continue

        token_id, lp_obj = next(iter(elem.items()))  # 只取第一个键值对
        if lp_obj is None:                          # ③ dict 的 value 为空
            out.append(_null_row(token_id=token_id))
        else:
            out.append({
                "token_id": token_id,
                "logprob":  float(getattr(lp_obj, "logprob", "None")),
                "rank":     getattr(lp_obj, "rank", "None"),
                "token":    getattr(lp_obj, "decoded_token", "None"),
            })
    return out


def _null_row(token_id=None, extra=None):
    """生成字段全为空的占位行；extra 用于调试"""
    return {
        "token_id": "None",
        "logprob":  "None",
        "rank":     "None",
        "token":    "None",   # 额外信息可写在这里，也可以改回 None
    }

def ensure_file_directory_exists(file_path):
    # Convert to Path object for better path handling
    path = Path(file_path)
    
    # Get the parent directory
    dir_path = path.parent
    
    # Create directory if it doesn't exist
    if not dir_path.exists():
        dir_path.mkdir(parents=True, exist_ok=True)
        print(f"Created directory: {dir_path}")
    
    return str(path.resolve())


def get_eval_prompt(question, answer, options, tokenizer):
    # prompt = question + "\n\n You need to break down the above problem into sub-questions that can help derive the final answer. For each sub-question generated, answer it first, then continue generating further sub-questions until the final answer is reached."
    # system_prompt = "You are a math expert."
    q = "Q: " + question + f"\n options: {options}"
    a = "A: " + answer
    prompt = q + " " + a

    message = [
        {"role": "user", "content": prompt}
    ]
    return tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    ), len(tokenizer(q)['input_ids'])

def main(args, model, dataset, output_dir, task):
    tokenizer = AutoTokenizer.from_pretrained(model)  
    tokenizer.pad_token = tokenizer.eos_token

    llm = LLM(
        model=model,
        tensor_parallel_size=args.tensor_parallel_size,
        max_model_len=4096,
        gpu_memory_utilization=args.gpu_memory_utilization
    )

    greedy_params = SamplingParams(temperature=0, max_tokens=1, stop=["<|endoftext|>", "<|im_end|>"], prompt_logprobs=0)
    greedy_prompts = []
    greedy_lens = []
    for instance in dataset:
        greedy_prompt, greedy_len = get_eval_prompt(instance['question'], instance['explanation'], instance['options'], tokenizer)
        greedy_prompts.append(greedy_prompt)
        greedy_lens.append(greedy_len)

    greedy_logprob = list(map(lambda x: x.prompt_logprobs, 
                    llm.generate(greedy_prompts, greedy_params)))
    
    torch.cuda.empty_cache()
    
    
    greedy_logprob = [to_jsonable(prob) for prob in greedy_logprob]
    
    result = []
    for idx, example in enumerate(dataset):
        
        result_i = {
            "question": example['question'],
            "solution": example['explanation'],
            "solution_prob": greedy_logprob[idx],
            "label": example['label']
        }
        result.append(result_i)
    
    model_name = os.path.basename(model)
    output_path = os.path.join(output_dir, model_name + f"_{task}.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=4, allow_nan=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="自动化代码补全处理脚本")
    parser.add_argument("--model", type=str, required=True, help="使用的模型名称")
    parser.add_argument("--inference_dataset_path", type=str, required=True, help="数据集存储路径")
    parser.add_argument("--output_dir", type=str, required=True, help="结果输出目录")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.4)
    parser.add_argument("--tensor_parallel_size", type=int, default=2)
    parser.add_argument("--task", type=str, required=True, help="")
    args = parser.parse_args()
    
    torch.cuda.empty_cache()
    # 加载数据集
    with open(args.inference_dataset_path, "r") as f:
        dataset = json.load(f)
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 执行处理流程
    main(args, args.model, dataset, args.output_dir, args.task)