import os
import sys
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import euclidean_distances
import pickle
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, roc_curve, average_precision_score
from sklearn.model_selection import train_test_split
import torch
import time

sys.path.append(os.path.abspath("."))

DATASET = "SQuAD"
MODEL = "llama3_8b"

# 保留模型层定义
LAYERS_PER_MODEL = {
    'llama3_70b': [0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79],
    'phi3': [0, 7, 15, 23, 31],
    'mixtral': [0, 7, 15, 23, 31],
    'mistral': [0, 7, 15, 23, 31],
    'llama3_8b': [0, 7, 15, 23, 31],
    'mistral_no_priming': [0, 7, 15, 23, 31],
    'vicuna': [0, 7, 15, 23, 31],
}

from sklearn.ensemble import RandomForestClassifier
from task_tracker.CONFIG_1 import current_risk
RISK = current_risk 
DATASET = "FinQA"
ANCHOR_SAMPLES = [1000]
from task_tracker.training.dataset import ActivationsDatasetDynamic, ActivationsDatasetDynamicPrimaryText
from task_tracker.training.helpers.data import load_file_paths
#from task_tracker.training.utils.constants_1 import CONSTANTS_ALL_MODELS, OOD_POISONED_FILE

# Function to train a binary classifier
def train_binary_classifier(train_activations, train_labels):
    print("[*] Training triplet probe")
    classifier = RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10, min_samples_split=5)
    classifier.fit(train_activations, train_labels)
    return classifier

def load_activations(file_paths, num_layers, activations_dir):
    dataset = ActivationsDatasetDynamic(file_paths, root_dir=activations_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_val_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/validation')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_test_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/test')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

# Add a print statement to debug file paths
def load_file_paths(file_path):
    with open(file_path, 'r') as f:
        paths = f.read().splitlines()
    for path in paths:
        print(f"Loading file: {path}")
    return paths

def generate_distance_features(samples, clean_anchors):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    clean_anchors_tensor = torch.tensor(clean_anchors).to(device)
    features = []
    for sample in tqdm(samples):
        sample_tensor = torch.tensor(sample).to(device)
        expanded_sample = sample_tensor.expand_as(clean_anchors_tensor)
        distances = torch.norm(expanded_sample - clean_anchors_tensor, dim=1)
        avg_distance = distances.mean().item()
        features.append(avg_distance)
    return np.array(features)

def evaluate_classifier(classifier, test_activations, test_labels):
    print("[*] 进行结果评估")
    predictions = classifier.predict(test_activations)
    accuracy = accuracy_score(test_labels, predictions)
    
    # 计算 AUPRC
    auprc = average_precision_score(test_labels, classifier.predict_proba(test_activations)[:, 1])
    
    fpr, tpr, _ = roc_curve(test_labels, classifier.predict_proba(test_activations)[:, 1])
    
    print(f"Test accuracy: {accuracy}")
    print(f"AUPRC score: {auprc}")
    print(classification_report(test_labels, predictions))
    return accuracy, auprc, fpr, tpr

def train_and_evaluate_with_distance(train_files_clean, train_files_poisoned, val_files_clean, val_files_poisoned, 
                                    test_files_clean, test_files_poisoned, num_layers, activations_dir, n_anchors):
    """
    训练和评估模型，使用指定数量的锚点样本
    
    Args:
        n_anchors: 锚点样本数量
    """
    # 加载训练集 activations
    print("加载训练集激活")
    train_clean_activations = load_activations(train_files_clean, num_layers, activations_dir)
    train_poisoned_activations = load_activations(train_files_poisoned, num_layers, activations_dir)

    # 加载测试集 activations
    print("加载测试集激活")
    test_clean_activations = load_test_activations(test_files_clean, num_layers, activations_dir)
    test_poisoned_activations = load_test_activations(test_files_poisoned, num_layers, activations_dir)

    # 加载验证集 activations
    print("加载验证集激活")
    val_clean_activations = load_val_activations(val_files_clean, num_layers, activations_dir)
    val_poisoned_activations = load_val_activations(val_files_poisoned, num_layers, activations_dir)

    # 设置指定数量的 clean 锚点样本
    n = min(len(train_clean_activations), n_anchors)
    clean_anchors = train_clean_activations[:n]

    # 生成训练集特征和标签
    print("计算训练集激活残差")
    train_features_clean = generate_distance_features(train_clean_activations, clean_anchors)
    train_features_poisoned = generate_distance_features(train_poisoned_activations, clean_anchors)
    train_features = np.concatenate([train_features_clean, train_features_poisoned], axis=0)
    train_labels = np.array([0] * len(train_features_clean) + [1] * len(train_features_poisoned))

    # 生成验证集特征和标签
    print("计算验证集激活残差")
    val_features_clean = generate_distance_features(val_clean_activations, clean_anchors)
    val_features_poisoned = generate_distance_features(val_poisoned_activations, clean_anchors)
    val_features = np.concatenate([val_features_clean, val_features_poisoned], axis=0)
    val_labels = np.array([0] * len(val_features_clean) + [1] * len(val_features_poisoned))

    # 训练二分类器
    print("将数据输入初始嵌入模型，训练新嵌入模型")
    classifier = train_binary_classifier(train_features.reshape(-1, 1), train_labels)

    # 在验证集上评估模型
    print("在验证集上进行评估")
    evaluate_classifier(classifier, val_features.reshape(-1, 1), val_labels)

    # 在测试集上评估模型
    print("在测试集上进行评估")
    test_features_clean = generate_distance_features(test_clean_activations, clean_anchors)
    test_features_poisoned = generate_distance_features(test_poisoned_activations, clean_anchors)
    test_features = np.concatenate([test_features_clean, test_features_poisoned], axis=0)
    test_labels = np.array([0] * len(test_features_clean) + [1] * len(test_features_poisoned))
    
    # 使用 AUPRC 进行评估
    test_accuracy, test_auprc, fpr, tpr = evaluate_classifier(classifier, test_features.reshape(-1, 1), test_labels)
    
    return test_accuracy, test_auprc, fpr, tpr

if __name__ == "__main__":
    LAYERS = LAYERS_PER_MODEL[MODEL]

    print(f"\n===== 加载数据集: SQuAD\n")
    
    # 获取数据集特定的常量
    ACTIVATION_FILE_LIST_DIR = "/guardrail/TaskTracker/store/data/" + RISK + "/" + DATASET
    ACTIVATIONS_DIR = '/guardrail/TaskTracker/store/activations/'+ RISK + "/" + DATASET + '/' + MODEL + '/training'
    ACTIVATIONS_VAL_DIR = '/guardrail/TaskTracker/store/activations/'+ RISK + "/" + DATASET + '/' + MODEL + '/validation'
    OOD_POISONED_FILE = "/guardrail/TaskTracker/store/output_datasets/"+ RISK + "/" + DATASET + '/dataset_poisoned_val.json'
    
    # 为当前数据集设置配置
    config = {
        'activations': ACTIVATIONS_DIR,
        'activations_ood': ACTIVATIONS_VAL_DIR,
        'ood_poisoned_file': OOD_POISONED_FILE,
        'exp_name': f'distance_based_classification_{MODEL}_{DATASET}',
    }
    
    # 为当前数据集设置输出目录
    OUTPUT_DIR = f"/guardrail/TaskTracker/store/model/{RISK}/{DATASET}/{MODEL}"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # 加载文件列表（只需加载一次，所有锚点数量共用相同的文件列表）
    train_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_clean_files_{MODEL}.txt'))
    train_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_poisoned_files_{MODEL}.txt'))
    val_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_clean_files_{MODEL}.txt'))
    val_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_poisoned_files_{MODEL}.txt'))
    test_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_clean_files_{MODEL}.txt'))
    test_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_poisoned_files_{MODEL}.txt'))
    
    print(f"Training with {len(train_files_clean)} clean files and {len(train_files_poisoned)} poisoned files.")
    print(f"Evaluating with {len(val_files_clean)} clean files and {len(val_files_poisoned)} poisoned files.")
    
    # 遍历不同锚点样本数量
    for n_anchors in ANCHOR_SAMPLES:
        #print(f"\n----- Using {n_anchors} anchor samples -----\n")
        
        # 初始化结果列表
        results = []
        
        # 处理每个层
        for n_layer in LAYERS:
            print(f"[*] 处理第 {n_layer} 层激活层")
            anchor_output_dir = os.path.join(OUTPUT_DIR, f"anchors_{n_anchors}")
            os.makedirs(anchor_output_dir, exist_ok=True)
            layer_output_dir = os.path.join(anchor_output_dir, str(n_layer))
            os.makedirs(layer_output_dir, exist_ok=True)

            _config = config.copy()
            _config["num_layers"] = n_layer
            _config["n_anchors"] = n_anchors
            _config["exp_name"] = f"{config['exp_name']}_{n_layer}_anchors_{n_anchors}"
            with open(os.path.join(layer_output_dir, 'config.json'), 'w') as f:
                json.dump(_config, f)

            # 训练和评估模型，传递激活目录和锚点数量
            print("使用三元组探测器进行处理")
            test_accuracy, test_auprc, fpr, tpr = train_and_evaluate_with_distance(
                train_files_clean, train_files_poisoned, 
                val_files_clean, val_files_poisoned, 
                test_files_clean, test_files_poisoned, 
                num_layers=(n_layer, n_layer),
                activations_dir=ACTIVATIONS_DIR,
                n_anchors=n_anchors
            )
            
            # 保存结果
            results.append({
                'layer': n_layer,
                'anchors': n_anchors,
                'test_accuracy': test_accuracy,
                'test_auprc': test_auprc,
                'fpr': fpr.tolist(),
                'tpr': tpr.tolist()
            })

        # 为当前数据集和锚点数量保存结果
        results_file = os.path.join(OUTPUT_DIR, f"{RISK}_{DATASET}_{MODEL}_results_anchors_{n_anchors}.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
        print(f"结果保存至 SQuAD_{MODEL}_results.json")