# UEP_ab_version.py - UEP Beta算法消融实验版本
import torch
import torch.nn.functional as F
import numpy as np
from collections import deque
from typing import Dict, Optional, List, Tuple
import warnings

from src.UEP.predictor.predictor import BanditDDPMPredictor720


class UEPBetaFixedPD:
    """
    UEP Beta算法消融实验版本
    支持以下消融开关：
    1. 自适应窗口消融：关闭时使用固定大小(20)的滑动窗口
    2. 扩散模型消融：关闭时相当于lambda等于0
    3. 自适应lambda消融：关闭时直接设置lambda为0.5不再更新
    """

    def __init__(self, n_arms: int, d_true: float, device: str = "cuda",
                 # 消融控制参数
                 use_adaptive_window: bool = True,
                 use_diffusion: bool = True,
                 use_adaptive_lambda: bool = True,
                 # 固定参数（用于消融）
                 fixed_window: int = 20,
                 fixed_lambda: float = 0.5,
                 # 算法参数
                 hist_len: int = 20,
                 timesteps: int = 20,
                 delta: float = 0.01,
                 sigma_noise: float = 0.25,
                 # 理论参数
                 C_delta: float = 0.22,
                 C_epsilon: float = 0.65,
                 d_eff: int = 32,
                 C1: float = 0.04,
                 C2: float = 0.25,
                 C3: float = 0.04,
                 alpha: float = 0.8,
                 C_coupling: float = 0.8,
                 C_decay: float = 1.0):

        self.n_arms = n_arms
        self.d_true = d_true
        self.device = device
        self.global_step = 0

        # 消融配置
        self.use_adaptive_window = use_adaptive_window
        self.use_diffusion = use_diffusion
        self.use_adaptive_lambda = use_adaptive_lambda

        # 固定参数
        self.fixed_window = fixed_window
        self.fixed_lambda = fixed_lambda

        # 算法参数
        self.hist_len = hist_len
        self.delta = delta
        self.sigma_squared = sigma_noise ** 2

        # 理论参数
        self.C_delta = C_delta
        self.C_epsilon = C_epsilon
        self.d_eff = d_eff
        self.C1 = C1
        self.C2 = C2
        self.C3 = C3
        self.alpha = alpha
        self.C_coupling = C_coupling
        self.C_decay = C_decay

        # 环境检测
        self.is_binary_environment = None
        self.reward_type_detection_samples = 50

        # 直接使用已知d
        self.d_est = d_true
        self.d_est_history = deque(maxlen=150)
        self.d_estimates_buffer = deque(maxlen=8)
        self.d_update_frequency = 5

        # 初始化组件
        self._initialize_components()

    def _initialize_components(self):
        """初始化组件"""
        # 扩散模型初始化
        if self.use_diffusion:
            try:
                self.models = [
                    BanditDDPMPredictor720(n_arms=1, hist_len=self.hist_len, timesteps=50, device=self.device)
                    for _ in range(self.n_arms)
                ]
                print(f"✅ 成功初始化 {self.n_arms} 个扩散模型")
            except Exception as e:
                print(f"⚠️ 扩散模型初始化失败: {e}")
                self.models = None
        else:
            self.models = None

        # 数据存储
        self.arm_data = {
            i: {
                'real_rewards': deque(maxlen=800),
                'virtual_rewards': deque(maxlen=100),
                'timestamps': deque(maxlen=900),
                'last_selected': -100,
                'window_size': self.fixed_window if not self.use_adaptive_window else self.hist_len,
                'lambda_value': self.fixed_lambda if not self.use_adaptive_lambda else 0.5,
                'selection_count': 0,
                'cumulative_reward': 0.0,
                'probability_estimates': deque(maxlen=200),
                'probability_changes': deque(maxlen=100)
            }
            for i in range(self.n_arms)
        }

        # 性能跟踪
        self.history = {
            'arm_counts': np.zeros(self.n_arms),
            'arm_rewards': np.zeros(self.n_arms),
            'cumulative_regret': [],
            'instantaneous_regret': [],
            'window_sizes': [],
            'lambda_values': [],
            'd_estimates': [],
            'ucb_values': [],
            'mse_history': [],
            'confidence_widths': [],
            'binary_analysis': [],
            'environment_type': [],
            'estimation_confidence': [],
            'bias_corrections': []
        }

        print(f"🎯 UEP Beta消融实验算法初始化完成:")
        print(f"   - 臂数: {self.n_arms}")
        print(f"   - 真实d值: {self.d_true}")
        print(f"   - 自适应窗口: {'开启' if self.use_adaptive_window else '关闭(固定大小' + str(self.fixed_window) + ')'}")
        print(f"   - 扩散模型: {'开启' if self.use_diffusion else '关闭(lambda=0)'}")
        print(f"   - 自适应lambda: {'开启' if self.use_adaptive_lambda else '关闭(固定' + str(self.fixed_lambda) + ')'}")

    def detect_environment_type(self):
        """环境类型检测"""
        if self.is_binary_environment is not None:
            return self.is_binary_environment

        all_rewards = []
        for arm_idx in range(self.n_arms):
            rewards = list(self.arm_data[arm_idx]['real_rewards'])
            all_rewards.extend(rewards)

        if len(all_rewards) >= self.reward_type_detection_samples:
            unique_values = set(all_rewards)
            self.is_binary_environment = unique_values.issubset({0, 0.0, 1, 1.0})
        else:
            self.is_binary_environment = True

        return self.is_binary_environment

    def compute_optimized_window_size(self, arm_idx: int) -> int:
        """计算优化窗口大小"""
        if not self.use_adaptive_window:
            return self.fixed_window

        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        is_binary = self.detect_environment_type()

        if is_binary:
            base_multiplier = 1.3
            min_window = max(25, int(np.log(t + 1) * 5))
        else:
            base_multiplier = 1.0
            min_window = max(20, int(np.log(t + 1) * 4))

        if d <= 0.7:
            base_window = int((25 + np.sqrt(t) * 2.5) * base_multiplier)
            max_window = min(200, int(t ** 0.55))
        elif d <= 1.0:
            base_window = int((35 + np.sqrt(t) * 3.5) * base_multiplier)
            max_window = min(300, int(t ** 0.6))
        else:
            base_window = int((45 + np.sqrt(t) * 4.5) * base_multiplier)
            max_window = min(400, int(t ** 0.65))

        W_optimal = np.clip(base_window, min_window, max_window)

        n_samples = len(self.arm_data[arm_idx]['real_rewards'])
        if n_samples < W_optimal:
            W_optimal = max(min_window, n_samples)

        return int(W_optimal)

    def compute_mse_components_theoretical(self, arm_idx: int) -> Tuple[float, float, float]:
        """计算理论MSE组件"""
        arm_data = self.arm_data[arm_idx]
        real_rewards = list(arm_data['real_rewards'])
        window_size = arm_data['window_size']
        n_samples = min(len(real_rewards), window_size)

        if n_samples <= 2:
            return (1.0, 1.0, 0.1)

        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        eps = 1e-8

        # 历史估计MSE
        var_stat = max(self.sigma_squared * self.n_arms / max(n_samples, 1), eps)
        window_term = max(window_size ** (2 * d), eps)
        time_term = max(t ** (2 * d), eps)
        bias_env = (self.C_delta ** 2 * self.n_arms ** 2 * window_term) / \
                   (max((d + 1) ** 2, eps) * time_term)
        mse_hist = var_stat + bias_env

        # 预测估计MSE
        model_bias = max(self.C1 * (t ** (-2 * d)), eps)
        sample_var = max(self.C2 * (n_samples ** (-self.alpha)), eps)
        time_diff = max(self.C3 * (t ** (-d)), eps)
        mse_pred = model_bias + sample_var + time_diff

        # 协方差界限
        log_t = max(np.log(t), 1)
        sqrt_n = max(np.sqrt(n_samples), 1)
        rho_stat = self.C_coupling * np.sqrt(log_t) / sqrt_n
        delta_t = max(t ** (-d), eps)
        rho_decay = self.C_decay / (1 + delta_t * window_size)
        rho_max = min(0.95, rho_stat, rho_decay)
        cov_bound = rho_max * np.sqrt(mse_hist * mse_pred)

        return (max(mse_hist, eps), max(mse_pred, eps), max(cov_bound, 0))

    def compute_optimal_lambda_theoretical(self, mse_hist: float, mse_pred: float, cov_bound: float) -> float:
        """计算理论最优lambda"""
        if not self.use_adaptive_lambda:
            return self.fixed_lambda

        A, B, C = max(mse_hist, 1e-8), max(mse_pred, 1e-8), max(cov_bound, 0)
        denominator = A + B - 2 * C

        if denominator > 1e-6:
            lambda_opt = (A - C) / denominator
        else:
            lambda_opt = 0.1 if A <= B else 0.9

        return np.clip(lambda_opt, 0.05, 0.95)

    def compute_mixed_confidence_bound(self, arm_idx: int, lambda_t: float) -> float:
        """计算混合置信区间"""
        n_samples = len(self.arm_data[arm_idx]['real_rewards'])
        if n_samples == 0:
            return 1.0

        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        eps = 1e-8

        log_term = max(np.log(4 * self.n_arms * t / self.delta), 1)
        sigma_h_sq = (2 * self.sigma_squared * log_term) / max(n_samples, 1)
        pred_var_term = max(self.C1 * (t ** (-2 * d)), eps)
        sigma_p_sq = sigma_h_sq + pred_var_term

        sqrt_log_t = np.sqrt(max(np.log(t), 1))
        sqrt_n = np.sqrt(max(n_samples, 1))
        cov_bound = min(
            np.sqrt(sigma_h_sq * sigma_p_sq),
            self.C_coupling * sqrt_log_t / sqrt_n * np.sqrt(sigma_h_sq * sigma_p_sq)
        )

        lambda_t = np.clip(lambda_t, 0, 1)
        sigma_mix_sq = ((1 - lambda_t) ** 2 * sigma_h_sq +
                        lambda_t ** 2 * sigma_p_sq +
                        2 * lambda_t * (1 - lambda_t) * abs(cov_bound))

        window_size = self.arm_data[arm_idx]['window_size']
        bias_hist = self.C_delta * self.n_arms * (window_size ** (d + 1)) / \
                    (max((d + 1) * n_samples, 1)) * max((t ** (-d)), eps)
        bias_pred = self.C3 * max((t ** (-d)), eps)
        bias_total = (1 - lambda_t) * bias_hist + lambda_t * bias_pred

        confidence_bound = np.sqrt(max(sigma_mix_sq, eps)) + bias_total
        return max(confidence_bound, 0.01)

    def select_arm(self) -> int:
        """臂选择"""
        # 更新窗口大小
        for arm_idx in range(self.n_arms):
            self.arm_data[arm_idx]['window_size'] = self.compute_optimized_window_size(arm_idx)

        ucb_values = []
        lambda_values = []

        for arm_idx in range(self.n_arms):
            arm_data = self.arm_data[arm_idx]
            real_rewards = list(arm_data['real_rewards'])

            if len(real_rewards) == 0:
                ucb_values.append(float('inf'))
                lambda_values.append(0.0)
                continue

            window = arm_data['window_size']
            recent_rewards = real_rewards[-window:] if window <= len(real_rewards) else real_rewards
            mu_hist = np.mean(recent_rewards)

            # 混合估计与扩散模型
            if self.use_diffusion and self.models and len(recent_rewards) >= 5:
                try:
                    mu_pred, var_pred = self.models[arm_idx](0)
                    mu_pred = float(mu_pred.item() if torch.is_tensor(mu_pred) else mu_pred)

                    if self.use_adaptive_lambda:
                        mse_hist, mse_pred, cov = self.compute_mse_components_theoretical(arm_idx)
                        lambda_t = self.compute_optimal_lambda_theoretical(mse_hist, mse_pred, cov)
                    else:
                        lambda_t = self.fixed_lambda

                    mu_mixed = (1 - lambda_t) * mu_hist + lambda_t * mu_pred
                    cb = self.compute_mixed_confidence_bound(arm_idx, lambda_t)

                except Exception as e:
                    print(f"⚠️ 臂 {arm_idx} 预测失败: {e}")
                    mu_mixed = mu_hist
                    lambda_t = 0.0
                    cb = np.sqrt(2 * np.log(self.global_step + 1) / len(recent_rewards))
            else:
                # 扩散模型关闭时，lambda=0，只使用历史估计
                mu_mixed = mu_hist
                lambda_t = 0.0
                cb = np.sqrt(2 * np.log(self.global_step + 1) / len(recent_rewards))

            arm_data['lambda_value'] = lambda_t
            lambda_values.append(lambda_t)

            ucb_value = mu_mixed + cb
            ucb_values.append(ucb_value)

        # 记录历史
        self.history['ucb_values'].append(ucb_values.copy())
        self.history['lambda_values'].append(np.mean(lambda_values))
        self.history['environment_type'].append(self.detect_environment_type())

        # 选择臂
        if any(np.isinf(ucb_values)):
            unexplored = [i for i, v in enumerate(ucb_values) if np.isinf(v)]
            selected_arm = np.random.choice(unexplored)
        else:
            selected_arm = int(np.argmax(ucb_values))

        return selected_arm

    def update(self, arm: int, reward: float):
        """更新算法状态"""
        reward = np.clip(float(reward), 0.0, 1.0)

        # 更新数据
        self.arm_data[arm]['real_rewards'].append(reward)
        self.arm_data[arm]['timestamps'].append(self.global_step)
        self.arm_data[arm]['last_selected'] = self.global_step
        self.arm_data[arm]['selection_count'] += 1
        self.arm_data[arm]['cumulative_reward'] += reward

        # 更新概率跟踪
        if self.detect_environment_type():
            recent_rewards = list(self.arm_data[arm]['real_rewards'])[-20:]
            if len(recent_rewards) >= 10:
                prob_estimate = np.mean(recent_rewards)
                self.arm_data[arm]['probability_estimates'].append(prob_estimate)

        # 更新全局统计
        self.history['arm_counts'][arm] += 1
        self.history['arm_rewards'][arm] += reward

        # 更新扩散模型
        if self.use_diffusion and self.models:
            try:
                self.models[arm].update_history(reward, arm_idx=0)

                if len(self.arm_data[arm]['real_rewards']) >= 8:
                    hist_data = list(self.arm_data[arm]['real_rewards'])[-self.hist_len:]
                    hist_tensor = torch.tensor(hist_data, device=self.device, dtype=torch.float32)

                    if len(hist_tensor) < self.hist_len:
                        hist_tensor = F.pad(hist_tensor, (0, self.hist_len - len(hist_tensor)))

                    self.models[arm].add_training_sample(
                        hist_tensor.unsqueeze(-1),
                        torch.tensor([[reward]], device=self.device, dtype=torch.float32),
                        0
                    )

                if self.global_step % 15 == 0:
                    self.models[arm].train_step(0)

            except Exception as e:
                print(f"⚠️ 臂 {arm} 模型更新失败: {e}")

        # 记录演化
        avg_window = np.mean([self.arm_data[i]['window_size'] for i in range(self.n_arms)])
        self.history['window_sizes'].append(avg_window)
        self.history['d_estimates'].append(self.d_est)

        self.global_step += 1

    def compute_regret(self, true_means: np.ndarray) -> float:
        """计算瞬时遗憾"""
        if self.global_step == 0:
            return 0.0

        optimal_arm = np.argmax(true_means)
        optimal_reward = true_means[optimal_arm]

        last_selected = -1
        for arm in range(self.n_arms):
            if (self.arm_data[arm]['timestamps'] and
                    self.arm_data[arm]['timestamps'][-1] == self.global_step - 1):
                last_selected = arm
                break

        if last_selected == -1:
            return 0.0

        selected_reward = true_means[last_selected]
        regret = optimal_reward - selected_reward

        self.history['instantaneous_regret'].append(regret)
        if self.history['cumulative_regret']:
            cumulative = self.history['cumulative_regret'][-1] + regret
        else:
            cumulative = regret
        self.history['cumulative_regret'].append(cumulative)

        return regret

    def get_diagnostics(self) -> Dict:
        """获取诊断信息"""
        return {
            'global_step': self.global_step,
            'd_estimate': self.d_est,
            'd_history': list(self.d_est_history),
            'environment_type': self.detect_environment_type(),
            'avg_window_size': np.mean([self.arm_data[i]['window_size'] for i in range(self.n_arms)]),
            'window_sizes': [self.arm_data[i]['window_size'] for i in range(self.n_arms)],
            'avg_lambda': np.mean([self.arm_data[i]['lambda_value'] for i in range(self.n_arms)]),
            'lambda_values': [self.arm_data[i]['lambda_value'] for i in range(self.n_arms)],
            'arm_counts': self.history['arm_counts'].tolist(),
            'cumulative_regret': self.history['cumulative_regret'][-1] if self.history['cumulative_regret'] else 0,
            'estimation_confidence': self.history['estimation_confidence'][-5:] if self.history['estimation_confidence'] else [0.4],
            'bias_corrections': self.history['bias_corrections'][-5:] if self.history['bias_corrections'] else [0.0],
            'config': {
                'use_adaptive_window': self.use_adaptive_window,
                'use_diffusion': self.use_diffusion,
                'use_adaptive_lambda': self.use_adaptive_lambda,
                'fixed_window': self.fixed_window,
                'fixed_lambda': self.fixed_lambda,
                'environment_type': 'binary' if self.detect_environment_type() else 'continuous',
                'optimized_approach': True,
                'perfect_d': True
            }
        }


if __name__ == "__main__":
    print("✅ UEPBetaFixedPD 消融实验算法类创建成功！")
    print("   主要特性：")
    print("   - 支持自适应窗口消融")
    print("   - 支持扩散模型消融")
    print("   - 支持自适应lambda消融")
    print("   - 使用真实d值进行理论计算")
    print("   - 优化的窗口大小计算")
