# balanced_banditdiff_ucb_perfect_d.py
import torch
import torch.nn.functional as F
import numpy as np
from collections import deque
from typing import Dict, Optional, List, Tuple
import warnings

from src.new_formal.model.predictor import BanditDDPMPredictor717Compatible
from src.UEP.predictor.predictor import BanditDDPMPredictor720

class BalancedBanditDiffUCBPerfectD:
    """
    Balanced BanditDiff UCB，假设每个环境的d已知（无需估计）
    """
    def __init__(self, n_arms: int, d_true: float, device: str = "cuda",
                 # Ablation control parameters
                 use_diffusion: bool = True,
                 use_adaptive_window: bool = True,
                 use_theory_guided: bool = True,
                 use_virtual_data: bool = True,
                 # Fixed parameters (for ablation)
                 fixed_window: Optional[int] = None,
                 fixed_lambda: Optional[float] = None,
                 # Algorithm parameters
                 hist_len: int = 20,
                 timesteps: int = 20,
                 delta: float = 0.01,
                 sigma_noise: float = 0.25,
                 virtual_method: str = "optimistic",
                 # Theory parameters - rebalanced
                 C_delta: float = 0.22,
                 C_epsilon: float = 0.65,
                 d_eff: int = 32,
                 C1: float = 0.04,
                 C2: float = 0.25,
                 C3: float = 0.04,
                 alpha: float = 0.8,
                 C_coupling: float = 0.8,
                 C_decay: float = 1.0):

        self.n_arms = n_arms
        self.device = device
        self.global_step = 0

        # Ablation configuration
        self.use_diffusion = use_diffusion
        self.use_adaptive_window = use_adaptive_window
        self.use_theory_guided = use_theory_guided
        self.use_virtual_data = use_virtual_data

        # Fixed parameters
        self.fixed_window = fixed_window
        self.fixed_lambda = fixed_lambda

        # Algorithm parameters
        self.hist_len = hist_len
        self.delta = delta
        self.sigma_squared = sigma_noise ** 2
        self.virtual_method = virtual_method

        # Theory parameters
        self.C_delta = C_delta
        self.C_epsilon = C_epsilon
        self.d_eff = d_eff
        self.C1 = C1
        self.C2 = C2
        self.C3 = C3
        self.alpha = alpha
        self.C_coupling = C_coupling
        self.C_decay = C_decay

        # Environment detection
        self.is_binary_environment = None
        self.reward_type_detection_samples = 50

        # 直接使用已知d
        self.d_est = d_true
        self.d_est_history = deque(maxlen=150)
        self.d_estimates_buffer = deque(maxlen=8)
        self.d_update_frequency = 5

        # Initialize components
        self._initialize_components()

    def _initialize_components(self):
        if self.use_diffusion:
            try:
                # self.models = [
                #     BanditDDPMPredictor717Compatible(1, self.hist_len, 50, self.device)
                #     for _ in range(self.n_arms)
                # ]
                # BanditDDPMPredictor720
                self.models = [
                    BanditDDPMPredictor720(1, self.hist_len, 50, self.device)
                    for _ in range(self.n_arms)
                ]
                print(f"✅ Successfully initialized {self.n_arms} diffusion models")
            except Exception as e:
                print(f"⚠️ Diffusion model initialization failed: {e}")
                self.models = None
        else:
            self.models = None

        self.arm_data = {
            i: {
                'real_rewards': deque(maxlen=800),
                'virtual_rewards': deque(maxlen=100),
                'timestamps': deque(maxlen=900),
                'last_selected': -100,
                'window_size': self.hist_len,
                'lambda_value': 0.5 if self.fixed_lambda is None else self.fixed_lambda,
                'selection_count': 0,
                'cumulative_reward': 0.0,
                'probability_estimates': deque(maxlen=200),
                'probability_changes': deque(maxlen=100)
            }
            for i in range(self.n_arms)
        }

        self.history = {
            'arm_counts': np.zeros(self.n_arms),
            'arm_rewards': np.zeros(self.n_arms),
            'cumulative_regret': [],
            'instantaneous_regret': [],
            'window_sizes': [],
            'lambda_values': [],
            'd_estimates': [],
            'ucb_values': [],
            'mse_history': [],
            'confidence_widths': [],
            'binary_analysis': [],
            'environment_type': [],
            'estimation_confidence': [],
            'bias_corrections': []
        }
        print(f"🎯 Perfect-d BanditDiff initialization completed:")
        print(f"   - Arms: {self.n_arms}")
        print(f"   - Known d: {self.d_est}")

    def detect_environment_type(self):
        if self.is_binary_environment is not None:
            return self.is_binary_environment
        all_rewards = []
        for arm_idx in range(self.n_arms):
            rewards = list(self.arm_data[arm_idx]['real_rewards'])
            all_rewards.extend(rewards)
        if len(all_rewards) >= self.reward_type_detection_samples:
            unique_values = set(all_rewards)
            self.is_binary_environment = unique_values.issubset({0, 0.0, 1, 1.0})
        else:
            self.is_binary_environment = True
        return self.is_binary_environment

    def compute_optimized_window_size(self, arm_idx: int) -> int:
        if not self.use_adaptive_window:
            if self.fixed_window:
                return self.fixed_window
            else:
                return len(self.arm_data[arm_idx]['real_rewards']) + 1
        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        is_binary = self.detect_environment_type()
        if is_binary:
            base_multiplier = 1.3
            min_window = max(25, int(np.log(t + 1) * 5))
        else:
            base_multiplier = 1.0
            min_window = max(20, int(np.log(t + 1) * 4))
        if d <= 0.7:
            base_window = int((25 + np.sqrt(t) * 2.5) * base_multiplier)
            max_window = min(200, int(t ** 0.55))
        elif d <= 1.0:
            base_window = int((35 + np.sqrt(t) * 3.5) * base_multiplier)
            max_window = min(300, int(t ** 0.6))
        else:
            base_window = int((45 + np.sqrt(t) * 4.5) * base_multiplier)
            max_window = min(400, int(t ** 0.65))
        W_optimal = np.clip(base_window, min_window, max_window)
        n_samples = len(self.arm_data[arm_idx]['real_rewards'])
        if n_samples < W_optimal:
            W_optimal = max(min_window, n_samples)
        return int(W_optimal)

    def select_arm(self) -> int:
        for arm_idx in range(self.n_arms):
            self.arm_data[arm_idx]['window_size'] = self.compute_optimized_window_size(arm_idx)
        ucb_values = []
        lambda_values = []
        for arm_idx in range(self.n_arms):
            arm_data = self.arm_data[arm_idx]
            real_rewards = list(arm_data['real_rewards'])
            if len(real_rewards) == 0:
                ucb_values.append(float('inf'))
                lambda_values.append(0.0)
                continue
            window = arm_data['window_size']
            recent_rewards = real_rewards[-window:] if window <= len(real_rewards) else real_rewards
            mu_hist = np.mean(recent_rewards)
            if self.use_diffusion and self.models and len(recent_rewards) >= 5:
                try:
                    mu_pred, var_pred = self.models[arm_idx](arm_idx)
                    mu_pred = float(mu_pred.item() if torch.is_tensor(mu_pred) else mu_pred)
                    if self.use_theory_guided:
                        mse_hist, mse_pred, cov = self.compute_mse_components_theoretical(arm_idx)
                        lambda_t = self.compute_optimal_lambda_theoretical(mse_hist, mse_pred, cov)
                    else:
                        lambda_t = self.fixed_lambda if self.fixed_lambda else 0.5
                    mu_mixed = (1 - lambda_t) * mu_hist + lambda_t * mu_pred
                    cb = self.compute_mixed_confidence_bound(arm_idx, lambda_t)
                except Exception:
                    mu_mixed = mu_hist
                    lambda_t = 0.0
                    cb = np.sqrt(2 * np.log(self.global_step + 1) / len(recent_rewards))
            else:
                mu_mixed = mu_hist
                lambda_t = 0.0
                cb = np.sqrt(2 * np.log(self.global_step + 1) / len(recent_rewards))
            arm_data['lambda_value'] = lambda_t
            lambda_values.append(lambda_t)
            ucb_value = mu_mixed + cb
            ucb_values.append(ucb_value)
        self.history['ucb_values'].append(ucb_values.copy())
        self.history['lambda_values'].append(np.mean(lambda_values))
        self.history['environment_type'].append(self.detect_environment_type())
        if any(np.isinf(ucb_values)):
            unexplored = [i for i, v in enumerate(ucb_values) if np.isinf(v)]
            selected_arm = np.random.choice(unexplored)
        else:
            selected_arm = int(np.argmax(ucb_values))
        return selected_arm

    def compute_mse_components_theoretical(self, arm_idx: int) -> Tuple[float, float, float]:
        arm_data = self.arm_data[arm_idx]
        real_rewards = list(arm_data['real_rewards'])
        window_size = arm_data['window_size']
        n_samples = min(len(real_rewards), window_size)
        if n_samples <= 2:
            return (1.0, 1.0, 0.1)
        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        eps = 1e-8
        var_stat = max(self.sigma_squared * self.n_arms / max(n_samples, 1), eps)
        window_term = max(window_size ** (2 * d), eps)
        time_term = max(t ** (2 * d), eps)
        bias_env = (self.C_delta ** 2 * self.n_arms ** 2 * window_term) / \
                   (max((d + 1) ** 2, eps) * time_term)
        mse_hist = var_stat + bias_env
        model_bias = max(self.C1 * (t ** (-2 * d)), eps)
        sample_var = max(self.C2 * (n_samples ** (-self.alpha)), eps)
        time_diff = max(self.C3 * (t ** (-d)), eps)
        mse_pred = model_bias + sample_var + time_diff
        log_t = max(np.log(t), 1)
        sqrt_n = max(np.sqrt(n_samples), 1)
        rho_stat = self.C_coupling * np.sqrt(log_t) / sqrt_n
        delta_t = max(t ** (-d), eps)
        rho_decay = self.C_decay / (1 + delta_t * window_size)
        rho_max = min(0.95, rho_stat, rho_decay)
        cov_bound = rho_max * np.sqrt(mse_hist * mse_pred)
        return (max(mse_hist, eps), max(mse_pred, eps), max(cov_bound, 0))

    def compute_optimal_lambda_theoretical(self, mse_hist: float, mse_pred: float, cov_bound: float) -> float:
        if not self.use_theory_guided and self.fixed_lambda is not None:
            return self.fixed_lambda
        A, B, C = max(mse_hist, 1e-8), max(mse_pred, 1e-8), max(cov_bound, 0)
        denominator = A + B - 2 * C
        if denominator > 1e-6:
            lambda_opt = (A - C) / denominator
        else:
            lambda_opt = 0.1 if A <= B else 0.9
        return np.clip(lambda_opt, 0.05, 0.95)

    def compute_mixed_confidence_bound(self, arm_idx: int, lambda_t: float) -> float:
        n_samples = len(self.arm_data[arm_idx]['real_rewards'])
        if n_samples == 0:
            return 1.0
        t = max(self.global_step, 1)
        d = max(self.d_est, 0.3)
        eps = 1e-8
        log_term = max(np.log(4 * self.n_arms * t / self.delta), 1)
        sigma_h_sq = (2 * self.sigma_squared * log_term) / max(n_samples, 1)
        pred_var_term = max(self.C1 * (t ** (-2 * d)), eps)
        sigma_p_sq = sigma_h_sq + pred_var_term
        sqrt_log_t = np.sqrt(max(np.log(t), 1))
        sqrt_n = np.sqrt(max(n_samples, 1))
        cov_bound = min(
            np.sqrt(sigma_h_sq * sigma_p_sq),
            self.C_coupling * sqrt_log_t / sqrt_n * np.sqrt(sigma_h_sq * sigma_p_sq)
        )
        lambda_t = np.clip(lambda_t, 0, 1)
        sigma_mix_sq = ((1 - lambda_t) ** 2 * sigma_h_sq +
                        lambda_t ** 2 * sigma_p_sq +
                        2 * lambda_t * (1 - lambda_t) * abs(cov_bound))
        window_size = self.arm_data[arm_idx]['window_size']
        bias_hist = self.C_delta * self.n_arms * (window_size ** (d + 1)) / \
                    (max((d + 1) * n_samples, 1)) * max((t ** (-d)), eps)
        bias_pred = self.C3 * max((t ** (-d)), eps)
        bias_total = (1 - lambda_t) * bias_hist + lambda_t * bias_pred
        confidence_bound = np.sqrt(max(sigma_mix_sq, eps)) + bias_total
        return max(confidence_bound, 0.01)

    def update(self, arm: int, reward: float):
        reward = np.clip(float(reward), 0.0, 1.0)
        self.arm_data[arm]['real_rewards'].append(reward)
        self.arm_data[arm]['timestamps'].append(self.global_step)
        self.arm_data[arm]['last_selected'] = self.global_step
        self.arm_data[arm]['selection_count'] += 1
        self.arm_data[arm]['cumulative_reward'] += reward
        if self.detect_environment_type():
            recent_rewards = list(self.arm_data[arm]['real_rewards'])[-20:]
            if len(recent_rewards) >= 10:
                prob_estimate = np.mean(recent_rewards)
                self.arm_data[arm]['probability_estimates'].append(prob_estimate)
        self.history['arm_counts'][arm] += 1
        self.history['arm_rewards'][arm] += reward
        if self.use_diffusion and self.models:
            try:
                self.models[arm].update_history(
                    torch.tensor([[reward]], device=self.device, dtype=torch.float32)
                )
                if len(self.arm_data[arm]['real_rewards']) >= 8:
                    hist_data = list(self.arm_data[arm]['real_rewards'])[-self.hist_len:]
                    hist_tensor = torch.tensor(hist_data, device=self.device, dtype=torch.float32)
                    if len(hist_tensor) < self.hist_len:
                        hist_tensor = F.pad(hist_tensor, (0, self.hist_len - len(hist_tensor)))
                    self.models[arm].add_training_sample(
                        hist_tensor.unsqueeze(-1),
                        torch.tensor([[reward]], device=self.device, dtype=torch.float32),
                        0
                    )
                if self.global_step % 15 == 0:
                    self.models[arm].train_step()
            except Exception:
                pass
        if self.use_virtual_data:
            for other_arm in range(self.n_arms):
                if other_arm != arm:
                    vr = self.generate_virtual_reward_enhanced(other_arm)
                    if vr is not None:
                        self.arm_data[other_arm]['virtual_rewards'].append(vr)
        avg_window = np.mean([self.arm_data[i]['window_size'] for i in range(self.n_arms)])
        self.history['window_sizes'].append(avg_window)
        self.history['d_estimates'].append(self.d_est)
        self.global_step += 1

    def generate_virtual_reward_enhanced(self, arm_idx: int) -> Optional[float]:
        if not self.use_virtual_data:
            return None
        arm_data = self.arm_data[arm_idx]
        time_since = self.global_step - arm_data['last_selected']
        if time_since < 12:
            return None
        if self.virtual_method == "optimistic":
            if arm_data['real_rewards']:
                recent_rewards = list(arm_data['real_rewards'])[-18:]
                hist_mean = np.mean(recent_rewards)
                hist_std = np.std(recent_rewards) if len(recent_rewards) > 1 else 0.1
                exploration_bonus = np.sqrt(2 * np.log(self.global_step + 1) /
                                            (len(arm_data['real_rewards']) + 1))
                reward = hist_mean + 0.6 * exploration_bonus + 0.1 * hist_std
                reward = np.clip(reward, 0.0, 1.0)
            else:
                reward = 0.7
        elif self.virtual_method == "mean":
            if arm_data['real_rewards']:
                reward = np.mean(list(arm_data['real_rewards'])[-18:])
            else:
                reward = 0.5
        else:
            reward = 0.5
        decay = np.exp(-time_since / 65.0)
        return reward * decay

    def compute_regret(self, true_means: np.ndarray) -> float:
        if self.global_step == 0:
            return 0.0
        optimal_arm = np.argmax(true_means)
        optimal_reward = true_means[optimal_arm]
        last_selected = -1
        for arm in range(self.n_arms):
            if (self.arm_data[arm]['timestamps'] and
                    self.arm_data[arm]['timestamps'][-1] == self.global_step - 1):
                last_selected = arm
                break
        if last_selected == -1:
            return 0.0
        selected_reward = true_means[last_selected]
        regret = optimal_reward - selected_reward
        self.history['instantaneous_regret'].append(regret)
        if self.history['cumulative_regret']:
            cumulative = self.history['cumulative_regret'][-1] + regret
        else:
            cumulative = regret
        self.history['cumulative_regret'].append(cumulative)
        return regret

    def get_diagnostics(self) -> Dict:
        return {
            'global_step': self.global_step,
            'd_estimate': self.d_est,
            'd_history': list(self.d_est_history),
            'environment_type': self.detect_environment_type(),
            'avg_window_size': np.mean([self.arm_data[i]['window_size'] for i in range(self.n_arms)]),
            'window_sizes': [self.arm_data[i]['window_size'] for i in range(self.n_arms)],
            'avg_lambda': np.mean([self.arm_data[i]['lambda_value'] for i in range(self.n_arms)]),
            'lambda_values': [self.arm_data[i]['lambda_value'] for i in range(self.n_arms)],
            'arm_counts': self.history['arm_counts'].tolist(),
            'cumulative_regret': self.history['cumulative_regret'][-1] if self.history['cumulative_regret'] else 0,
            'estimation_confidence': self.history['estimation_confidence'][-5:] if self.history['estimation_confidence'] else [0.4],
            'bias_corrections': self.history['bias_corrections'][-5:] if self.history['bias_corrections'] else [0.0],
            'config': {
                'use_diffusion': self.use_diffusion,
                'use_adaptive_window': self.use_adaptive_window,
                'use_theory_guided': self.use_theory_guided,
                'use_virtual_data': self.use_virtual_data,
                'environment_type': 'binary' if self.detect_environment_type() else 'continuous',
                'optimized_approach': True,
                'perfect_d': True
            }
        }
