import json
import os
import re
import time
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# 配置模型路径
MODEL_PATH = "/ssd/ssd_backup/home/wuweijie/.cache/huggingface/models--meta-llama--CodeLlama-13b-Instruct-hf/snapshots/c24689b733f841ab1a6cce35bccb729f3f7b23fd"  # 更改为您的本地模型路径

# 输入文件路径
INPUT_PATHS = [
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/dataset/kotlin.json",
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/dataset/php.json",
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/dataset/scala.json"
]

# 输出文件路径
OUTPUT_PATHS = [
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/test_cases/kotlin-test.jsonl",
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/test_cases/php-test.jsonl",
    "/ssd/ssd_backup/home/wuweijie/CodeJudge/evaluation/data/leetcode/test_cases/scala-test.jsonl"
]

# 使用两个GPU
DEVICES = [0, 1]

def extract_pure_code(model_output):
    """从模型输出中提取纯代码"""
    # 尝试匹配Markdown代码块
    md_match = re.search(r'```(?:kotlin|php|scala)?\s*(.*?)\s*```', model_output, re.DOTALL)
    if md_match:
        return md_match.group(1).strip()
    
    # 尝试匹配没有标记的代码块
    code_lines = []
    in_code = False
    for line in model_output.split('\n'):
        if line.strip().startswith(('class ', 'function ', '<?php', 'object ', 'def ')) or in_code:
            in_code = True
            code_lines.append(line)
        elif in_code and line.strip() == '':
            break
    
    if code_lines:
        return '\n'.join(code_lines).strip()
    
    # 如果都不匹配，返回原始输出
    return model_output.strip()

def load_model_and_tokenizer(model_path):
    """加载本地模型和tokenizer，使用两个GPU"""
    print("检测可用设备...")
    if torch.cuda.is_available():
        print(f"找到 {torch.cuda.device_count()} 个GPU")
        device_map = "auto"  # 让transformers自动分配
    else:
        print("未找到GPU，使用CPU")
        device_map = None
    
    # 加载tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map=device_map,
        max_memory={i: "20GiB" for i in DEVICES}  # 限制每个GPU的内存使用
    )
    
    # 创建文本生成pipeline
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )
    
    print("模型加载完成，设备分配:")
    print(model.hf_device_map)
    
    return generator

def generate_code(prompt, lang, generator):
    """使用本地模型生成代码"""
    lang_map = {
        "kotlin": "Kotlin",
        "php": "PHP",
        "scala": "Scala"
    }
    lang_name = lang_map.get(lang, lang)
    
    # 系统提示
    system_prompt = (
        f"You are an expert in {lang_name} programming. Below is a problem description along with a predefined solution class. Please solve the problem and output only the code — no explanations, comments, or extra text."
        f"Complete the predefined solution class to ensure the code is complete, satisfies the problem description, and is executable, including all necessary imports."
    )
    
    # 完整的提示
    full_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{prompt} [/INST]"
    
    # 生成参数
    generation_params = {
        "max_new_tokens": 1024,
        "temperature": 0.2,
        "top_p": 0.95,
        "do_sample": True,
        "return_full_text": False
    }
    
    try:
        # 生成代码
        result = generator(full_prompt, **generation_params)
        generated_text = result[0]['generated_text']
        return extract_pure_code(generated_text)
    except Exception as e:
        print(f"生成代码时出错: {e}")
        return ""

def process_language_file(input_path, output_path, generator):
    """处理单个语言文件"""
    lang = os.path.basename(input_path).split('.')[0]
    print(f"\n处理语言: {lang.upper()}")
    print(f"输入文件: {input_path}")
    print(f"输出文件: {output_path}")
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # 检查已有进度
    processed_ids = set()
    if os.path.exists(output_path):
        print(f"检测到现有输出文件，加载进度...")
        with open(output_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    processed_ids.add(data["question_id"])
                except json.JSONDecodeError:
                    continue
        print(f"已处理 {len(processed_ids)} 个问题")
    
    # 读取输入文件
    with open(input_path, 'r', encoding='utf-8') as f:
        problems = json.load(f)
    
    # 过滤未处理的问题
    to_process = {qid: prob for qid, prob in problems.items() if qid not in processed_ids}
    if not to_process:
        print(f"所有 {lang} 问题已处理完成")
        return
    
    print(f"开始处理 {len(to_process)} 个新问题...")
    total_problems = len(to_process)
    
    # 处理每个问题并显示进度条
    with open(output_path, 'a', encoding='utf-8') as out_f:
        # 使用tqdm显示进度条
        progress_bar = tqdm(total=total_problems, desc=f"生成 {lang} 代码")
        
        for qid, problem in to_process.items():
            program = generate_code(problem["prompt"], lang, generator)
            
            result = {
                "question_id": qid,
                "program": program,
                "pass": None
            }
            
            out_f.write(json.dumps(result, ensure_ascii=False) + '\n')
            out_f.flush()  # 确保及时写入
            
            progress_bar.update(1)  # 更新进度条
            time.sleep(0.1)  # 避免过热
        
        progress_bar.close()

def main():
    """主处理函数"""
    # 加载模型
    print("加载CodeLlama模型...")
    generator = load_model_and_tokenizer(MODEL_PATH)
    
    # 处理所有语言文件
    for input_path, output_path in zip(INPUT_PATHS, OUTPUT_PATHS):
        if os.path.exists(input_path):
            process_language_file(input_path, output_path, generator)
        else:
            print(f"输入文件不存在: {input_path}")
    
    print("\n所有处理完成!")

if __name__ == "__main__":
    main()