import os
import argparse
import numpy as np
import pickle
import joblib
import json
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from biscope_utils import data_generation  # 保留原始版本
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def load_features(path):
    with open(path, 'rb') as f:
        return np.array(pickle.load(f))

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='test')
    parser.add_argument('--detect_model', required=True)
    parser.add_argument('--summary_model', default='none')
    parser.add_argument('--task', default='task')  # 只做文件命名前缀用
    parser.add_argument('--human_json', default=' ')
    parser.add_argument('--machine_json', default=' ')
    parser.add_argument('--output_dir', default='./results')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    # 设置模型路径
    model_path = os.path.join(args.output_dir, 'random_forest_model.joblib')

    if args.mode == 'train':
        print("=== 训练模式 ===")
        # 创建一个专门存放 pkl 的目录
        feat_dir = os.path.join(args.output_dir, 'train_features')
        os.makedirs(feat_dir, exist_ok=True)

        # 保存模拟的数据目录结构
        dataset_type = 'nonparaphrased'
        task = args.task
        gen_model = 'gpt'

        # 将两个json文件复制/重命名或使用软链接到data_generation中约定的路径也行
        # 此处假设你已经手动放置到 ./Dataset/{task}/{task}_human.json 和 ./Dataset/{task}/{task}_gpt.json

        print("Generating training features...")
        data_generation(args, feat_dir, dataset_type, task, gen_model)

        human_feats = load_features(os.path.join(feat_dir, f"{task}_human_features.pkl"))
        machine_feats = load_features(os.path.join(feat_dir, f"{task}_GPT_features.pkl"))

        X = np.concatenate([human_feats, machine_feats])
        y = np.concatenate([np.zeros(len(human_feats)), np.ones(len(machine_feats))])

        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X, y)
        joblib.dump(clf, model_path)
        print(f"模型已保存至: {model_path}")

    elif args.mode == 'test':
        print("=== 测试模式 ===")
        clf = joblib.load(model_path)
        feat_dir = os.path.join(args.output_dir, 'test_features')
        os.makedirs(feat_dir, exist_ok=True)
        for model_name in ["DSB"]:
            print(model_name, "start")
            # for length in [40, 80, 120, 160, 200, 240]:
            # for data_type in ["machine_test", "human_test"]
            dataset_type = 'nonparaphrased'
            task = f"M4_{model_name}"
            gen_model = 'gpt'
            args.human_json = f"./AIDetection/Related_dataset/DSB/{model_name}_human_test.json"
            args.machine_json = f"./AIDetection/Related_dataset/DSB/{model_name}_machine_test_v4.json"
            print("Generating test features...")
            data_generation(args, feat_dir, dataset_type, task, gen_model, input_length=1024)
            # data_generation(args, feat_dir, dataset_type, task, gen_model)

            human_feats = load_features(os.path.join(feat_dir, f"{task}_human_features.pkl"))
            machine_feats = load_features(os.path.join(feat_dir, f"{task}_GPT_features.pkl"))

            X_test = np.concatenate([human_feats, machine_feats])
            # X_test = machine_feats
            # y_test = np.concatenate([np.zeros(len(human_feats)), np.ones(len(machine_feats))])


            probs = clf.predict_proba(X_test)[:, 1]

            # Split scores
            human_scores = probs[:len(human_feats)].tolist()
            machine_scores = probs[len(human_feats):].tolist()
            # machine_scores = probs.tolist()

            # Save individual scores
            with open(f'./AIDetection/DNA-DetectLLM/scores/RealDet_Biscope_human_test.json', 'w') as f:
                json.dump({"predictions": human_scores}, f)
            with open(f'./AIDetection/DNA-DetectLLM/scores/RealDet_Biscope_machine_test.json', 'w') as f:
                json.dump({"predictions": machine_scores}, f)

            # auc = roc_auc_score(y_test, probs)
            # print(f"AUC Score: {auc:.4f}")

if __name__ == '__main__':
    main()
