import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
from typing import Dict, Optional, List, Tuple
import warnings


# ========== 嵌入层 ==========
class BanditEmbedding(nn.Module):
    """臂和时间步的联合嵌入层"""

    def __init__(self, n_arms, embed_dim, max_steps=10000):
        super().__init__()
        self.step_embed = nn.Embedding(max_steps, embed_dim)
        self.arm_embed = nn.Embedding(n_arms, embed_dim)
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, step, arm_idx):
        step_emb = self.step_embed(step.squeeze(-1))
        arm_emb = self.arm_embed(arm_idx.squeeze(-1))
        return self.layer_norm(step_emb + arm_emb)


# ========== 条件噪声预测网络 ==========
class BanditConditionalModel(nn.Module):
    """条件噪声预测网络（与Model623完全相同）"""

    def __init__(self, timesteps, n_arms, input_dim, hidden_dim=128):
        super().__init__()
        self.time_embed = nn.Embedding(timesteps, hidden_dim)
        self.context_embed = BanditEmbedding(n_arms, hidden_dim)
        self.hist_projection = nn.Linear(input_dim, hidden_dim)

        # 噪声预测网络
        self.noise_predictor = nn.Sequential(
            nn.Linear(387, hidden_dim),  # 修复：实际输入维度是387 (3 + 128*3)
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # 方差估计网络
        self.variance_estimator = nn.Sequential(
            nn.Linear(389, hidden_dim * 2),  # 修复：实际输入维度是389 (128*3 + 1*5)
            nn.GELU(),
            nn.LayerNorm(hidden_dim * 2),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Softplus()
        )

    def forward(self, hist_rewards, y_t, y0_hat, gx, t, step, arm_idx):
        # 与Model623完全相同的实现
        y0_hat = torch.zeros_like(y_t) if y0_hat is None else y0_hat
        gx = torch.zeros_like(y_t) if gx is None else gx
        time_emb = self.time_embed(t).unsqueeze(1)
        context_emb = self.context_embed(step, arm_idx).unsqueeze(1)

        hist_mean = hist_rewards.mean(dim=1).unsqueeze(1)  # [batch, 1, input_dim]
        hist_mean = self.hist_projection(hist_mean)  # [batch, 1, hidden_dim]

        window_var = torch.var(hist_rewards, dim=1, unbiased=False).unsqueeze(1)
        window_mean = hist_rewards.mean(dim=1).unsqueeze(1)
        window_max = hist_rewards.max(dim=1).values.unsqueeze(1)
        window_min = hist_rewards.min(dim=1).values.unsqueeze(1)

        # 保证所有拼接变量都是[batch, 1, feature]维度
        if y_t.dim() == 2:
            y_t = y_t.unsqueeze(1)
        if y0_hat.dim() == 2:
            y0_hat = y0_hat.unsqueeze(1)
        if gx.dim() == 2:
            gx = gx.unsqueeze(1)

        combined = torch.cat([y_t, y0_hat, gx, time_emb, context_emb, hist_mean], dim=-1)
        noise_pred = self.noise_predictor(combined)

        hist_sorted, _ = torch.sort(hist_rewards, dim=1)

        # 修复：安全计算分位数索引，避免越界
        window_size = hist_rewards.shape[1]
        if window_size > 1:
            index_25 = max(0, min(int(0.25 * (window_size - 1)), window_size - 1))
            index_75 = max(0, min(int(0.75 * (window_size - 1)), window_size - 1))
            q25 = hist_sorted[:, index_25:index_25 + 1, :]  # 保持 [1, 1, 1] 维度
            q75 = hist_sorted[:, index_75:index_75 + 1, :]  # 保持 [1, 1, 1] 维度
        else:
            # 如果窗口太小，使用默认值
            q25 = hist_sorted[:, 0:1, :]
            q75 = hist_sorted[:, 0:1, :]

        sigma_input = torch.cat([
            context_emb.squeeze(1),
            time_emb.squeeze(1),
            hist_mean.squeeze(1),
            window_var.squeeze(1),
            window_max.squeeze(1),
            window_min.squeeze(1),
            q25.squeeze(1),  # 修复：确保是2维 [1]
            q75.squeeze(1)  # 修复：确保是2维 [1]
        ], dim=-1)

        sigma = self.variance_estimator(sigma_input)
        return noise_pred, sigma


# ========== 主预测器类 ==========
class BanditDDPMPredictor720(nn.Module):
    """
    修复后的预测器：每个臂独立追踪历史数据和预测
    """

    def __init__(self, n_arms=1, hist_len=100, timesteps=20, device="cuda", learning_rate=1e-3):
        super().__init__()
        self.n_arms = n_arms
        self.hist_len = hist_len
        self.timesteps = timesteps
        self.device = device

        # 1. 扩散过程参数（与Model623相同）
        self.betas = self._cosine_beta_schedule(timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

        # 2. 条件扩散模型（与Model623相同）
        self.denoise_model = BanditConditionalModel(
            timesteps, n_arms, 1, hidden_dim=128  # 修复：input_dim=1，因为每个臂的历史数据是1维的
        )

        # 3. ✅ 修复：每个臂独立的历史数据管理
        self.hist_buffers = [deque(maxlen=hist_len) for _ in range(n_arms)]
        self.step_counters = [0 for _ in range(n_arms)]
        self.training_buffers = [deque(maxlen=500) for _ in range(n_arms)]
        self.training_losses = [deque(maxlen=100) for _ in range(n_arms)]

        # 4. ✅ 修复：每个臂独立的预测历史和校正因子
        self.pred_mu_histories = [[] for _ in range(n_arms)]
        self.variance_correction_factors = [1.0 for _ in range(n_arms)]
        self.correction_decay = 0.95

        # 5. 优化器
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

        # 6. ✅ 兼容性：保留全局步数计数器（用于与旧接口兼容）
        self.step_counter = 0

        self.to(device)

    def _cosine_beta_schedule(self, timesteps, s=0.008):
        """标准余弦调度（与Model623相同）"""
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0, 0.999).to(self.device)

    def _extract(self, arr, timesteps, broadcast_shape):
        """从序列中提取对应时间步的值（与Model623相同）"""
        device = arr.device
        res = arr.to(device=device)[timesteps].float()
        while len(res.shape) < len(broadcast_shape):
            res = res[..., None]
        return res.expand(broadcast_shape)

    def q_sample(self, y0, t, noise):
        """前向扩散过程（与Model623相同）"""
        sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, y0.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, y0.shape)
        return sqrt_alphas_cumprod_t * y0 + sqrt_one_minus_alphas_cumprod_t * noise

    def _predict_mean(self, y_t, t, noise_pred):
        """去噪预测（与Model623相同）"""
        sqrt_recip_alphas_t = self._extract(self.sqrt_recip_alphas, t, y_t.shape)
        sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, y_t.shape)
        return sqrt_recip_alphas_t * (y_t - sqrt_one_minus_alphas_cumprod_t * noise_pred)

    def update_history(self, new_rewards, arm_idx=0):
        """✅ 修复：更新指定臂的历史数据"""
        if isinstance(new_rewards, torch.Tensor):
            new_rewards = new_rewards.item()

        # 只更新对应臂的历史
        self.hist_buffers[arm_idx].append(float(new_rewards))
        self.step_counters[arm_idx] += 1

        # 保持兼容性：同时更新全局计数器
        self.step_counter += 1

    def predict_next_reward(self, hist_rewards=None, arm_idx=0, window_size=None):
        """
        ✅ 修复：使用臂特定的历史数据进行预测
        """
        if window_size is None:
            window_size = min(20, len(self.hist_buffers[arm_idx]))

        if len(self.hist_buffers[arm_idx]) < max(5, window_size):
            return 0.5, 0.1

        # ✅ 使用臂特定的历史数据
        arm_history = list(self.hist_buffers[arm_idx])
        hist_data = arm_history[-window_size:] if window_size <= len(arm_history) else arm_history
        hist = torch.tensor(hist_data, device=self.device).unsqueeze(0).unsqueeze(-1)

        arm_idx_tensor = torch.tensor([[arm_idx]], device=self.device)
        step_tensor = torch.tensor([[self.step_counters[arm_idx]]], device=self.device)
        y_t = torch.zeros(1, 1, 1, device=self.device)
        t = torch.tensor([0], device=self.device)

        with torch.no_grad():
            # 1. 准备输入（使用臂特定的历史）
            y0 = hist[:, -1].clone()  # [1, 1] - 2维
            noise = torch.randn_like(y0)  # [1, 1] - 2维
            y_t_sampled = self.q_sample(y0, t, noise)  # 可能变成3维

            # 修复：确保y_t_sampled有正确的3维形状 [1, 1, 1]
            if y_t_sampled.dim() == 2:
                y_t_sampled = y_t_sampled.unsqueeze(1)
            elif y_t_sampled.dim() > 3:
                y_t_sampled = y_t_sampled.squeeze()

            # 2. 条件去噪预测
            noise_pred, sigma_pred = self.denoise_model(
                hist, y_t_sampled, None, None, t, step_tensor, arm_idx_tensor
            )

            # 3. 计算预测均值
            mu_pred = self._predict_mean(y_t_sampled, t, noise_pred)

            # 4. ✅ 方差校准（使用臂特定的方差校正因子）
            true_var = torch.var(hist[0, -window_size:], unbiased=False).item()
            pred_var = sigma_pred.mean().item()

            if pred_var > 1e-6:
                correction = true_var / pred_var
                self.variance_correction_factors[arm_idx] = (
                        self.correction_decay * self.variance_correction_factors[arm_idx] +
                        (1 - self.correction_decay) * correction
                )
            sigma_pred = sigma_pred * self.variance_correction_factors[arm_idx]

            # 5. ✅ 记录预测值（臂特定）
            self.pred_mu_histories[arm_idx].append(mu_pred.mean().item())

            return mu_pred[0, 0].item(), sigma_pred[0, 0].item()

    def add_training_sample(self, hist_state, next_reward, arm_idx=0):
        """✅ 修复：添加训练样本到对应臂的缓冲区"""
        # 处理torch.Tensor类型的输入
        if isinstance(hist_state, torch.Tensor):
            hist_state = hist_state.squeeze().cpu().numpy()
        if isinstance(next_reward, torch.Tensor):
            next_reward = next_reward.item()

        if len(hist_state) >= 5:
            self.training_buffers[arm_idx].append((hist_state, next_reward))

    def train_step(self, arm_idx=0):
        """✅ 修复：使用臂特定的训练数据进行训练"""
        if len(self.training_buffers[arm_idx]) < 10:
            return 0.0

        # 简化的训练逻辑
        batch_size = min(16, len(self.training_buffers[arm_idx]))
        batch = np.random.choice(len(self.training_buffers[arm_idx]), batch_size, replace=False)

        total_loss = 0.0

        for idx in batch:
            hist_state, target_reward = self.training_buffers[arm_idx][idx]

            # 转换为张量
            hist_tensor = torch.tensor(hist_state, device=self.device).float().unsqueeze(0).unsqueeze(-1)
            target_tensor = torch.tensor([target_reward], device=self.device).float().unsqueeze(0).unsqueeze(0)

            # 随机时间步
            t = torch.randint(0, self.timesteps, (1,), device=self.device)

            # 添加噪声
            noise = torch.randn_like(target_tensor)
            y_t = self.q_sample(target_tensor, t, noise)

            # 预测噪声
            step_tensor = torch.tensor([[self.step_counters[arm_idx]]], device=self.device)
            arm_tensor = torch.tensor([[arm_idx]], device=self.device)

            noise_pred, _ = self.denoise_model(
                hist_tensor, y_t, None, None, t, step_tensor, arm_tensor
            )

            # 计算损失
            loss = F.mse_loss(noise_pred, noise)
            total_loss += loss.item()

            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        avg_loss = total_loss / batch_size
        self.training_losses[arm_idx].append(avg_loss)

        return avg_loss

    def __call__(self, arm_idx=0):
        """✅ 修复：调用接口，返回臂特定的预测结果"""
        # 使用默认窗口大小，如果没有足够的历史数据则返回默认值
        default_window_size = min(20, len(self.hist_buffers[arm_idx]))
        if default_window_size < 5:
            return 0.5, 0.1

        return self.predict_next_reward(arm_idx=arm_idx, window_size=default_window_size)

    def get_diagnostics(self, arm_idx=0):
        """✅ 修复：获取臂特定的诊断信息"""
        recent_loss = np.mean(list(self.training_losses[arm_idx])[-10:]) if self.training_losses[arm_idx] else 0.0

        return {
            'arm_idx': arm_idx,
            'window_size': len(self.hist_buffers[arm_idx]),
            'hist_len': len(self.hist_buffers[arm_idx]),
            'recent_loss': recent_loss,
            'step_counter': self.step_counters[arm_idx],
            'training_samples': len(self.training_buffers[arm_idx]),
            'variance_correction': self.variance_correction_factors[arm_idx],
            'recent_predictions': self.pred_mu_histories[arm_idx][-5:] if self.pred_mu_histories[arm_idx] else [],
            'recent_history': list(self.hist_buffers[arm_idx])[-10:] if self.hist_buffers[arm_idx] else []
        }

    def get_global_diagnostics(self):
        """获取所有臂的全局诊断信息"""
        diagnostics = {}
        for arm_idx in range(self.n_arms):
            diagnostics[f'arm_{arm_idx}'] = self.get_diagnostics(arm_idx)

        diagnostics['global_step'] = self.step_counter
        diagnostics['total_arms'] = self.n_arms

        return diagnostics

    def forward(self, x, x_mark, y_t, y0_hat, gx, t, window_size=None):
        """
        ✅ 修复：前向传播（与Model623兼容，但使用臂特定数据）
        """
        arm_idx = x.item() if torch.is_tensor(x) else x

        if window_size is None:
            window_size = min(20, len(self.hist_buffers[arm_idx]))

        if len(self.hist_buffers[arm_idx]) < window_size:
            return torch.tensor([[0.5]], device=self.device), torch.tensor([[0.1]], device=self.device)

        # ✅ 使用臂特定的历史数据
        arm_history = list(self.hist_buffers[arm_idx])
        hist_data = arm_history[-window_size:] if window_size <= len(arm_history) else arm_history
        hist = torch.tensor(hist_data, device=self.device).unsqueeze(0).unsqueeze(-1)

        step = x_mark

        # 1. 准备输入（与Model623相同）
        y0 = hist[:, -1].clone()
        noise = torch.randn_like(y0)
        y_t_sampled = self.q_sample(y0, t, noise)

        # 2. 条件去噪预测
        noise_pred, sigma_pred = self.denoise_model(
            hist, y_t_sampled, None, gx, t, step, torch.tensor([[arm_idx]], device=self.device)
        )

        # 3. 计算预测均值
        mu_pred = self._predict_mean(y_t_sampled, t, noise_pred)

        # 4. ✅ 方差校准（使用臂特定的校正因子）
        true_var = torch.var(hist[0, -window_size:], unbiased=False).item()
        pred_var = sigma_pred.mean().item()

        if pred_var > 1e-6:
            correction = true_var / pred_var
            self.variance_correction_factors[arm_idx] = (
                    self.correction_decay * self.variance_correction_factors[arm_idx] +
                    (1 - self.correction_decay) * correction
            )
        sigma_pred = sigma_pred * self.variance_correction_factors[arm_idx]

        # 5. ✅ 记录预测值（臂特定）
        self.pred_mu_histories[arm_idx].append(mu_pred.mean().item())

        return mu_pred, sigma_pred