import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict

@dataclass
class BPACConfignaive:
    """
    B-PAC 算法超参数配置
    """
    alpha: float = 0.1          # 容错概率 (1-Confidence), 例如 0.1 代表 90% 置信度
    epsilon: float = 0.1      # 容忍的风险上限 (Error Tolerance), 例如 0.05 代表允许 5% 的性能损失
    rho: float = 0.1           # 最小探索概率 (Minimum Exploration Probability),0到1
    num_thresholds: int = 1001   # 阈值搜索空间的精细度
    warm_up: int = 50          # 初始预热步数

def compute_step_loss(y_correct_t: int, y_hat_correct_t: int) -> float:
    """
    单步 Loss 计算，对应你提供的逻辑：
    Loss = 1 当且仅当 (专家对 AND 小模型错)
    """
    # y_correct_t: 1 if expert is correct, 0 else
    # y_hat_correct_t: 1 if instant model is correct, 0 else
    
    weak_wrong = 1 - y_hat_correct_t
    # 只有 strong 正确 & weak 错误 时才记为 1
    loss = float(y_correct_t * weak_wrong)
    return loss

class Onaive:
    def __init__(self, config: BPACConfignaive):
        self.cfg = config
        self.threshold_candidates = np.linspace(0, 1, self.cfg.num_thresholds)

        # 状态初始化
        self.current_u_idx = 0 
        self.current_u = self.threshold_candidates[self.current_u_idx]
        self.t = 1
        self.sum_risk_terms = np.zeros(self.cfg.num_thresholds)

        
    def get_action(self, uncertainty_score: float):
        """
        Returns:
            action (int): 1 (Expert), 0 (Instant)
            propensity (float): The probability of choosing Expert (pi_t)
        """
        # 策略: pi_t = I(U >= u) + rho * I(U < u)
        if uncertainty_score >= self.current_u:
            # 必须调用专家
            return 1, 1.0
        else:
            # 尝试使用小模型，但有 rho 的概率探索
            # prop 是指“在这个不确定性下，算法设计上调用专家的概率”
            propensity = self.cfg.rho 
            
            # 实际采样动作
            is_exploring = np.random.rand() < self.cfg.rho
            action = 1 if is_exploring else 0
            
            return action, propensity

    def update(self, uncertainty_score: float, action: int, observed_loss: Optional[float]):
        """
        根据 O-Naive 公式更新风险估计和阈值
        R_hat(u) = (1/t) * Sum( xi * l * I(U < u) )
        """
        self.t += 1
        
        # =========================================================
        # 1. 计算公式中的单步项: xi * l * I(U < u)
        # =========================================================
        
        # (a) 标量部分: xi * l
        xi = action
        l_t = observed_loss if observed_loss is not None else 0.0
        
        # 只有当调用专家(xi=1)且发生错误(l=1)时，scalar_term 才为 1，否则为 0
        scalar_term = xi * l_t 
        
        # (b) 向量部分: I(U_t < u)
        # 这是一个形状为 (num_thresholds,) 的 0/1 向量
        indicator_less = (uncertainty_score < self.threshold_candidates).astype(float)
        
        # (c) 累加到总和
        # self.sum_risk_terms += scalar_term * indicator_less
        # 只有那些“阈值 u 比当前 uncertainty 大”的候选者，才可能在这一步积累风险
        self.sum_risk_terms += (scalar_term * indicator_less)
            
        # =========================================================
        # 2. 计算平均累积风险 R_hat(u)
        # =========================================================
        estimated_risk = self.sum_risk_terms / self.t
        
        # =========================================================
        # 3. 阈值选择: max { u : R_hat(u) <= epsilon }
        # =========================================================
        # 找到所有风险达标的索引
        valid_indices = np.where(estimated_risk <= self.cfg.epsilon)[0]
        
        if len(valid_indices) > 0:
            # 贪婪选择最大的那个
            self.current_u_idx = valid_indices[-1]
            self.current_u = self.threshold_candidates[self.current_u_idx]
        else:
            # 没有任何阈值满足风险要求 (说明连 u=0 都不安全，或者刚开始运气极差)
            # 退化到最保守策略 (全专家)
            self.current_u_idx = 0
            self.current_u = 0.0
        

def run_simulation_onaive(data_sequence: List[Dict], config: BPACConfignaive):
    """
    data_sequence: List of item dicts
    item keys: "uncertainty", "instant_correct", "expert_correct", "instant_token", "expert_token"
    """
    model = Onaive(config)
    logs = []
    warm_up = config.warm_up
    # print(f"Start Simulation with {len(data_sequence)} samples...")
    # print(f"Config: Epsilon={config.epsilon}, Alpha={config.alpha}, Rho={config.rho}")

    # 累积变量
    total_actual_tokens = 0
    total_baseline_tokens = 0
    cumulative_loss = 0
    expert_calls = 0

    for t, item in enumerate(data_sequence):
        # 1. 提取特征
        u_t = item['uncertainty']
        inst_corr = item['instant_correct']
        exp_corr = item['expert_correct']
        inst_tok = item['instant_token']
        exp_tok = item['expert_token']
        
            
        action, propensity = model.get_action(u_t)
        
        # 3. 计算 Loss
        # (A) True Loss: 上帝视角，用于评估和画图
        # 即使 action=0，如果小模型错了专家对了，这里也是 1
        true_loss = compute_step_loss(exp_corr, inst_corr)
        
        # (B) Observed Loss: 算法视角 (Bandit Feedback) 
        # 只有调用了专家 (action=1)，算法才能看到 loss
        # 如果 action=0，算法不知道 loss，传入 None (内部处理为0)
        observed_loss = true_loss if action == 1 else None
        
        # 4. 算法更新
        model.update(u_t, action, observed_loss)
        if t < warm_up:
            continue

        # 5. Token 消耗计算
        # Baseline: 假设全用 Expert
        step_baseline_tokens = exp_tok
        
        # Actual: 
        # Action 0 -> instant
        # Action 1 -> instant + expert (Cascade)
        if action == 1:
            step_actual_tokens = inst_tok + exp_tok
            expert_calls += 1
        else:
            step_actual_tokens = inst_tok

        total_actual_tokens += step_actual_tokens
        total_baseline_tokens += step_baseline_tokens
        
        # 计算当前的 Token Ratio (Accumulated)
        # 避免除以0
        current_token_ratio = total_actual_tokens / total_baseline_tokens if total_baseline_tokens > 0 else 1.0
        # 计算当前专家调用的比例
        current_expert_ratio = expert_calls / (t - warm_up + 1)
        # 计算当前的 Average Risk (Accumulated True Loss / t)
        if action == 0:
            cumulative_loss += true_loss # 当调用小模型时，才记录损失
        else:
            cumulative_loss += 0 # 注意调用专家时，损失永远为0.
        current_avg_risk = cumulative_loss / (t + 1 - warm_up)

        # 6. 记录日志
        logs.append({
            "step": t,
            "uncertainty": u_t,
            "threshold": model.current_u,      # 当前阈值
            "action": action,                  # 1=Expert, 0=Instant
            "true_loss": true_loss,            # 真实损失 (上帝视角)
            "observed_loss": observed_loss if observed_loss is not None else np.nan,
            "avg_risk": current_avg_risk,      # 累积平均风险
            "token_ratio": current_token_ratio, # 累积 Token 消耗比
            "expert_call_ratio": current_expert_ratio, # 累积专家调用比
            "wealth": 0                   # 当前财富水平
        })

    return pd.DataFrame(logs), model