import argparse
import json
import os
import sys
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import re

import torch_npu
from torch_npu.npu import amp
from torch_npu.contrib import transfer_to_npu

# 定义选项列表
choices = ["A", "B", "C", "D"]

def softmax(x):
    """计算softmax概率"""
    z = x - np.max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    return numerator / denominator

def format_subject(subject):
    """格式化科目名称"""
    l = subject.split("_")
    return " ".join(l)

def format_example(df, idx, include_answer=True):
    """格式化单个问题示例"""
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

def gen_prompt(train_df, subject, k=-1):
    """生成few-shot prompt"""
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt

def postprocess(text: str, options: str = "ABCD") -> str:
    """从模型输出中提取答案选项"""
    if not isinstance(text, str):
        return ""
    
    text = text.strip()
    
    # 多种模式匹配答案
    patterns = [
        r'Answer:\s*([A-D])',
        r'答案是?\s*([A-D])',
        r'答案是?\s*：\s*([A-D])',
        r'答案是?\s*:\s*([A-D])',
        r'答案选项应?该?是\s*([A-D])',
        r'答案选项应?该?为\s*([A-D])',
        r'答案应该?是\s*([A-D])',
        r'答案应该?选\s*([A-D])',
        r'答案选项为?\s*：\s*([A-D])',
        r'答案选项是?\s*:\s*([A-D])',
        r'答案为\s*([A-D])',
        r'答案选\s*([A-D])',
        r'选择?\s*([A-D])',
        r'故选?\s*([A-D])',
        r'只有选?项?\s?([A-D])\s?是?对',
        r'只有选?项?\s?([A-D])\s?是?错',
        r'只有选?项?\s?([A-D])\s?不?正确',
        r'只有选?项?\s?([A-D])\s?错误',
        r'说法不?对选?项?的?是\s?([A-D])',
        r'说法不?正确选?项?的?是\s?([A-D])',
        r'说法错误选?项?的?是\s?([A-D])',
        r'([A-D])\s?是正确的',
        r'([A-D])\s?是正确答案',
        r'选项\s?([A-D])\s?正确',
        r'所以答\s?([A-D])',
        r'所以\s?([A-D][.。$]?$)',
        r'所有\s?([A-D][.。$]?$)',
        r'(?i)ANSWER\s*:\s*([A-D])',
        r'[\s，：:,]([A-D])[。，,\.]?$',
        r'[\s，,：:][故即]([A-D])[。\.]?$',
        r'[\s，,：:]因此([A-D])[。\.]?$',
        r'[是为。]\s?([A-D])[。\.]?$',
        r'因此\s?([A-D])[。\.]?$',
        r'显然\s?([A-D])[。\.]?$',
        r'答案是\s?(\S+)(?:。|$)',
        r'答案应该是\s?(\S+)(?:。|$)',
        r'答案为\s?(\S+)(?:。|$)',
        r'[Tt]he answer is:?\s+\(?([A-D])\)?',
        r'[Tt]he answer is option:?\s+\(?([A-D])\)?',
        r'[Tt]he correct answer is:?\s+\(?([A-D])\)?',
        r'[Tt]he correct answer is option:?\s+\(?([A-D])\)?',
        r'[Tt]he correct answer is:?.*?boxed{{([A-D])}}',
        r'[Tt]he correct option is:?.*?boxed{{([A-D])}}',
        r'[Tt]he correct answer option is:?.*?boxed{{([A-D])}}',
        r'[Tt]he answer to the question is:?\s+\(?([A-D])\)?',
        r'^选项\s?([A-D])',
        r'^([A-D])\s?选?项',
        r'(\s|^)[A-D][\s。，,：:\.$]',
        r'1.\s?(.*?)$',
        r'1.\s?([A-D])[.。$]?$',
        r'([A-D]):',
        r'([A-D])',
    ]

    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            answer = match.group(1)
            if answer in options:
                return answer
    
    return ""

def eval_subject(args, subject, model, tokenizer, dev_df, test_df, device):
    """评估单个科目的准确率"""
    cors = []
    all_probs = []
    answers = choices[:test_df.shape[1]-2]

    print(f"Evaluating {subject}...")
    
    for i in tqdm(range(test_df.shape[0]), desc=f"Processing {subject}"):
        k = args.ntrain
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, k)
        prompt = train_prompt + prompt_end

        # 确保prompt长度符合模型输入限制
        while len(tokenizer.encode(prompt)) > args.max_length - 50:  # 留一些空间给生成
            k -= 1
            if k <= 0:
                break
            train_prompt = gen_prompt(dev_df, subject, k)
            prompt = train_prompt + prompt_end

        # 将prompt转换为模型输入格式
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=args.max_length).to(device)

        with torch.no_grad():
            # 生成答案
            outputs = model.generate(
                **inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=args.do_sample,
                top_k=args.top_k,
                top_p=args.top_p,
                temperature=args.temperature,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                repetition_penalty=args.repetition_penalty
            )
            
            # 解码生成的文本
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # 提取答案部分
            answer_text = generated_text[len(prompt):].strip()
            
            # 后处理提取答案选项
            pred = postprocess(answer_text)
            
            # 如果没有找到答案，尝试从logits中预测
            if not pred:
                # 获取最后一个token的logits
                logits = model(**inputs).logits
                last_token_logits = logits[0, -1, :]
                
                # 计算选项的logprobs
                lprobs = []
                for ans in answers:
                    ans_tokens = tokenizer(" {}".format(ans), add_special_tokens=False).input_ids
                    if ans_tokens:
                        lprobs.append(last_token_logits[ans_tokens[0]].item())
                    else:
                        lprobs.append(-float('inf'))
                
                # 预测和计算概率
                pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(lprobs)]
                probs = softmax(np.array(lprobs))
            else:
                # 如果从文本中提取到了答案，设置概率
                probs = np.array([0.25, 0.25, 0.25, 0.25])  # 默认均匀分布

        # 检查预测是否正确
        label = test_df.iloc[i, test_df.shape[1] - 1]
        cor = pred == label
        cors.append(cor)
        all_probs.append(probs)

    acc = np.mean(cors)
    cors = np.array(cors)
    all_probs = np.array(all_probs)
    print(f"Average accuracy {acc:.3f} - {subject}")

    return cors, acc, all_probs

def evaluate_mmlu_model(model, tokenizer, args):
    """主要的MMLU评估函数"""
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    # 获取所有科目
    subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
    
    print(f"Found {len(subjects)} subjects: {subjects}")
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    all_cors = []
    subject_accuracies = {}
    
    for subject in subjects:
        try:
            dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
            test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
            
            cors, acc, probs = eval_subject(args, subject, model, tokenizer, dev_df, test_df, device)
            all_cors.append(cors)
            subject_accuracies[subject] = acc
            
            # 保存结果
            test_df[f"{args.model_name}_correct"] = cors
            for j in range(probs.shape[1]):
                choice = choices[j]
                test_df[f"{args.model_name}_choice{choice}_probs"] = probs[:, j]
            
            result_path = os.path.join(args.output_dir, f"{subject}.csv")
            test_df.to_csv(result_path, index=None)
            
        except Exception as e:
            print(f"Error evaluating {subject}: {e}")
            continue
    
    # 计算总体准确率
    if all_cors:
        weighted_acc = np.mean(np.concatenate(all_cors))
        print(f"\nOverall MMLU accuracy: {weighted_acc:.3f}")
        
        # 保存总体结果
        results = {
            "overall_accuracy": weighted_acc,
            "subject_accuracies": subject_accuracies,
            "model_name": args.model_name,
            "ntrain": args.ntrain,
            "max_length": args.max_length,
            "max_new_tokens": args.max_new_tokens
        }
        
        with open(os.path.join(args.output_dir, "results.json"), "w") as f:
            json.dump(results, f, indent=2)
        
        # 打印各科目准确率
        print("\nSubject-wise accuracies:")
        for subject, acc in subject_accuracies.items():
            print(f"{subject}: {acc:.3f}")
    
    return weighted_acc if all_cors else 0.0

def main():
    parser = argparse.ArgumentParser(description="Evaluate 537M model on MMLU dataset")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model")
    parser.add_argument("--tokenizer_path", type=str, required=True, help="Path to the tokenizer")
    parser.add_argument("--data_dir", type=str, required=True, help="Path to MMLU dataset")
    parser.add_argument("--output_dir", type=str, required=True, help="Output directory for results")
    parser.add_argument("--ntrain", type=int, default=5, help="Number of few-shot examples")
    parser.add_argument("--max_length", type=int, default=2048, help="Maximum input length")
    parser.add_argument("--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate")
    parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
    parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling")
    parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling")
    parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
    parser.add_argument("--repetition_penalty", type=float, default=1.0, help="Repetition penalty")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use")
    parser.add_argument("--model_name", type=str, default="537M", help="Model name for results")
    
    args = parser.parse_args()
    
    print(f"Loading model from {args.model_path}")
    print(f"Loading tokenizer from {args.tokenizer_path}")
    
    # 加载tokenizer和模型
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16,
        device_map="auto" if args.device == "cuda" else None
    )
    
    # 设置pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print("Model loaded successfully!")
    
    # 执行评估
    overall_acc = evaluate_mmlu_model(model, tokenizer, args)
    
    print(f"\nFinal MMLU accuracy: {overall_acc:.3f}")

if __name__ == "__main__":
    main()
