import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------------------------------------------------------------------
# 1. 模拟核心组件 (MOCK COMPONENTS)
# ----------------------------------------------------------------------------

class MockAblationModel:
    """
    一个模拟的模型，其行为会根据训练配置和当前的训练步数动态变化。
    """
    def __init__(self, config: dict):
        self.config = config
        self.name = config["name"]
        
        # 模型内部状态，模拟学习过程
        # 'semantic_scaffolding' 代表由 L_latent 提供的结构化防御知识
        # 'info_completeness' 代表由 L_causal/L_decode 压力保留的信息完整度
        self.semantic_scaffolding = 0.0
        self.info_completeness = 0.0

    def train_step(self, step: int, total_steps: int):
        """模拟一个训练步骤，更新模型内部状态。"""
        progress = step / total_steps

        # 根据配置更新内部状态
        if self.config["use_latent_loss"]:
            # Latent损失帮助模型快速学习防御结构
            # Hybrid模型由于有双重任务，学习稍慢但更稳定
            learning_rate = 0.9 if self.name == "Latent-Only" else 0.7
            self.semantic_scaffolding = min(1.0, self.semantic_scaffolding + learning_rate * (1 - progress) / 20)
        
        if self.config["use_causal_loss"]:
            # Causal损失确保信息完整性
            self.info_completeness = min(1.0, self.info_completeness + 0.8 * progress / 20)

        # 模拟“表征过拟合” (Representational Overfitting)
        # Latent-Only模型在训练中途会为了优化MSE损失而丢弃解码所需的信息
        if self.name == "Latent-Only" and progress > 0.5:
            # 灾难性遗忘
            self.info_completeness = max(0.0, self.info_completeness - 0.1)

    def get_robustness_asr(self) -> float:
        """根据模型的防御结构知识，计算其当前的ASR。"""
        # ASR 与防御结构知识成反比
        base_asr = 100.0
        # Causal-Only模型由于缺乏结构指导，鲁棒性提升非常缓慢
        if self.name == "Causal-Only":
            return base_asr - self.info_completeness * 20
        
        return base_asr - self.semantic_scaffolding * 95

    def get_auditability_sem_sim(self) -> float:
        """根据模型的信息完整度，计算其当前的可审计性（语义相似度）。"""
        # 可审计性与信息完整度成正比
        # Latent-Only因为没有L_causal的压力，初始信息完整度就很低
        initial_sim = 0.4 if self.name == "Latent-Only" else 0.1
        return initial_sim + self.info_completeness * 0.85


# ----------------------------------------------------------------------------
# 2. 实验主框架 (EXPERIMENT FRAMEWORK)
# ----------------------------------------------------------------------------

class HybridSupervisionAblation:
    """
    封装了整个混合监督消融实验的逻辑。
    """
    def __init__(self, total_training_steps=100, eval_interval=2):
        print("Setting up Hybrid Supervision Ablation Study...")
        self.total_steps = total_training_steps
        self.eval_interval = eval_interval

        self.training_configs = {
            "Latent-Only": {"name": "Latent-Only", "use_latent_loss": True, "use_causal_loss": False},
            "Causal-Only": {"name": "Causal-Only", "use_latent_loss": False, "use_causal_loss": True},
            "Hybrid": {"name": "Hybrid", "use_latent_loss": True, "use_causal_loss": True}
        }
        
        self.training_logs = []

    def run(self):
        """
        执行完整的训练和评估循环。
        """
        print("\nStarting simulated training and evaluation runs...")
        
        for name, config in self.training_configs.items():
            print(f"--- Simulating training for: {name} ---")
            model = MockAblationModel(config)
            
            for step in range(self.total_steps):
                model.train_step(step, self.total_steps)
                
                if step % self.eval_interval == 0:
                    asr = model.get_robustness_asr()
                    sem_sim = model.get_auditability_sem_sim()
                    
                    self.training_logs.append({
                        "Model": name,
                        "Training Step": step,
                        "ASR (%)": asr,
                        "Semantic Similarity": sem_sim
                    })
        print("Simulation finished.")

    def plot_results(self):
        """
        使用matplotlib和seaborn绘制结果图表，模仿论文中的图。
        """
        if not self.training_logs:
            print("No data to plot.")
            return
            
        df = pd.DataFrame(self.training_logs)
        
        sns.set_theme(style="whitegrid")
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        # 1. 绘制鲁棒性 (ASR) 演化图
        sns.lineplot(data=df, x="Training Step", y="ASR (%)", hue="Model", ax=axes[0], linewidth=2.5)
        axes[0].set_title("(a) Evolution of Safety against Adaptive Attacks", fontsize=14)
        axes[0].set_ylabel("Attack Success Rate (ASR, %)", fontsize=12)
        axes[0].set_xlabel("Training Steps", fontsize=12)
        axes[0].legend(title="Supervision")
        axes[0].set_ylim(0, 101)

        # 2. 绘制可审计性 (Semantic Similarity) 演化图
        sns.lineplot(data=df, x="Training Step", y="Semantic Similarity", hue="Model", ax=axes[1], linewidth=2.5)
        axes[1].set_title("(b) Evolution of the Honesty of Latent Reasoning", fontsize=14)
        axes[1].set_ylabel("Auditability (Semantic Similarity)", fontsize=12)
        axes[1].set_xlabel("Training Steps", fontsize=12)
        axes[1].legend(title="Supervision")
        axes[1].set_ylim(0, 1.01)

        fig.suptitle("Dissecting the Hybrid Supervision: A Tale of Synergy and Collapse", fontsize=16, y=1.02)
        plt.tight_layout()
        plt.show()

# ----------------------------------------------------------------------------
# 3. 运行实验 (RUN THE EXPERIMENT)
# ----------------------------------------------------------------------------

if __name__ == "__main__":
    experiment = HybridSupervisionAblation()
    experiment.run()
    experiment.plot_results()