import pandas as pd
import numpy as np

# ----------------------------------------------------------------------------
# 1. 实验主框架 (EXPERIMENT FRAMEWORK)
# ----------------------------------------------------------------------------

class ComponentChoiceExperiments:
    """
    封装了关于ALCA核心组件选择的两个关键消融实验。
    """
    def __init__(self):
        print("Setting up Component Choice Experiments...")
        self.results_target = []
        self.results_trigger = []

    # --- Experiment 1: Latent Representation Target (Section 4.5) ---
    
    def run_latent_target_experiment(self):
        """
        模拟对比不同潜在表示目标方法的效果。
        """
        print("\n--- Running Experiment 1: Choice of Latent Representation Target ---")
        
        target_methods = {
            "Attention-Weighted Pool": {
                "description": "Averages all states, creating a 'blurred' representation.",
                "performance_score": 0.82  # Corresponds to lower performance
            },
            "Mean Pooling": {
                "description": "Simple average of all states, slightly better but still not precise.",
                "performance_score": 0.90
            },
            "Last Token Hidden State": {
                "description": "Architecturally aligned with autoregressive prediction, most precise.",
                "performance_score": 0.98  # Corresponds to the best performance
            }
        }

        for method, props in target_methods.items():
            print(f"  > Evaluating method: {method}")
            print(f"    - Rationale: {props['description']}")
            
            # 模拟性能：分数越高，ASR越低，语义相似度越高
            score = props['performance_score']
            
            # ASR与性能分数成反比。基准ASR为20%，分数越高，ASR越接近0。
            adaptive_asr = 20.0 * (1.05 - score) 
            # 语义相似度与性能分数成正比。基准为1.0。
            semantic_similarity = 0.80 + (score - 0.8) * 0.9 
            
            self.results_target.append({
                "Target Vector Method": method,
                "ASR (Adap.) ↓": adaptive_asr,
                "Sem. Sim. ↑": semantic_similarity
            })
            
    def print_target_results(self):
        """以表格形式打印Latent Target实验的结果，模仿Table 4。"""
        if not self.results_target:
            print("No results for Latent Target experiment.")
            return

        df = pd.DataFrame(self.results_target).set_index("Target Vector Method")
        
        # 调整顺序
        desired_order = ["Attention-Weighted Pool", "Mean Pooling", "Last Token Hidden State"]
        df = df.reindex(desired_order)
        
        print("\n\n" + "="*60)
        print("     EXPERIMENT 1 RESULTS: LATENT REPRESENTATION TARGET")
        print("          (Similar to Table 4 in the paper)")
        print("="*60)
        print(df.to_string(float_format="%.2f"))
        print("="*60)

    # --- Experiment 2: Trigger Selection (Section 4.6) ---

    def run_trigger_selection_experiment(self):
        """
        模拟对比不同触发机制的效果。
        """
        print("\n--- Running Experiment 2: The Selection of Trigger ---")
        
        # 模拟一个包含1000个样本的平衡数据集
        num_samples = 1000
        true_labels = np.array([1] * (num_samples // 2) + [0] * (num_samples // 2)) # 1=harmful, 0=harmless
        
        trigger_mechanisms = {
            "Internal Special Token": {
                "description": "Compromised by multi-task interference, lower recall.",
                "true_positive_rate": 0.94,  # Lower recall (misses some harmful prompts)
                "true_negative_rate": 0.96
            },
            "External Probe (Ours)": {
                "description": "Dedicated classifier, high precision and recall.",
                "true_positive_rate": 0.99,  # Nearly perfect recall
                "true_negative_rate": 0.98
            }
        }

        for mechanism, props in trigger_mechanisms.items():
            print(f"  > Evaluating mechanism: {mechanism}")
            print(f"    - Rationale: {props['description']}")
            
            predictions = []
            for label in true_labels:
                if label == 1: # Harmful
                    pred = 1 if random.random() < props['true_positive_rate'] else 0
                else: # Harmless
                    pred = 0 if random.random() < props['true_negative_rate'] else 1
                predictions.append(pred)
            
            # 计算 TP, FP, FN
            tp = np.sum((true_labels == 1) & (np.array(predictions) == 1))
            fp = np.sum((true_labels == 0) & (np.array(predictions) == 1))
            fn = np.sum((true_labels == 1) & (np.array(predictions) == 0))
            
            # 计算指标
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            self.results_trigger.append({
                "Trigger Mechanism": mechanism,
                "Precision": precision,
                "Recall": recall,
                "F1-Score": f1_score
            })

    def print_trigger_results(self):
        """以表格形式打印Trigger Selection实验的结果，模仿Table 5。"""
        if not self.results_trigger:
            print("No results for Trigger Selection experiment.")
            return

        df = pd.DataFrame(self.results_trigger).set_index("Trigger Mechanism")
        
        print("\n\n" + "="*60)
        print("         EXPERIMENT 2 RESULTS: TRIGGER SELECTION")
        print("          (Similar to Table 5 in the paper)")
        print("="*60)
        print(df.to_string(float_format="%.2f"))
        print("="*60)

# ----------------------------------------------------------------------------
# 2. 运行实验 (RUN THE EXPERIMENT)
# ----------------------------------------------------------------------------

if __name__ == "__main__":
    experiment = ComponentChoiceExperiments()
    
    # 运行并打印第一个实验的结果
    experiment.run_latent_target_experiment()
    experiment.print_target_results()
    
    # 运行并打印第二个实验的结果
    experiment.run_trigger_selection_experiment()
    experiment.print_trigger_results()