import torch
import numpy as np

class DynamicUncertaintySelector:
    def __init__(self, num_samples, window_size=5, start_epoch=10, keep_ratio=0.6, update=1):
        self.num_samples = num_samples
        self.window_size = window_size  # k
        self.start_epoch = start_epoch  # J
        self.keep_ratio = keep_ratio
        self.updatef = update

        # 初始化滑动窗口 [N, k]
        self.pred_history = torch.zeros(num_samples, window_size).cuda()
        self.ptr = torch.zeros(num_samples, dtype=torch.long).cuda()  # 每个样本的写入位置
        self.epoch = 0

    def update(self, sample_indices, logits):
        """
        更新当前 epoch 的预测概率到滑动窗口中。
        Args:
            sample_indices: Tensor [B]，表示当前 batch 中样本在全集中的 index
            pred_probs: Tensor [B]，当前 batch 中每个样本对其真实标签的预测概率
        """
        
        probs = torch.softmax(logits, dim=1)
        max_probs = torch.max(probs, dim=1).values  # [B]
        
        for i, idx in enumerate(sample_indices):
            pos = self.ptr[idx].item()
            self.pred_history[idx, pos] = max_probs[i]
            self.ptr[idx] = (pos + 1) % self.window_size

    def select(self):
        """
        计算标准差作为伪标签置信度的波动 → 用作样本不确定性
        Returns:
            keep_indices: Tensor [K]，保留样本索引（Top-K 不确定性）
        """
        self.epoch += 1
        if self.epoch < self.start_epoch:
            return torch.arange(self.num_samples).cuda()  # 返回所有样本索引

        with torch.no_grad():
            uncertainties = torch.std(self.pred_history, dim=1)  # shape: [N]
            num_keep = int(self.keep_ratio * self.num_samples)
            topk = torch.topk(uncertainties, num_keep)
            keep_indices = topk.indices
            return keep_indices
        


class DynamicEntropySelector:
    def __init__(self, num_samples, window_size=5, start_epoch=10, keep_ratio=0.6, update=1):
        self.num_samples = num_samples
        self.window_size = window_size
        self.start_epoch = start_epoch
        self.keep_ratio = keep_ratio
        self.updatef = update

        self.entropy_history = torch.zeros(num_samples, window_size).cuda()
        self.pseudo_history = torch.zeros(num_samples, window_size).cuda()
        self.conf_history = torch.zeros(num_samples, window_size).cuda()
        self.ptr = torch.zeros(num_samples, dtype=torch.long).cuda()
        self.epoch = 0

    def update(self, sample_indices, logits):
        """
        使用每个样本的 softmax 熵作为不确定性，加入滑动窗口。
        Args:
            sample_indices: Tensor [B]
            logits: Tensor [B, C]
        """
        probs = torch.softmax(logits, dim=1) + 1e-8  # 避免 log(0)
        entropy = -torch.sum(probs * torch.log(probs), dim=1)  # shape: [B]

        for i, idx in enumerate(sample_indices):
            pos = self.ptr[idx].item()
            self.entropy_history[idx, pos] = entropy[i]
            self.pseudo_history[idx, pos] = probs[i].argmax().item()  # 记录伪标签
            self.conf_history[idx, pos] = probs[i].max().item()
            self.ptr[idx] = (pos + 1) % self.window_size

        

    def select(self):
        """
        返回动态熵不确定性（std over entropy）最高的 top-k 样本索引
        """
        self.epoch += 1
        
        if self.epoch < self.start_epoch:
            return torch.arange(self.num_samples).cuda()

        with torch.no_grad():
            uncertainties = torch.std(self.entropy_history, dim=1)  # shape: [N]
            num_keep = int(self.keep_ratio * self.num_samples)
            topk = torch.topk(uncertainties, num_keep)
            return topk.indices
        
        
class DynamicPseudoSelector:
    def __init__(self, num_samples, window_size=5, start_epoch=10, keep_ratio=0.6, update=1):
        self.num_samples = num_samples
        self.window_size = window_size
        self.start_epoch = start_epoch
        self.keep_ratio = keep_ratio
        self.updatef = update

        self.entropy_history = torch.zeros(num_samples, window_size).cuda()
        self.pseudo_history = torch.zeros(num_samples, window_size).cuda()
        self.conf_history = torch.zeros(num_samples, window_size).cuda()
        self.ptr = torch.zeros(num_samples, dtype=torch.long).cuda()
        self.epoch = 0

    def update(self, sample_indices, logits):
        """
        使用每个样本的 softmax 熵作为不确定性，加入滑动窗口。
        Args:
            sample_indices: Tensor [B]
            logits: Tensor [B, C]
        """
        probs = torch.softmax(logits, dim=1) + 1e-8  # 避免 log(0)
        entropy = -torch.sum(probs * torch.log(probs), dim=1)  # shape: [B]

        for i, idx in enumerate(sample_indices):
            pos = self.ptr[idx].item()
            self.entropy_history[idx, pos] = entropy[i]
            self.pseudo_history[idx, pos] = probs[i].argmax().item()  # 记录伪标签
            self.conf_history[idx, pos] = probs[i].max().item()
            self.ptr[idx] = (pos + 1) % self.window_size

        

    def select(self):
        """
        返回动态熵不确定性（std over entropy）最高的 top-k 样本索引
        """
        self.epoch += 1
        
        if self.epoch < self.start_epoch:
            return torch.arange(self.num_samples).cuda()

        with torch.no_grad():
            uncertainties = torch.std(self.entropy_history, dim=1)  # shape: [N]
            
            label_changes = torch.zeros(self.num_samples).cuda()
            for i in range(self.window_size - 1):
                changed = self.pseudo_history[:, i] != self.pseudo_history[:, i + 1]
                valid = (self.pseudo_history[:, i] >= 0) & (self.pseudo_history[:, i + 1] >= 0)
                label_changes += (changed & valid).float()
            num_keep = int(self.keep_ratio * self.num_samples)
            selected = torch.topk(-label_changes, num_keep).indices  # 越小越好，取负号后用 topk
            return selected
        