import numpy as np
import pandas as pd
from typing import List, Dict, Optional
from dataclasses import dataclass

@dataclass
class BPACConfigIPS:
    """
    B-PAC 算法超参数配置
    """
    alpha: float = 0.1          # 容错概率 (1-Confidence), 例如 0.1 代表 90% 置信度
    epsilon: float = 0.1      # 容忍的风险上限 (Error Tolerance), 例如 0.05 代表允许 5% 的性能损失
    rho: float = 0.1           # 最小探索概率 (Minimum Exploration Probability),0到1
    num_thresholds: int = 1001   # 阈值搜索空间的精细度
    warm_up: int = 50          # 初始预热步数

class IPSHoeffding:
    def __init__(self, config: BPACConfigIPS):
        self.cfg = config
        self.threshold_candidates = np.linspace(0, 1, self.cfg.num_thresholds)
        
        # 状态初始化
        self.current_u_idx = 0 
        self.current_u = self.threshold_candidates[self.current_u_idx]
        
        # IPS + Hoeffding 统计量
        self.time_step = 0
        self.sum_Z = np.zeros(self.cfg.num_thresholds) # 累积的 IPS 估计值
        
        # 预计算常数 M_tilde = (1-rho)/rho
        # 注意：这里的 rho 应该是 rho_min (部署阶段的 rho)，以保证 bound 成立
        self.M_tilde = (1.0 - self.cfg.rho) / self.cfg.rho
        
        # 候选阈值数量 N (对应公式中的 log(N/alpha_t))
        self.N_thresholds = self.cfg.num_thresholds

    def get_action(self, uncertainty_score: float):
        """
        动作选择逻辑与 BPAC 保持一致，以保证 estimator 的输入分布相同。
        """
        # 策略: pi_t = I(U >= u) + rho * I(U < u)
        if uncertainty_score >= self.current_u:
            # 必须调用专家
            return 1, 1.0
        else:
            # 探索性调用
            propensity = self.cfg.rho 
            is_exploring = np.random.rand() < self.cfg.rho
            action = 1 if is_exploring else 0
            
            return action, propensity

    def update(self, uncertainty_score: float, action: int, observed_loss: Optional[float]):
        """
        基于 IPS + Hoeffding 的更新逻辑
        Ref: Image 'IPS+Hoeff. ...'
        """
        self.time_step += 1
        t = self.time_step
        
        # 1. 数据准备
        l_t = observed_loss if observed_loss is not None else 0.0
        xi_t = action
        
        # 2. 计算 Propensity (pi_t)
        # indicator_less: I(U_t < u)
        indicator_less = (uncertainty_score < self.threshold_candidates).astype(float)
        
        # 注意：为了构造无偏估计，分母必须是生成数据时使用的真实 pi_t
        # 如果 U < current_u, pi_t = rho; else pi_t = 1
        if uncertainty_score < self.current_u:
            pi_t_val = self.cfg.rho
        else:
            pi_t_val = 1.0
            
        # 3. 计算 Scaled IPS 估计量 Z_t(u)
        # 公式: Z_t(u) = (1 - rho_min) * l_t * xi_t * I(U < u) / pi_t
        # 注意：图片文本定义 Z_t(u) 包含了 (1-rho) 因子
        scaling_factor = (1.0 - self.cfg.rho)
        Z_t = scaling_factor * (l_t * xi_t * indicator_less) / pi_t_val
        
        # 更新累积和
        self.sum_Z += Z_t
        
        # 4. 计算 Hoeffding Upper Confidence Bound (UCB)
        # Mean Z
        mean_Z = self.sum_Z / t
        
        # Failure probability allocation: alpha_t = 6 * alpha / (pi^2 * t^2)
        alpha_t = (6 * self.cfg.alpha) / (np.pi**2 * t**2)
        
        # Penalty Term: M_tilde * sqrt( log(N / alpha_t) / 2t )
        # 加上 1e-9 防止 log(0)
        # log_term = np.log(self.N_thresholds / alpha_t + 1e-9)
        log_term = np.log(1 / alpha_t + 1e-9)
        penalty = self.M_tilde * np.sqrt(log_term / (2 * t))
        
        ucb = mean_Z + penalty
        
        # 5. 阈值选择
        # 选择满足 UCB(u) <= epsilon 的最大 u
        # valid_indices = { u : UCB(u) <= epsilon }
        
        # 注意：这里需要比较的是 Deployment Risk，所以 epsilon 是原始设定的 risk budget
        valid_indices = np.where(ucb <= self.cfg.epsilon)[0]
        
        if len(valid_indices) > 0:
            self.current_u_idx = valid_indices[-1] # Max index
            self.current_u = self.threshold_candidates[self.current_u_idx]
        else:
            # Fallback to safest (u=0, all expert)
            self.current_u_idx = 0
            self.current_u = 0.0

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict

def compute_step_loss(y_correct_t: int, y_hat_correct_t: int) -> float:
    """
    单步 Loss 计算，对应你提供的逻辑：
    Loss = 1 当且仅当 (专家对 AND 小模型错)
    """
    # y_correct_t: 1 if expert is correct, 0 else
    # y_hat_correct_t: 1 if instant model is correct, 0 else
    
    weak_wrong = 1 - y_hat_correct_t
    # 只有 strong 正确 & weak 错误 时才记为 1
    loss = float(y_correct_t * weak_wrong)
    return loss

def run_simulation_ips(data_sequence: List[Dict], config: IPSHoeffding):
    """
    data_sequence: List of item dicts
    item keys: "uncertainty", "instant_correct", "expert_correct", "instant_token", "expert_token"
    """
    model = IPSHoeffding(config)
    logs = []
    warm_up = config.warm_up
    # print(f"Start Simulation with {len(data_sequence)} samples...")
    # print(f"Config: Epsilon={config.epsilon}, Alpha={config.alpha}, Rho={config.rho}")

    # 累积变量
    total_actual_tokens = 0
    total_baseline_tokens = 0
    cumulative_loss = 0
    expert_calls = 0

    for t, item in enumerate(data_sequence):
        # 1. 提取特征
        u_t = item['uncertainty']
        inst_corr = item['instant_correct']
        exp_corr = item['expert_correct']
        inst_tok = item['instant_token']
        exp_tok = item['expert_token']
        
        action, propensity = model.get_action(u_t)
        
        # 3. 计算 Loss
        # (A) True Loss: 上帝视角，用于评估和画图
        # 即使 action=0，如果小模型错了专家对了，这里也是 1
        true_loss = compute_step_loss(exp_corr, inst_corr)
        
        # (B) Observed Loss: 算法视角 (Bandit Feedback) 
        # 只有调用了专家 (action=1)，算法才能看到 loss
        # 如果 action=0，算法不知道 loss，传入 None (内部处理为0)
        observed_loss = true_loss if action == 1 else None
        
        # 4. 算法更新
        model.update(u_t, action, observed_loss)
        if t < warm_up:
            continue

        # 5. Token 消耗计算
        # Baseline: 假设全用 Expert
        step_baseline_tokens = exp_tok
        
        # Actual: 
        # Action 0 -> instant
        # Action 1 -> instant + expert (Cascade)
        if action == 1:
            step_actual_tokens = inst_tok + exp_tok
            expert_calls += 1
        else:
            step_actual_tokens = inst_tok

        total_actual_tokens += step_actual_tokens
        total_baseline_tokens += step_baseline_tokens
        
        # 计算当前的 Token Ratio (Accumulated)
        # 避免除以0
        current_token_ratio = total_actual_tokens / total_baseline_tokens if total_baseline_tokens > 0 else 1.0
        # 计算当前专家调用的比例
        current_expert_ratio = expert_calls / (t - warm_up + 1)
        # 计算当前的 Average Risk (Accumulated True Loss / t)
        if action == 0:
            cumulative_loss += true_loss # 当调用小模型时，才记录损失
        else:
            cumulative_loss += 0 # 注意调用专家时，损失永远为0.
        current_avg_risk = cumulative_loss / (t + 1 - warm_up)

        # 6. 记录日志
        logs.append({
            "step": t,
            "uncertainty": u_t,
            "threshold": model.current_u,      # 当前阈值
            "action": action,                  # 1=Expert, 0=Instant
            "true_loss": true_loss,            # 真实损失 (上帝视角)
            "observed_loss": observed_loss if observed_loss is not None else np.nan,
            "avg_risk": current_avg_risk,      # 累积平均风险
            "token_ratio": current_token_ratio, # 累积 Token 消耗比
            "expert_call_ratio": current_expert_ratio, # 累积专家调用比
            "wealth": 0                   # 当前财富水平
        })

    return pd.DataFrame(logs), model