import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch.utils.data as data
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_auc_score
from torchvision import models
import os
import json
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score

# ============================================================
# 1. 数据加载部分
# ============================================================

def load_tensor_from_pt(file_path, key=None):
    """从 .pt 文件加载 tensor"""
    data = torch.load(file_path, map_location='cpu')

    if isinstance(data, list):
        data = [torch.stack([act.float() for act in round_acts]) for round_acts in data]
        data = torch.stack(data)
    
    if isinstance(data, torch.Tensor):
        data = data.float()
    
    if key is not None:
        return data[key]
    return data


def load_all_data():
    """加载所有数据和标签"""
    X = []
    labels = []
    
    for s in range(2, 11, 2):
        for a in range(1, 4):
            path = f"memory_attack/train_n6_s0{s}_a{a}/" if s < 10 else f"memory_attack/train_n6_s{s}_a{a}/"
            path_activation = os.path.join(path, "activations")
            
            # 获取 label
            json_files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.json')]
            path_labels = json_files[0]
            with open(path_labels, "r") as f:
                json_data = json.load(f)
                for task in json_data:
                    attackers = task["attacker_idxes"]
                    label = [0] * 6
                    for idx in attackers:
                        label[idx] = 1
                    label_2d = np.array(label).reshape(6, 1)
                    labels.append(label_2d)
            
            # 读取 tensor
            for task_id in range(0, 20):
                file_path = os.path.join(path_activation, f"sample_{task_id:04d}.pt")
                x = load_tensor_from_pt(file_path)  # (R,A,L,H)
                x = x[0, :, 0, :]  # (A, H) 只取 Round 0（第一轮）
                if isinstance(x, torch.Tensor):
                    x = x.numpy()
                X.append(x)
    
    X = np.stack(X, axis=0)  # (N, A, H)
    labels = np.stack(labels, axis=0)  # (N, A, 1)
    
    return X, labels


# ============================================================
# 2. 基于距离的异常检测
# ============================================================

def compute_knn_distance(embeddings, reference_embeddings, k=5):
    """
    计算每个样本到reference set的k近邻平均距离
    
    Args:
        embeddings: (N, H) - 待检测的样本
        reference_embeddings: (M, H) - 正常样本参考集
        k: k近邻的k值
    
    Returns:
        distances: (N,) - 每个样本的异常分数
    """
    distances = []
    for emb in embeddings:
        # 计算到所有reference的欧氏距离
        dists = np.linalg.norm(reference_embeddings - emb, axis=1)
        # 取最近的k个的平均距离作为异常分数
        knn_dist = np.mean(np.sort(dists)[:k])
        distances.append(knn_dist)
    return np.array(distances)


def calibrate_threshold(distances, percentile=95):
    """
    基于正常样本的距离分布校准阈值
    
    Args:
        distances: 正常样本的距离分数
        percentile: 百分位数（95表示容忍5%的误报）
    
    Returns:
        threshold: 检测阈值
    """
    threshold = np.percentile(distances, percentile)
    return threshold


def evaluate_distance_based(X_normal_train, X_normal_val, X_test_attack, k=5, percentile=95):
    """
    基于距离的异常检测评估
    
    Args:
        X_normal_train: (M, H) - 正常样本训练集（作为参考）
        X_normal_val: (N1, H) - 正常样本验证集（校准阈值）
        X_test_attack: (N2, H) - 攻击者测试集
        k: k近邻的k值
        percentile: 阈值百分位数
    
    Returns:
        results: 包含所有评估指标的字典
    """
    # 1. 计算验证集（正常）的距离
    val_distances = compute_knn_distance(X_normal_val, X_normal_train, k=k)
    
    # 2. 用验证集校准阈值
    threshold = calibrate_threshold(val_distances, percentile=percentile)
    
    # 3. 计算测试集（攻击者）的距离
    test_distances = compute_knn_distance(X_test_attack, X_normal_train, k=k)
    
    # 4. 预测
    val_pred = (val_distances > threshold).astype(int)  # 应该是0（正常）
    test_pred = (test_distances > threshold).astype(int)  # 应该是1（异常）
    
    # 5. 合并真实标签和预测
    y_true = np.concatenate([
        np.zeros(len(val_distances)),  # 验证集是正常
        np.ones(len(test_distances))   # 测试集是攻击者
    ])
    y_pred = np.concatenate([val_pred, test_pred])
    y_scores = np.concatenate([val_distances, test_distances])
    
    # 6. 计算指标
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    auroc = roc_auc_score(y_true, y_scores)
    
    # 7. 混淆矩阵
    conf_matrix = confusion_matrix(y_true, y_pred)
    TN, FP, FN, TP = conf_matrix.ravel()
    
    results = {
        'threshold': threshold,
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auroc': auroc,
        'confusion_matrix': conf_matrix,
        'TN': TN, 'FP': FP, 'FN': FN, 'TP': TP,
        'val_distances': val_distances,
        'test_distances': test_distances
    }
    
    return results


# ============================================================
# 3. 数据预处理
# ============================================================

def preprocess_data(X, labels):
    """
    分离正常和攻击者数据
    
    Args:
        X: (N, A, H) - 所有样本
        labels: (N, A, 1) - 标签
    
    Returns:
        X_normal: (M1, H) - 正常样本
        X_attack: (M2, H) - 攻击者样本
    """
    N, A, H = X.shape
    
    X_flat = X.reshape(-1, H)  # (N*A, H)
    labels_flat = labels.reshape(-1)  # (N*A,)
    
    # 分离正常和攻击者
    X_normal = X_flat[labels_flat == 0]
    X_attack = X_flat[labels_flat == 1]
    
    print(f"Total samples: {len(X_flat)}")
    print(f"Normal samples: {len(X_normal)}")
    print(f"Attack samples: {len(X_attack)}")
    
    return X_normal, X_attack


# ============================================================
# 4. 主函数
# ============================================================

def main():
    # 设置随机种子
    random.seed(42)
    np.random.seed(42)
    
    print("=" * 60)
    print("Distance-Based Anomaly Detection for MAS")
    print("=" * 60)
    
    # 1. 加载数据
    print("\n[Step 1] Loading data from Round 0 (first round)...")
    X, labels = load_all_data()
    print(f"X shape: {X.shape}")
    print(f"labels shape: {labels.shape}")
    
    # 2. 分离正常和攻击者数据
    print("\n[Step 2] Separating normal and attack samples...")
    X_normal, X_attack = preprocess_data(X, labels)
    
    # 3. 划分正常数据：训练/验证
    print("\n[Step 3] Splitting normal samples into train/val...")
    X_normal_train, X_normal_val = train_test_split(
        X_normal, test_size=0.2, random_state=42
    )
    
    # 4. 攻击者数据用于测试
    X_test_attack = X_attack
    
    print(f"\nNormal train (reference): {X_normal_train.shape}")
    print(f"Normal val (calibration): {X_normal_val.shape}")
    print(f"Attack test: {X_test_attack.shape}")
    
    # 5. 基于距离的异常检测
    print("\n" + "=" * 60)
    print("[Step 4] Distance-Based Anomaly Detection")
    print("=" * 60)
    
    # 可以尝试不同的k值和percentile
    best_f1 = 0
    best_k = None
    best_percentile = None
    best_results = None
    
    print("\nSearching for best hyperparameters...")
    for k in [3, 5, 7, 10]:
        for percentile in [90, 95, 99]:
            results = evaluate_distance_based(
                X_normal_train, X_normal_val, X_test_attack,
                k=k, percentile=percentile
            )
            
            if results['f1'] > best_f1:
                best_f1 = results['f1']
                best_k = k
                best_percentile = percentile
                best_results = results
            
            print(f"k={k}, percentile={percentile}: F1={results['f1']:.4f}, AUROC={results['auroc']:.4f}")
    
    # 6. 最终评估
    print("\n" + "=" * 60)
    print("Final Evaluation (Best Hyperparameters)")
    print("=" * 60)
    
    print(f"\nBest k: {best_k}")
    print(f"Best percentile: {best_percentile}")
    print(f"Threshold: {best_results['threshold']:.4f}")
    print(f"\nPerformance Metrics:")
    print(f"  Accuracy:  {best_results['accuracy']:.4f}")
    print(f"  Precision: {best_results['precision']:.4f}")
    print(f"  Recall:    {best_results['recall']:.4f}")
    print(f"  F1 Score:  {best_results['f1']:.4f}")
    print(f"  AUROC:     {best_results['auroc']:.4f}")
    
    print(f"\nConfusion Matrix:")
    print(best_results['confusion_matrix'])
    
    print(f"\nDetailed Breakdown:")
    print(f"  True Negatives  (Normal correctly identified): {best_results['TN']}")
    print(f"  False Positives (Normal misclassified):        {best_results['FP']}")
    print(f"  False Negatives (Attack missed):               {best_results['FN']}")
    print(f"  True Positives  (Attack correctly identified): {best_results['TP']}")
    
    # 7. 距离分布统计
    print(f"\nDistance Distribution:")
    print(f"  Normal samples (val):")
    print(f"    Mean: {np.mean(best_results['val_distances']):.4f}")
    print(f"    Std:  {np.std(best_results['val_distances']):.4f}")
    print(f"    Min:  {np.min(best_results['val_distances']):.4f}")
    print(f"    Max:  {np.max(best_results['val_distances']):.4f}")
    
    print(f"  Attack samples (test):")
    print(f"    Mean: {np.mean(best_results['test_distances']):.4f}")
    print(f"    Std:  {np.std(best_results['test_distances']):.4f}")
    print(f"    Min:  {np.min(best_results['test_distances']):.4f}")
    print(f"    Max:  {np.max(best_results['test_distances']):.4f}")
    
    print("\n" + "=" * 60)
    print("Done!")
    print("=" * 60)


if __name__ == "__main__":
    main()