import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict

@dataclass
class BPACConfig:
    """
    B-PAC 算法超参数配置
    """
    alpha: float = 0.1          # 容错概率 (1-Confidence), 例如 0.1 代表 90% 置信度
    epsilon: float = 0.1      # 容忍的风险上限 (Error Tolerance), 例如 0.05 代表允许 5% 的性能损失
    rho: float = 0.1           # 最小探索概率 (Minimum Exploration Probability),0到1
    beta: float = 1.0           # FTRL 正则化参数,0到无穷
    c_clip: float = 0.9         # 投注截断常数，0到1
    num_thresholds: int = 1001   # 阈值搜索空间的精细度
    warm_up: int = 50          # 初始预热步数
    rho_0: float = 0.05
    rho_1: float = 0.6
    change_point: int = 200

def compute_step_loss(y_correct_t: int, y_hat_correct_t: int) -> float:
    """
    单步 Loss 计算，对应你提供的逻辑：
    Loss = 1 当且仅当 (专家对 AND 小模型错)
    """
    # y_correct_t: 1 if expert is correct, 0 else
    # y_hat_correct_t: 1 if instant model is correct, 0 else
    
    weak_wrong = 1 - y_hat_correct_t
    # 只有 strong 正确 & weak 错误 时才记为 1
    loss = float(y_correct_t * weak_wrong)
    return loss

class BPAC:
    def __init__(self, config: BPACConfig):
        self.cfg = config
        self.threshold_candidates = np.linspace(0, 1, self.cfg.num_thresholds)

        # 状态初始化
        self.current_u_idx = 0 
        self.current_u = self.threshold_candidates[self.current_u_idx]
        self.wealth = np.ones(self.cfg.num_thresholds) # K_0 = 1
        
        # FTRL 统计量
        self.sum_D = np.zeros(self.cfg.num_thresholds)
        self.sum_D_sq = np.zeros(self.cfg.num_thresholds)

    def get_action(self, uncertainty_score: float):
        """
        Returns:
            action (int): 1 (Expert), 0 (Instant)
            propensity (float): The probability of choosing Expert (pi_t)
        """
        # 策略: pi_t = I(U >= u) + rho * I(U < u)
        if uncertainty_score >= self.current_u:
            # 必须调用专家
            return 1, 1.0
        else:
            # 尝试使用小模型，但有 rho 的概率探索
            # prop 是指“在这个不确定性下，算法设计上调用专家的概率”
            propensity = self.cfg.rho 
            
            # 实际采样动作
            is_exploring = np.random.rand() < self.cfg.rho
            action = 1 if is_exploring else 0
            
            return action, propensity

    def update(self, uncertainty_score: float, action: int, observed_loss: Optional[float]):
        """
        核心更新逻辑 (Bandit Feedback)
        论文中的 update 仅依赖于 'observed' 数据
        """
        # 1. 数据准备
        # 如果 action=0 (没调专家)，则 observed_loss 为 None，但在公式中 l_t * xi_t 会变成 0
        l_t = observed_loss if observed_loss is not None else 0.0
        xi_t = action
        
        # 2. 计算 Propensity 向量 (Vectorized for all u)
        # indicator_less: I(U_t < u)
        indicator_less = (uncertainty_score < self.threshold_candidates).astype(float)
        # pi_t(u)
        if uncertainty_score < self.current_u:
            pi_t = self.cfg.rho
        else:
            pi_t = 1.0
        
        # 3. 计算 Payoff D_t(u)
        # D_t = epsilon - (l_t * xi_t * I(U < u)) / pi_t
        weighted_loss = (1-self.cfg.rho_0)*(l_t * xi_t * indicator_less) / pi_t

        # epsilon = epsilon / (1 - rho) 调整
        # epsilon = self.cfg.epsilon / (1.0 - self.cfg.rho)
        D_t = self.cfg.epsilon - weighted_loss
        
        # 4. FTRL Lambda 更新 [cite: 199]
        denom = self.sum_D_sq + self.cfg.beta
        denom[denom == 0] = 1e-9 # 避免除零
        lambda_raw = self.sum_D / denom
        
        M_t = max(self.cfg.epsilon,((1.0-self.cfg.rho_0)/self.cfg.rho)-self.cfg.epsilon)
        upper_bound = self.cfg.c_clip / M_t
        lambda_t = np.clip(lambda_raw, 0, upper_bound)
        
        # 5. 财富更新
        self.wealth = self.wealth * (1.0 + lambda_t * D_t)
        self.sum_D += D_t
        self.sum_D_sq += (D_t ** 2)

        # # 6. 阈值选择 [cite: 166]
        # valid_indices = np.where(self.wealth >= (1.0 / self.cfg.alpha))[0]
        
        is_safe_mask = (self.wealth >= (1.0 / self.cfg.alpha))
        prefix_safe_mask = np.logical_and.accumulate(is_safe_mask)
        valid_indices = np.where(prefix_safe_mask)[0]

        if len(valid_indices) > 0:
            self.current_u_idx = valid_indices[-1]
            self.current_u = self.threshold_candidates[self.current_u_idx]
        else:
            # Fallback to safest (all expert)
            self.current_u_idx = 0
            self.current_u = 0.0

def run_simulation(data_sequence: List[Dict], config: BPACConfig):
    """
    data_sequence: List of item dicts
    item keys: "uncertainty", "instant_correct", "expert_correct", "instant_token", "expert_token"
    """
    model = BPAC(config)
    logs = []
    warm_up = config.warm_up
    # print(f"Start Simulation with {len(data_sequence)} samples...")
    # print(f"Config: Epsilon={config.epsilon}, Alpha={config.alpha}, Rho={config.rho}")

    # 累积变量
    total_actual_tokens = 0
    total_baseline_tokens = 0
    cumulative_loss = 0
    expert_calls = 0

    for t, item in enumerate(data_sequence):
        # 1. 提取特征
        u_t = item['uncertainty']
        inst_corr = item['instant_correct']
        exp_corr = item['expert_correct']
        inst_tok = item['instant_token']
        exp_tok = item['expert_token']
        
        # 2. 算法决策

        
        if t< model.cfg.change_point:
            model.cfg.rho = model.cfg.rho_1
        else:
            model.cfg.rho = model.cfg.rho_0
            
        action, propensity = model.get_action(u_t)
        
        # 3. 计算 Loss
        # (A) True Loss: 上帝视角，用于评估和画图
        # 即使 action=0，如果小模型错了专家对了，这里也是 1
        true_loss = compute_step_loss(exp_corr, inst_corr)
        
        # (B) Observed Loss: 算法视角 (Bandit Feedback) 
        # 只有调用了专家 (action=1)，算法才能看到 loss
        # 如果 action=0，算法不知道 loss，传入 None (内部处理为0)
        observed_loss = true_loss if action == 1 else None
        
        # 4. 算法更新
        model.update(u_t, action, observed_loss)
        if t < warm_up:
            continue

        # 5. Token 消耗计算
        # Baseline: 假设全用 Expert
        step_baseline_tokens = exp_tok
        
        # Actual: 
        # Action 0 -> instant
        # Action 1 -> instant + expert (Cascade)
        if action == 1:
            step_actual_tokens = inst_tok + exp_tok
            expert_calls += 1
        else:
            step_actual_tokens = inst_tok

        total_actual_tokens += step_actual_tokens
        total_baseline_tokens += step_baseline_tokens
        
        # 计算当前的 Token Ratio (Accumulated)
        # 避免除以0
        current_token_ratio = total_actual_tokens / total_baseline_tokens if total_baseline_tokens > 0 else 1.0
        # 计算当前专家调用的比例
        current_expert_ratio = expert_calls / (t - warm_up + 1)
        # 计算当前的 Average Risk (Accumulated True Loss / t)
        if action == 0:
            cumulative_loss += true_loss # 当调用小模型时，才记录损失
        else:
            cumulative_loss += 0 # 注意调用专家时，损失永远为0.
        current_avg_risk = cumulative_loss / (t + 1 - warm_up)

        # 计算当前选的位置的财富
        wealth = model.wealth[model.current_u_idx]
        # 6. 记录日志
        logs.append({
            "step": t,
            "uncertainty": u_t,
            "threshold": model.current_u,      # 当前阈值
            "action": action,                  # 1=Expert, 0=Instant
            "true_loss": true_loss,            # 真实损失 (上帝视角)
            "observed_loss": observed_loss if observed_loss is not None else np.nan,
            "avg_risk": current_avg_risk,      # 累积平均风险
            "token_ratio": current_token_ratio, # 累积 Token 消耗比
            "expert_call_ratio": current_expert_ratio, # 累积专家调用比
            "wealth": wealth                   # 当前财富水平
        })

    return pd.DataFrame(logs), model