import argparse
import os
import numpy as np
import pandas as pd
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer  # 根据您的模型类型
from crop import crop  # 您代码中的 crop 函数

import torch_npu
from torch_npu.contrib import transfer_to_npu

# 定义选项列表
choices = ["A", "B", "C", "D"]

def softmax(x):
    z = x - max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    return numerator / denominator

def format_subject(subject):
    l = subject.split("_")
    return " ".join(l)

def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt

# 加载模型和分词器
def load_model_and_tokenizer(model_path, tokenizer_path):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.eval()
    return model, tokenizer

# 使用本地模型的推理替换 openai.Completion.create
def eval(args, subject, model, tokenizer, dev_df, test_df, device):
    cors = []
    all_probs = []
    answers = choices[:test_df.shape[1]-2]

    for i in range(test_df.shape[0]):
        k = args.ntrain
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, k)
        prompt = train_prompt + prompt_end

        # 确保 prompt 长度符合模型输入限制
        while crop(prompt) != prompt:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject, k)
            prompt = train_prompt + prompt_end

        # 将 prompt 转换为模型输入格式
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)

        with torch.no_grad():
            # 使用模型进行推理
            outputs = model(**inputs)
            logits = outputs.logits

        # 获取最后一个 token 的 logits
        last_token_logits = logits[0, -1, :]

        # 计算选项的 logprobs
        lprobs = []
        for ans in answers:
            ans_token_id = tokenizer(" {}".format(ans), add_special_tokens=False).input_ids[0]
            lprobs.append(last_token_logits[ans_token_id].item())

        # 预测和计算概率
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(lprobs)]
        probs = softmax(np.array(lprobs))

        # 检查预测是否正确
        label = test_df.iloc[i, test_df.shape[1] - 1]
        cor = pred == label
        cors.append(cor)
        all_probs.append(probs)

    acc = np.mean(cors)
    cors = np.array(cors)
    all_probs = np.array(all_probs)
    print("Average accuracy {:.3f} - {}".format(acc, subject))

    return cors, acc, all_probs

def main(args):
    # 加载模型和分词器, 指定设备为 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, tokenizer = load_model_and_tokenizer(args.model_path, args.tokenizer_path)
    # 加载模型和分词器后，将模型转移到 GPU
    model = model.to(device)

    engines = args.engine
    subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    for engine in engines:
        if not os.path.exists(os.path.join(args.save_dir, "results_{}".format(engine))):
            os.mkdir(os.path.join(args.save_dir, "results_{}".format(engine)))

    print(subjects)
    print(args)

    for engine in engines:
        print(engine)
        all_cors = []

        for subject in subjects:
            dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
            test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)

            cors, acc, probs = eval(args, subject, model, tokenizer, dev_df, test_df, device)
            all_cors.append(cors)

            test_df["{}_correct".format(engine)] = cors
            for j in range(probs.shape[1]):
                choice = choices[j]
                test_df["{}_choice{}_probs".format(engine, choice)] = probs[:, j]
            test_df.to_csv(os.path.join(args.save_dir, "results_{}".format(engine), "{}.csv".format(subject)), index=None)

        weighted_acc = np.mean(np.concatenate(all_cors))
        print("Average accuracy: {:.3f}".format(weighted_acc))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/opt/dpcvol/models/LLM_Distillation/results/gpt2_138M-token_1B/")
    parser.add_argument("--tokenizer_path", type=str, default="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/minillm_official/gpt2/train/minillm/medium-init-xlarge-sft/")
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--data_dir", "-d", type=str, default="/opt/dpcvol/datasets/8625883998351850434/datasets/llm/mmlu/")
    parser.add_argument("--save_dir", "-s", type=str, default="./results")
    parser.add_argument("--engine", "-e", choices=["davinci", "curie", "babbage", "ada"],
                        default=["davinci", "curie", "babbage", "ada"], nargs="+")
    args = parser.parse_args()
    main(args)
