import os
import sys
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import euclidean_distances
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, roc_curve, average_precision_score
from sklearn.model_selection import train_test_split
import torch
from sklearn.linear_model import LogisticRegression

MODEL = "llama3_8b"
OUTPUT_DIR = MODEL 
os.makedirs(OUTPUT_DIR,exist_ok=True)

from task_tracker.training.dataset import ActivationsDatasetDynamic, ActivationsDatasetDynamicPrimaryText
from task_tracker.training.helpers.data import load_file_paths
from task_tracker.training.utils.constants import CONSTANTS_ALL_MODELS, OOD_POISONED_FILE

LAYERS_PER_MODEL = {
    'llama3_70b': [0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79 ],
    'phi3': [0, 7, 15, 23, 31],
    'mixtral': [0, 7, 15, 23, 31],
    'mistral': [0, 7, 15, 23, 31],
    'llama3_8b': [0, 7, 15, 23, 31],
    'mistral_no_priming': [0, 7, 15, 23, 31],
}


ACTIVATION_FILE_LIST_DIR, ACTIVATIONS_DIR, ACTIVATIONS_VAL_DIR =\
CONSTANTS_ALL_MODELS[MODEL]['ACTIVATION_FILE_LIST_DIR'], CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_DIR'], CONSTANTS_ALL_MODELS[MODEL]['ACTIVATIONS_VAL_DIR']

# Configuration settings
config = {
    'activations': ACTIVATIONS_DIR,
    'activations_ood': ACTIVATIONS_VAL_DIR,
    'ood_poisoned_file': OOD_POISONED_FILE,
    'exp_name': 'logistic_regression_' + MODEL,
}


def train_model(train_files, num_layers):
    print("Loading dataset.")
    dataset = ActivationsDatasetDynamic(train_files, root_dir=config['activations'], num_layers=num_layers)

    print("Processing dataset.")
    clean_diff = []
    poisoned_diff = []
    for primary, clean, poisoned in tqdm(dataset):
        clean_diff.append((clean - primary).flatten().float().numpy())
        poisoned_diff.append((poisoned - primary).flatten().float().numpy())
    y = [0]*len(dataset) + [1]*len(dataset)
    X = clean_diff + poisoned_diff

    print("Training logistic regression classifier.")
    model = LogisticRegression()
    model.fit(X, y)

    return model

def load_evaluation_data(val_files_clean, val_files_poisoned, num_layers):
    print("Loading validation datasets.")
    clean_dataset = ActivationsDatasetDynamicPrimaryText(val_files_clean, num_layers=num_layers, root_dir=config.get('activations_ood'))
    poisoned_dataset = ActivationsDatasetDynamicPrimaryText(val_files_poisoned, num_layers=num_layers, root_dir=config.get('activations_ood'))

    print("Processing validation datasets.")
    clean_diff = []
    for primary, clean_with_text in tqdm(clean_dataset):
        clean_diff.append((clean_with_text - primary).flatten().float().numpy())
    poisoned_diff = []
    for primary, poisoned_with_text in tqdm(poisoned_dataset):
        poisoned_diff.append((poisoned_with_text - primary).flatten().float().numpy())
    X_validation = np.array(clean_diff + poisoned_diff)
    y_validation = [0]*len(clean_diff) + [1]*len(poisoned_diff)

    return X_validation, y_validation


if __name__ == "__main__":
    LAYERS = LAYERS_PER_MODEL[MODEL]

    for n_layer in LAYERS:
        print(f"[*] Training model for the {n_layer}-th activation layer.")
        os.makedirs(os.path.join(OUTPUT_DIR,str(n_layer)),exist_ok=True)
        layer_output_dir = os.path.join(OUTPUT_DIR,str(n_layer))
        

        _config = config.copy()
        _config["num_layers"] = n_layer
        _config["exp_name"] = f"{config['exp_name']}_{n_layer}"
        print(_config["exp_name"])
        with open(os.path.join(layer_output_dir,'config.json'), 'w') as f:
            json.dump(_config, f)

        # Train the model.
        train_files = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'train_files_' + MODEL + '.txt'))
        val_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'val_clean_files_' + MODEL + '.txt'))
        val_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, 'val_poisoned_files_' + MODEL + '.txt'))

        print(f"Training model with {len(train_files)} files.")
        print(f"Evaluating model with {len(val_files_clean)} clean files and {len(val_files_poisoned)} poisoned files.")

        model = train_model(train_files, num_layers=(n_layer, n_layer))
        pickle.dump(model, open(os.path.join(layer_output_dir,'model.pickle'), "wb"))


        # Evaluate.
        X_eval, y_eval = load_evaluation_data(val_files_clean, val_files_poisoned, num_layers=(n_layer, n_layer))
        accuracy = model.score(X_eval, y_eval)
        print(accuracy)
        print("\n"*4)

sys.path.append(os.path.abspath("."))
from task_tracker.CONFIG_1 import current_risk

# 定义要处理的数据集和锚点样本数量

DATASETS = ["FinQA", "hotpotqa", "Msmarco"]
ANCHOR_SAMPLES = [1000]
MODEL = "llama3_8b"

# 保留模型层定义
LAYERS_PER_MODEL = {
    'llama3_70b': [0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79],
    'phi3': [0, 7, 15, 23, 31],
    'mixtral': [0, 7, 15, 23, 31],
    'mistral': [0, 7, 15, 23, 31],
    'llama3_8b': [0, 7, 15, 23, 31],
    'mistral_no_priming': [0, 7, 15, 23, 31],
    'vicuna': [0, 7, 15, 23, 31],
}

from task_tracker.training.dataset import ActivationsDatasetDynamic, ActivationsDatasetDynamicPrimaryText
from task_tracker.training.helpers.data import load_file_paths
#from task_tracker.training.utils.constants_1 import CONSTANTS_ALL_MODELS, OOD_POISONED_FILE

# Function to train a binary classifier
def train_binary_classifier(train_activations, train_labels):
    print("[*] Training Random Forest Classifier")
    classifier = RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10, min_samples_split=5)
    classifier.fit(train_activations, train_labels)
    return classifier

def load_activations(file_paths, num_layers, activations_dir):
    dataset = ActivationsDatasetDynamic(file_paths, root_dir=activations_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_val_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/validation')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

def load_test_activations(file_paths, num_layers, activations_dir):
    root_dir = activations_dir.replace('/training', '/test')
    dataset = ActivationsDatasetDynamic(file_paths, root_dir, num_layers=num_layers)
    activations = []
    for activation in tqdm(dataset):
        activations.append(activation.flatten().float().numpy())
    return np.array(activations)

# Add a print statement to debug file paths
def load_file_paths(file_path):
    with open(file_path, 'r') as f:
        paths = f.read().splitlines()
    for path in paths:
        print(f"Loading file: {path}")
    return paths

def generate_distance_features(samples, clean_anchors):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    clean_anchors_tensor = torch.tensor(clean_anchors).to(device)
    features = []
    for sample in tqdm(samples):
        sample_tensor = torch.tensor(sample).to(device)
        expanded_sample = sample_tensor.expand_as(clean_anchors_tensor)
        distances = torch.norm(expanded_sample - clean_anchors_tensor, dim=1)
        avg_distance = distances.mean().item()
        features.append(avg_distance)
    return np.array(features)

def evaluate_classifier(classifier, test_activations, test_labels):
    print("[*] Evaluating the classifier")
    predictions = classifier.predict(test_activations)
    accuracy = accuracy_score(test_labels, predictions)
    
    # 计算 AUPRC
    auprc = average_precision_score(test_labels, classifier.predict_proba(test_activations)[:, 1])
    
    fpr, tpr, _ = roc_curve(test_labels, classifier.predict_proba(test_activations)[:, 1])
    
    print(f"Test accuracy: {accuracy}")
    print(f"AUPRC score: {auprc}")
    print(classification_report(test_labels, predictions))
    return accuracy, auprc, fpr, tpr

def train_and_evaluate_with_distance(train_files_clean, train_files_poisoned, val_files_clean, val_files_poisoned, 
                                    test_files_clean, test_files_poisoned, num_layers, activations_dir, n_anchors):
    """
    训练和评估模型，使用指定数量的锚点样本
    
    Args:
        n_anchors: 锚点样本数量
    """
    # 加载训练集 activations
    print("Loading training activations.")
    train_clean_activations = load_activations(train_files_clean, num_layers, activations_dir)
    train_poisoned_activations = load_activations(train_files_poisoned, num_layers, activations_dir)

    # 加载测试集 activations
    print("Loading test activations.")
    test_clean_activations = load_test_activations(test_files_clean, num_layers, activations_dir)
    test_poisoned_activations = load_test_activations(test_files_poisoned, num_layers, activations_dir)

    # 加载验证集 activations
    print("Loading validation activations.")
    val_clean_activations = load_val_activations(val_files_clean, num_layers, activations_dir)
    val_poisoned_activations = load_val_activations(val_files_poisoned, num_layers, activations_dir)

    # 设置指定数量的 clean 锚点样本
    n = min(len(train_clean_activations), n_anchors)
    print(f"Using {n} clean anchor samples")
    clean_anchors = train_clean_activations[:n]

    # 生成训练集特征和标签
    print("Generating training features.")
    train_features_clean = generate_distance_features(train_clean_activations, clean_anchors)
    train_features_poisoned = generate_distance_features(train_poisoned_activations, clean_anchors)
    train_features = np.concatenate([train_features_clean, train_features_poisoned], axis=0)
    train_labels = np.array([0] * len(train_features_clean) + [1] * len(train_features_poisoned))

    # 生成验证集特征和标签
    print("Generating validation features.")
    val_features_clean = generate_distance_features(val_clean_activations, clean_anchors)
    val_features_poisoned = generate_distance_features(val_poisoned_activations, clean_anchors)
    val_features = np.concatenate([val_features_clean, val_features_poisoned], axis=0)
    val_labels = np.array([0] * len(val_features_clean) + [1] * len(val_features_poisoned))

    # 训练二分类器
    classifier = train_binary_classifier(train_features.reshape(-1, 1), train_labels)

    # 在验证集上评估模型
    print("Evaluating on validation set.")
    evaluate_classifier(classifier, val_features.reshape(-1, 1), val_labels)

    # 在测试集上评估模型
    print("Evaluating on test set.")
    test_features_clean = generate_distance_features(test_clean_activations, clean_anchors)
    test_features_poisoned = generate_distance_features(test_poisoned_activations, clean_anchors)
    test_features = np.concatenate([test_features_clean, test_features_poisoned], axis=0)
    test_labels = np.array([0] * len(test_features_clean) + [1] * len(test_features_poisoned))
    
    # 使用 AUPRC 进行评估
    test_accuracy, test_auprc, fpr, tpr = evaluate_classifier(classifier, test_features.reshape(-1, 1), test_labels)
    
    return test_accuracy, test_auprc, fpr, tpr

if __name__ == "__main__":
    LAYERS = LAYERS_PER_MODEL[MODEL]
    for risk in RISKS:
        print(f"\n===== Processing risk: {risk} =====\n")
        # 迭代处理每个数据集
        for dataset in DATASETS:
            print(f"\n===== Processing dataset: {dataset} =====\n")
            
            # 获取数据集特定的常量
            ACTIVATION_FILE_LIST_DIR = "/guardrail/TaskTracker/store/data/" + risk + "/" + dataset
            ACTIVATIONS_DIR = '/guardrail/TaskTracker/store/activations/'+ risk + "/" + dataset + '/' + MODEL + '/training'
            ACTIVATIONS_VAL_DIR = '/guardrail/TaskTracker/store/activations/'+ risk + "/" + dataset + '/' + MODEL + '/validation'
            OOD_POISONED_FILE = "/guardrail/TaskTracker/store/output_datasets/"+ risk + "/" + dataset + '/dataset_poisoned_val.json'
            # 为当前数据集设置配置
            config = {
                'activations': ACTIVATIONS_DIR,
                'activations_ood': ACTIVATIONS_VAL_DIR,
                'ood_poisoned_file': OOD_POISONED_FILE,
                'exp_name': f'distance_based_classification_{MODEL}_{dataset}',
            }
            
            # 为当前数据集设置输出目录
            OUTPUT_DIR = f"/guardrail/TaskTracker/store/model/{risk}/{dataset}/{MODEL}"
            os.makedirs(OUTPUT_DIR, exist_ok=True)
            
            # 加载文件列表（只需加载一次，所有锚点数量共用相同的文件列表）
            train_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_clean_files_{MODEL}.txt'))
            train_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'train_poisoned_files_{MODEL}.txt'))
            val_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_clean_files_{MODEL}.txt'))
            val_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'val_poisoned_files_{MODEL}.txt'))
            test_files_clean = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_clean_files_{MODEL}.txt'))
            test_files_poisoned = load_file_paths(os.path.join(ACTIVATION_FILE_LIST_DIR, f'test_poisoned_files_{MODEL}.txt'))
            
            print(f"Training with {len(train_files_clean)} clean files and {len(train_files_poisoned)} poisoned files.")
            print(f"Evaluating with {len(val_files_clean)} clean files and {len(val_files_poisoned)} poisoned files.")
            
            # 遍历不同锚点样本数量
            for n_anchors in ANCHOR_SAMPLES:
                print(f"\n----- Using {n_anchors} anchor samples -----\n")
                
                # 初始化结果列表
                results = []
                
                # 处理每个层
                for n_layer in LAYERS:
                    print(f"[*] Processing the {n_layer}-th activation layer with {n_anchors} anchors.")
                    anchor_output_dir = os.path.join(OUTPUT_DIR, f"anchors_{n_anchors}")
                    os.makedirs(anchor_output_dir, exist_ok=True)
                    layer_output_dir = os.path.join(anchor_output_dir, str(n_layer))
                    os.makedirs(layer_output_dir, exist_ok=True)

                    _config = config.copy()
                    _config["num_layers"] = n_layer
                    _config["n_anchors"] = n_anchors
                    _config["exp_name"] = f"{config['exp_name']}_{n_layer}_anchors_{n_anchors}"
                    print(_config["exp_name"])
                    with open(os.path.join(layer_output_dir, 'config.json'), 'w') as f:
                        json.dump(_config, f)

                    # 训练和评估模型，传递激活目录和锚点数量
                    test_accuracy, test_auprc, fpr, tpr = train_and_evaluate_with_distance(
                        train_files_clean, train_files_poisoned, 
                        val_files_clean, val_files_poisoned, 
                        test_files_clean, test_files_poisoned, 
                        num_layers=(n_layer, n_layer),
                        activations_dir=ACTIVATIONS_DIR,
                        n_anchors=n_anchors
                    )
                    
                    # 保存结果
                    results.append({
                        'layer': n_layer,
                        'anchors': n_anchors,
                        'test_accuracy': test_accuracy,
                        'test_auprc': test_auprc,
                        'fpr': fpr.tolist(),
                        'tpr': tpr.tolist()
                    })

                # 为当前数据集和锚点数量保存结果
                results_file = os.path.join(OUTPUT_DIR, f"{risk}_{dataset}_{MODEL}_results_anchors_{n_anchors}.json")
                with open(results_file, 'w', encoding='utf-8') as f:
                    json.dump(results, f, ensure_ascii=False, indent=4)
                print(f"Results for {dataset} with {n_anchors} anchors saved to {results_file}")