import torch
import numpy as np
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest
from sklearn.covariance import EllipticEnvelope
from sklearn.svm import OneClassSVM
from sklearn.neighbors import LocalOutlierFactor
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split
import os
import json
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================
# 1. 数据加载（复用你的代码）
# ============================================================

def load_tensor_from_pt(file_path, key=None):
    """从 .pt 文件加载 tensor"""
    data = torch.load(file_path, map_location='cpu')

    if isinstance(data, list):
        data = [torch.stack([act.float() for act in round_acts]) for round_acts in data]
        data = torch.stack(data)
    
    if isinstance(data, torch.Tensor):
        data = data.float()
    
    if key is not None:
        return data[key]
    return data


def load_all_data():
    """加载所有数据和标签"""
    X = []
    labels = []
    
    for s in range(2, 11, 2):
        for a in range(1, 4):
            path = f"agent_graph_dataset/memory_attack/train_n8_s0{s}_a{a}/" if s < 10 else f"agent_graph_dataset/memory_attack/train_n8_s{s}_a{a}/"
            path_activation = os.path.join(path, "activations")
            
            # 获取 label
            json_files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.json')]
            path_labels = json_files[0]
            with open(path_labels, "r") as f:
                json_data = json.load(f)
                for task in json_data:
                    attackers = task["attacker_idxes"]
                    label = [0] * 8
                    for idx in attackers:
                        label[idx] = 1
                    label_2d = np.array(label).reshape(8, 1)
                    labels.append(label_2d)
            
            # 读取 tensor
            for task_id in range(0, 20):
                file_path = os.path.join(path_activation, f"sample_{task_id:04d}.pt")
                x = load_tensor_from_pt(file_path)  # (R,A,L,H)
                x = x[0, :, -1, :]  # (A, H) 只取 Round 0
                if isinstance(x, torch.Tensor):
                    x = x.numpy()
                X.append(x)
    
    X = np.stack(X, axis=0)  # (N, A, H)
    labels = np.stack(labels, axis=0)  # (N, A, 1)
    
    return X, labels


def preprocess_data_one_class(X, labels):
    """分离正常和攻击者数据"""
    N, A, H = X.shape
    
    X_flat = X.reshape(-1, H)  # (N*A, H)
    labels_flat = labels.reshape(-1)  # (N*A,)
    
    # 分离正常和攻击者
    X_normal = X_flat[labels_flat == 0]
    X_attack = X_flat[labels_flat == 1]
    
    print(f"\n{'='*60}")
    print(f"Data Statistics:")
    print(f"{'='*60}")
    print(f"Total samples:    {len(X_flat)}")
    print(f"Normal samples:   {len(X_normal)} ({len(X_normal)/len(X_flat)*100:.1f}%)")
    print(f"Attack samples:   {len(X_attack)} ({len(X_attack)/len(X_flat)*100:.1f}%)")
    print(f"Feature dim:      {H}")
    print(f"{'='*60}\n")
    
    return X_normal, X_attack


# ============================================================
# 2. Baseline方法
# ============================================================

class BaselineDetector:
    """Baseline异常检测器基类"""
    
    def __init__(self, name, use_pca=True, n_components=50):
        self.name = name
        self.use_pca = use_pca
        self.n_components = n_components
        self.pca = None
        self.detector = None
        
    def fit(self, X_normal):
        """在正常数据上训练"""
        # PCA降维（可选）
        if self.use_pca:
            self.pca = PCA(n_components=min(self.n_components, X_normal.shape[0]-1, X_normal.shape[1]))
            X_reduced = self.pca.fit_transform(X_normal)
            print(f"  PCA: {X_normal.shape[1]} -> {X_reduced.shape[1]} dims")
        else:
            X_reduced = X_normal
        
        # 训练检测器
        self._fit_detector(X_reduced)
        
        return self
    
    def predict_scores(self, X):
        """返回异常分数（越大越异常）"""
        if self.use_pca:
            X_reduced = self.pca.transform(X)
        else:
            X_reduced = X
        
        return self._predict_scores(X_reduced)
    
    def predict(self, X, threshold):
        """返回0/1标签"""
        scores = self.predict_scores(X)
        return (scores > threshold).astype(int)
    
    def _fit_detector(self, X):
        raise NotImplementedError
    
    def _predict_scores(self, X):
        raise NotImplementedError


class IsolationForestDetector(BaselineDetector):
    """Isolation Forest"""
    
    def __init__(self, use_pca=True, n_components=50):
        super().__init__("Isolation Forest", use_pca, n_components)
    
    def _fit_detector(self, X):
        self.detector = IsolationForest(
            contamination=0.1,
            random_state=42,
            n_jobs=-1
        )
        self.detector.fit(X)
    
    def _predict_scores(self, X):
        # IF返回的是异常分数，负数表示异常，转换为正数
        scores = -self.detector.score_samples(X)
        return scores


class EllipticEnvelopeDetector(BaselineDetector):
    """Elliptic Envelope (基于协方差的异常检测)"""
    
    def __init__(self, use_pca=True, n_components=50):
        super().__init__("Elliptic Envelope", use_pca, n_components)
    
    def _fit_detector(self, X):
        self.detector = EllipticEnvelope(
            contamination=0.1,
            random_state=42
        )
        self.detector.fit(X)
    
    def _predict_scores(self, X):
        # Mahalanobis距离
        scores = -self.detector.score_samples(X)
        return scores


class OneClassSVMDetector(BaselineDetector):
    """One-Class SVM"""
    
    def __init__(self, use_pca=True, n_components=50):
        super().__init__("One-Class SVM", use_pca, n_components)
    
    def _fit_detector(self, X):
        self.detector = OneClassSVM(
            kernel='rbf',
            gamma='auto',
            nu=0.1
        )
        self.detector.fit(X)
    
    def _predict_scores(self, X):
        # 距离超平面的距离
        scores = -self.detector.score_samples(X)
        return scores


class LOFDetector(BaselineDetector):
    """Local Outlier Factor"""
    
    def __init__(self, use_pca=True, n_components=50):
        super().__init__("LOF", use_pca, n_components)
    
    def _fit_detector(self, X):
        self.detector = LocalOutlierFactor(
            n_neighbors=20,
            contamination=0.1,
            novelty=True,  # 重要：用于预测新样本
            n_jobs=-1
        )
        self.detector.fit(X)
    
    def _predict_scores(self, X):
        scores = -self.detector.score_samples(X)
        return scores


class SimpleDistanceDetector(BaselineDetector):
    """简单距离方法：到正常样本中心的距离"""
    
    def __init__(self, use_pca=True, n_components=50):
        super().__init__("Distance to Center", use_pca, n_components)
        self.center = None
    
    def _fit_detector(self, X):
        # 计算正常样本的中心
        self.center = X.mean(axis=0)
    
    def _predict_scores(self, X):
        # 计算到中心的欧式距离
        distances = np.linalg.norm(X - self.center, axis=1)
        return distances


# ============================================================
# 3. 评估函数
# ============================================================

def evaluate_detector(detector, X_normal_train, X_normal_val, X_attack_test):
    """评估单个检测器"""
    
    print(f"\n{'='*60}")
    print(f"Training: {detector.name}")
    print(f"{'='*60}")
    
    # 训练
    detector.fit(X_normal_train)
    
    # 获取验证集和测试集的分数
    val_scores = detector.predict_scores(X_normal_val)
    test_scores = detector.predict_scores(X_attack_test)
    
    # 用验证集确定阈值（95分位数）
    threshold = np.percentile(val_scores, 95)
    print(f"  Threshold (95th percentile): {threshold:.4f}")
    
    # 预测
    val_pred = (val_scores > threshold).astype(int)
    test_pred = (test_scores > threshold).astype(int)
    
    # 合并真实标签和预测
    y_true = np.concatenate([
        np.zeros(len(val_pred)),   # 验证集都是正常
        np.ones(len(test_pred))    # 测试集都是攻击
    ])
    y_pred = np.concatenate([val_pred, test_pred])
    y_scores = np.concatenate([val_scores, test_scores])
    
    # 计算指标
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    auroc = roc_auc_score(y_true, y_scores)
    
    # 混淆矩阵
    conf_matrix = confusion_matrix(y_true, y_pred)
    
    results = {
        'name': detector.name,
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auroc': auroc,
        'threshold': threshold,
        'confusion_matrix': conf_matrix,
        'val_scores': val_scores,
        'test_scores': test_scores
    }
    
    # 打印结果
    print(f"\n  Results:")
    print(f"  ├─ Accuracy:   {acc:.4f}")
    print(f"  ├─ Precision:  {precision:.4f}")
    print(f"  ├─ Recall:     {recall:.4f}")
    print(f"  ├─ F1 Score:   {f1:.4f}")
    print(f"  └─ AUROC:      {auroc:.4f}")
    
    if conf_matrix.shape == (2, 2):
        TN, FP, FN, TP = conf_matrix.ravel()
        print(f"\n  Confusion Matrix:")
        print(f"  ├─ TN (Normal correct):  {TN}")
        print(f"  ├─ FP (Normal as Attack): {FP}")
        print(f"  ├─ FN (Attack as Normal): {FN}")
        print(f"  └─ TP (Attack correct):  {TP}")
    
    return results


# ============================================================
# 4. 可视化
# ============================================================

def plot_results(all_results, save_path='baseline_comparison.png'):
    """绘制对比图"""
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    methods = [r['name'] for r in all_results]
    
    # 1. 准确率对比
    ax = axes[0, 0]
    metrics = ['accuracy', 'precision', 'recall', 'f1']
    x = np.arange(len(methods))
    width = 0.2
    
    for i, metric in enumerate(metrics):
        values = [r[metric] for r in all_results]
        ax.bar(x + i*width, values, width, label=metric.capitalize())
    
    ax.set_xlabel('Method')
    ax.set_ylabel('Score')
    ax.set_title('Performance Comparison')
    ax.set_xticks(x + width * 1.5)
    ax.set_xticklabels(methods, rotation=45, ha='right')
    ax.legend()
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1])
    
    # 2. AUROC对比
    ax = axes[0, 1]
    aurocs = [r['auroc'] for r in all_results]
    colors = plt.cm.viridis(np.linspace(0, 1, len(methods)))
    bars = ax.bar(methods, aurocs, color=colors)
    ax.set_ylabel('AUROC')
    ax.set_title('AUROC Comparison')
    ax.set_xticklabels(methods, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.3)
    ax.set_ylim([0, 1])
    
    # 在柱子上标注数值
    for bar, auroc in zip(bars, aurocs):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{auroc:.3f}',
                ha='center', va='bottom', fontsize=9)
    
    # 3. 分数分布（箱线图）
    ax = axes[1, 0]
    box_data = []
    labels = []
    for r in all_results:
        box_data.extend([r['val_scores'], r['test_scores']])
        labels.extend([f"{r['name']}\n(Normal)", f"{r['name']}\n(Attack)"])
    
    bp = ax.boxplot(box_data, labels=labels, patch_artist=True)
    
    # 给Normal和Attack上不同颜色
    for i, patch in enumerate(bp['boxes']):
        if i % 2 == 0:
            patch.set_facecolor('lightblue')
        else:
            patch.set_facecolor('lightcoral')
    
    ax.set_ylabel('Anomaly Score')
    ax.set_title('Score Distribution')
    ax.tick_params(axis='x', rotation=90)
    ax.grid(axis='y', alpha=0.3)
    
    # 4. 混淆矩阵热图
    ax = axes[1, 1]
    # 选择最好的方法展示混淆矩阵
    best_idx = np.argmax([r['auroc'] for r in all_results])
    best_result = all_results[best_idx]
    
    cm = best_result['confusion_matrix']
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=['Normal', 'Attack'],
                yticklabels=['Normal', 'Attack'])
    ax.set_title(f'Confusion Matrix: {best_result["name"]}\n(Best AUROC)')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"\n{'='*60}")
    print(f"Plot saved to: {save_path}")
    print(f"{'='*60}\n")


# ============================================================
# 5. 主函数
# ============================================================

def main():
    # 设置随机种子
    random.seed(42)
    np.random.seed(42)
    
    print("\n" + "="*60)
    print("BASELINE ANOMALY DETECTION METHODS")
    print("="*60)
    
    # 1. 加载数据
    print("\nLoading data from Round 0...")
    X, labels = load_all_data()
    print(f"X shape: {X.shape}")
    print(f"labels shape: {labels.shape}")
    
    # 2. 分离正常和攻击者数据
    X_normal, X_attack = preprocess_data_one_class(X, labels)
    
    # 3. 划分数据
    X_normal_train, X_normal_val = train_test_split(
        X_normal, test_size=0.2, random_state=42
    )
    X_attack_test = X_attack
    
    print(f"Training on:   {X_normal_train.shape[0]} normal samples")
    print(f"Validating on: {X_normal_val.shape[0]} normal samples")
    print(f"Testing on:    {X_attack_test.shape[0]} attack samples")
    
    # 4. 定义所有baseline方法
    detectors = [
        IsolationForestDetector(use_pca=True, n_components=50),
        EllipticEnvelopeDetector(use_pca=True, n_components=50),
        OneClassSVMDetector(use_pca=True, n_components=50),
        LOFDetector(use_pca=True, n_components=50),
        SimpleDistanceDetector(use_pca=True, n_components=50),
    ]
    
    # 5. 评估所有方法
    all_results = []
    for detector in detectors:
        try:
            results = evaluate_detector(
                detector, 
                X_normal_train, 
                X_normal_val, 
                X_attack_test
            )
            all_results.append(results)
        except Exception as e:
            print(f"\n  ✗ Error: {e}")
    
    # 6. 汇总结果
    print("\n" + "="*60)
    print("SUMMARY OF ALL METHODS")
    print("="*60)
    print(f"\n{'Method':<25} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1':<10} {'AUROC':<10}")
    print("-"*85)
    
    for r in all_results:
        print(f"{r['name']:<25} {r['accuracy']:<10.4f} {r['precision']:<10.4f} "
              f"{r['recall']:<10.4f} {r['f1']:<10.4f} {r['auroc']:<10.4f}")
    
    # 7. 找最佳方法
    best_idx = np.argmax([r['auroc'] for r in all_results])
    best_result = all_results[best_idx]
    
    print("\n" + "="*60)
    print(f"🏆 BEST METHOD: {best_result['name']}")
    print("="*60)
    print(f"AUROC: {best_result['auroc']:.4f}")
    print(f"Accuracy: {best_result['accuracy']:.4f}")
    print(f"F1 Score: {best_result['f1']:.4f}")
    print("="*60 + "\n")
    
    # 8. 可视化
    plot_results(all_results, save_path='baseline_comparison.png')


if __name__ == "__main__":
    main()