import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np

# Original ce_loss function (added explicitly for consistency_loss compatibility)
def ce_loss(logits, targets, use_hard_labels=True, reduction='mean'):
    """
    Cross-entropy loss.
    Args:
        logits: Model output logits (B, C)
        targets: Target labels, can be hard labels (B,) or soft labels (B, C)
        use_hard_labels: Whether to use hard labels
        reduction: 'none', 'mean', 'sum'
    """
    if use_hard_labels:
        # Hard labels: use standard cross-entropy
        log_pred = F.log_softmax(logits, dim=-1)
        loss = F.nll_loss(log_pred, targets, reduction=reduction)
    else:
        # Soft labels: use KL divergence form
        log_pred = F.log_softmax(logits, dim=-1)
        loss = -torch.sum(targets * log_pred, dim=-1)
        if reduction == 'mean':
            loss = loss.mean()
        elif reduction == 'sum':
            loss = loss.sum()
    return loss

# Original classes and functions (unchanged)
class partial_loss(nn.Module):
    def __init__(self, train_givenY):
        super().__init__()
        print('Calculating uniform targets...')
        tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1])
        confidence = train_givenY.float()/tempY
        confidence = confidence.cuda()
        self.confidence = confidence

    def forward(self, outputs, index, targets=None):
        logsm_outputs = F.log_softmax(outputs, dim=1)
        if targets is None:
            final_outputs = logsm_outputs * self.confidence[index, :].detach()
        else:
            final_outputs = logsm_outputs * targets.detach()
        loss_vec = - ((final_outputs).sum(dim=1))
        average_loss = loss_vec.mean()
        return average_loss, loss_vec

    @torch.no_grad()
    def confidence_update(self, temp_un_conf, batch_index):
        self.confidence[batch_index, :] = temp_un_conf
        return None


class SupConLoss(nn.Module):
    """Supervised Contrastive Loss"""
    def __init__(self, temperature=0.07, base_temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature

    def forward(self, features, mask=None, batch_size=-1):
        device = torch.device('cuda' if features.is_cuda else torch.device('cpu'))

        if mask is not None:
            mask = mask.float().detach().to(device)
            anchor_dot_contrast = torch.div(
                torch.matmul(features[:batch_size], features.T),
                self.temperature)
            logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
            logits = anchor_dot_contrast - logits_max.detach()

            logits_mask = torch.scatter(
                torch.ones_like(mask), 1,
                torch.arange(batch_size).view(-1, 1).to(device), 0)
            mask = mask * logits_mask

            exp_logits = torch.exp(logits) * logits_mask
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
            mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
            loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
            loss = loss.mean()
        else:
            q = features[:batch_size]
            k = features[batch_size:batch_size*2]
            queue = features[batch_size*2:]
            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
            l_neg = torch.einsum('nc,kc->nk', [q, queue])
            logits = torch.cat([l_pos, l_neg], dim=1)
            logits /= self.temperature
            labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
            loss = F.cross_entropy(logits, labels)

        return loss


def consistency_loss(logits_w, logits_s, sin_label_idx, name='ce', T=1.0, p_cutoff=0.0, use_hard_labels=True):
    assert name in ['ce', 'L2']
    logits_w = logits_w.detach()
    if name == 'L2':
        assert logits_w.size() == logits_s.size()
        pred_w = torch.softmax(logits_w, dim=1).detach()
        pred_s = torch.softmax(logits_s, dim=1).detach()
        return F.mse_loss(pred_s, pred_w, reduction='mean')
    elif name == 'ce':
        pseudo_label = torch.softmax(logits_w, dim=-1)
        max_probs = pseudo_label[range(pseudo_label.shape[0]), sin_label_idx]
        mask = max_probs.ge(p_cutoff).float()
        if use_hard_labels:
            masked_loss = ce_loss(logits_s, sin_label_idx, use_hard_labels, reduction='none') * mask
        else:
            pseudo_label = torch.softmax(logits_w/T, dim=-1)
            masked_loss = ce_loss(logits_s, pseudo_label, use_hard_labels) * mask
        return masked_loss.mean(), mask.mean()
    else:
        raise Exception('Not Implemented consistency_loss')


def jin_lossb(outputs, partialY):
    Y = partialY/partialY.sum(dim=1,keepdim=True)
    q = 0.7
    sm_outputs = F.softmax(outputs, dim=1)
    pow_outputs = torch.pow(sm_outputs, q)
    sample_loss = (1-(pow_outputs*Y).sum(dim=1))/q 
    return sample_loss


def jin_lossu(outputs, partialY):
    Y = partialY/partialY.sum(dim=1,keepdim=True)
    logsm = nn.LogSoftmax(dim=1)
    logsm_outputs = logsm(outputs)
    final_outputs = logsm_outputs * Y
    sample_loss = - final_outputs.sum(dim=1)
    return sample_loss


def cour_lossb(outputs, partialY):
    sm_outputs = F.softmax(outputs, dim=1)
    candidate_outputs = ((sm_outputs*partialY).sum(dim=1))/(partialY.sum(dim=1))
    sig = nn.Sigmoid()
    candidate_loss = sig(candidate_outputs) 
    noncandidate_loss = (sig(-sm_outputs)*(1-partialY)).sum(dim=1) 
    sample_loss = (candidate_loss + noncandidate_loss).mean()
    return sample_loss


def squared_hinge_loss(z):
    hinge = torch.clamp(1-z, min=0)
    return hinge*hinge


def cour_lossu(outputs, partialY):
    sm_outputs = F.softmax(outputs, dim=1)
    candidate_outputs = ((sm_outputs*partialY).sum(dim=1))/(partialY.sum(dim=1))
    candidate_loss = squared_hinge_loss(candidate_outputs) 
    noncandidate_loss = (squared_hinge_loss(-sm_outputs)*(1-partialY)).sum(dim=1) 
    sample_loss = (candidate_loss + noncandidate_loss).mean()
    return sample_loss


def mae_loss(outputs, partialY):
    sm_outputs = F.softmax(outputs, dim=1)
    loss_fn = nn.L1Loss(reduction='none')
    loss_matrix = loss_fn(sm_outputs, partialY.float())
    sample_loss = loss_matrix.sum(dim=-1)
    return sample_loss


def mse_loss(outputs, Y):
    sm_outputs = F.softmax(outputs, dim=1)
    loss_fn = nn.MSELoss(reduction='none')
    loss_matrix = loss_fn(sm_outputs, Y.float())
    sample_loss = loss_matrix.sum(dim=-1)
    return sample_loss


def gce_loss(outputs, Y):
    q = 0.7
    sm_outputs = F.softmax(outputs, dim=1)
    pow_outputs = torch.pow(sm_outputs, q)
    sample_loss = (1-(pow_outputs*Y).sum(dim=1))/q
    return sample_loss


def phuber_ce_loss(outputs, Y):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trunc_point = 0.1
    n = Y.shape[0]
    soft_max = nn.Softmax(dim=1)
    sm_outputs = soft_max(outputs)
    final_outputs = sm_outputs * Y
    final_confidence = final_outputs.sum(dim=1)
  
    ce_index = (final_confidence > trunc_point)
    sample_loss = torch.zeros(n).to(device)

    if ce_index.sum() > 0:
        ce_outputs = outputs[ce_index,:]
        logsm = nn.LogSoftmax(dim=-1)
        logsm_outputs = logsm(ce_outputs)
        final_ce_outputs = logsm_outputs * Y[ce_index,:]
        sample_loss[ce_index] = - final_ce_outputs.sum(dim=-1)

    linear_index = (final_confidence <= trunc_point)

    if linear_index.sum() > 0:
        sample_loss[linear_index] = -math.log(trunc_point) + (-1/trunc_point)*final_confidence[linear_index] + 1

    return sample_loss


def cce_loss(outputs, Y):
    logsm = nn.LogSoftmax(dim=1)
    logsm_outputs = logsm(outputs)
    final_outputs = logsm_outputs * Y
    sample_loss = - final_outputs.sum(dim=1)
    return sample_loss


def focal_loss(outputs, Y):
    logsm = nn.LogSoftmax(dim=1)
    logsm_outputs = logsm(outputs)
    soft_max = nn.Softmax(dim=1)
    sm_outputs = soft_max(outputs)
    final_outputs = logsm_outputs * Y * (1-sm_outputs) ** 0.5
    sample_loss = - final_outputs.sum(dim=1)
    return sample_loss


def pll_estimator(loss_fn, outputs, partialY, device):
    n, k = partialY.shape[0], partialY.shape[1]
    comp_num = partialY.sum(dim=1)
    temp_loss = torch.zeros(n, k).to(device)

    for i in range(k):
        tempY = torch.zeros(n, k).to(device)
        tempY[:, i] = 1.0 
        temp_loss[:, i] = loss_fn(outputs, tempY)

    coef = 1.0 / comp_num
    total_loss = coef * (temp_loss*partialY).sum(dim=1) 
    total_loss = total_loss.sum()
    return total_loss


def cc_loss(outputs, partialY):
    sm_outputs = F.softmax(outputs, dim=1)
    final_outputs = sm_outputs * partialY
    average_loss = - torch.log(final_outputs.sum(dim=1)).mean()
    return average_loss


def min_loss(outputs, partialY):
    sm_outputs = F.softmax(outputs, dim=1)
    logsm_outputs = F.log_softmax(outputs, dim=1)
    temp_outputs = sm_outputs * partialY
    minY = torch.zeros_like(partialY)
    idx = torch.argmax(temp_outputs, dim=1)
    minY[torch.arange(0, len(idx)), idx] = 1
    loss = - (logsm_outputs * minY).sum() / len(minY)
    return loss


class proden_loss:
    def __init__(self, train_p_Y, device):
        self.conf = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.conf = self.conf.to(device)
        self.device = device
    
    def __call__(self, output1, indexes):
        target = self.conf[indexes].clone().detach()
        output = F.softmax(output1, dim=1)
        l = target * torch.log(output + 1e-8)
        loss = (-torch.sum(l)) / l.size(0)
        return loss

    def update_conf(self, output1, indexes):
        target = self.conf[indexes].clone().detach()
        output = F.softmax(output1, dim=1)
        revisedY = target.clone()
        revisedY[revisedY > 0] = 1
        revisedY = revisedY * output
        revisedY = revisedY / (revisedY.sum(dim=1, keepdim=True) + 1e-8)
        self.conf[indexes,:] = revisedY.clone().detach()


class rc_loss:
    def __init__(self, train_p_Y, device):
        self.conf = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.conf = self.conf.to(device)
        self.device = device
    
    def __call__(self, outputs, index):
        logsm_outputs = F.log_softmax(outputs, dim=1)
        final_outputs = logsm_outputs * self.conf[index, :]
        average_loss = - ((final_outputs).sum(dim=1)).mean()
        return average_loss
    
    def update_conf(self, model, batchX, batchY, batch_index):
        confidence = self.conf.clone().detach()
        with torch.no_grad():
            batch_outputs = model(batchX)
            temp_un_conf = F.softmax(batch_outputs, dim=1)
            confidence[batch_index,:] = temp_un_conf * batchY
            base_value = confidence.sum(dim=1).unsqueeze(1).repeat(1, confidence.shape[1])
            confidence = confidence/base_value
        self.conf = confidence.clone().detach()


class cavl_loss:
    def __init__(self, train_p_Y, device):
        self.conf = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.conf = self.conf.to(device)
        self.device = device
    
    def __call__(self, outputs, index):
        logsm_outputs = F.log_softmax(outputs, dim=1)
        final_outputs = logsm_outputs * self.conf[index, :]
        average_loss = - ((final_outputs).sum(dim=1)).mean()
        return average_loss
    
    def update_conf(self, model, batchX, batchY, batch_index):
        confidence = self.conf.clone().detach()
        with torch.no_grad():
            batch_outputs = model(batchX)
            cav = (batch_outputs * torch.abs(1 - batch_outputs)) * batchY
            cav_pred = torch.max(cav, dim=1)[1]
            gt_label = F.one_hot(cav_pred, batchY.shape[1])
            confidence[batch_index, :] = gt_label.float()
        self.conf = confidence.clone().detach()
        return confidence


class lws_loss:
    def __init__(self, train_p_Y, device, lw_weight=1, lw_weight0=1, epoch_ratio=None):
        self.conf = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.conf = self.conf.to(device)
        self.device = device
        self.lw_weight = lw_weight
        self.lw_weight0 = lw_weight0
        self.epoch_ratio = epoch_ratio
    
    def __call__(self, outputs, partialY, index):
        device = self.device
        confidence = self.conf.clone().detach()
        lw_weight = self.lw_weight
        lw_weight0 = self.lw_weight0

        onezero = torch.zeros(outputs.shape[0], outputs.shape[1])
        onezero[partialY > 0] = 1
        counter_onezero = 1 - onezero
        onezero = onezero.to(device)
        counter_onezero = counter_onezero.to(device)

        sig_loss1 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1])
        sig_loss1 = sig_loss1.to(device)
        sig_loss1[outputs < 0] = 1 / (1 + torch.exp(outputs[outputs < 0]))
        sig_loss1[outputs > 0] = torch.exp(-outputs[outputs > 0]) / (1 + torch.exp(-outputs[outputs > 0]))
        l1 = confidence[index, :] * onezero * sig_loss1
        average_loss1 = torch.sum(l1) / l1.size(0)

        sig_loss2 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1])
        sig_loss2 = sig_loss2.to(device)
        sig_loss2[outputs > 0] = 1 / (1 + torch.exp(-outputs[outputs > 0]))
        sig_loss2[outputs < 0] = torch.exp(outputs[outputs < 0]) / (1 + torch.exp(outputs[outputs < 0]))
        l2 = confidence[index, :] * counter_onezero * sig_loss2
        average_loss2 = torch.sum(l2) / l2.size(0)

        average_loss = lw_weight0 * average_loss1 + lw_weight * average_loss2
        return average_loss
    
    def update_conf(self, model, batchX, batchY, batch_index):
        confidence = self.conf.clone().detach()
        with torch.no_grad():
            device = self.device
            batch_outputs = model(batchX)
            sm_outputs = F.softmax(batch_outputs, dim=1)

            onezero = torch.zeros(sm_outputs.shape[0], sm_outputs.shape[1])
            onezero[batchY > 0] = 1
            counter_onezero = 1 - onezero
            onezero = onezero.to(device)
            counter_onezero = counter_onezero.to(device)

            new_weight1 = sm_outputs * onezero
            new_weight1 = new_weight1 / (new_weight1 + 1e-8).sum(dim=1).repeat(confidence.shape[1], 1).transpose(0, 1)
            new_weight2 = sm_outputs * counter_onezero
            new_weight2 = new_weight2 / (new_weight2 + 1e-8).sum(dim=1).repeat(confidence.shape[1], 1).transpose(0, 1)
            new_weight = new_weight1 + new_weight2

            confidence[batch_index, :] = new_weight
        self.conf = confidence.clone().detach()


class d2cnn_loss:
    def __init__(self, train_p_Y, device=torch.device('cuda')):
        num_samples, num_classes = train_p_Y.shape
        self.train_p_Y = train_p_Y
        self.Z = torch.zeros([num_samples, num_classes])
        self.teacher = torch.zeros([num_samples, num_classes])
        self.weights = torch.zeros(num_samples)
        self.tmax = 100
        self.vt = 0.0
        self.alpha = 1e-3
        self.epoch = 1
        self.device = device

    def __call__(self, output, target, indexes):
        def dcnn_loss(output, target, teacher, weights, vt, alpha):
            loss_cr = _cross_entropy(output, target)
            loss_se = self_entropy(output)
            loss_STcr = ST_cross_entropy(output, teacher, weights)
            loss = loss_cr + alpha*loss_se + vt*loss_STcr
            return loss

        def _cross_entropy(prediction, labels):
            _cross_entropy_singel = -torch.sum((1-labels)*torch.log((1-prediction) + 1e-5), dim=1)
            _cross_entropy_mean = torch.mean(_cross_entropy_singel)
            return _cross_entropy_mean

        def self_entropy(prediction):
            self_entropy_singel = -torch.sum(prediction*torch.log(prediction + 1e-10), dim=1)
            self_entropy_mean = torch.mean(self_entropy_singel)
            return self_entropy_mean

        def ST_cross_entropy(prediction, teacher, weights):
            cross_entropy_singel = -weights*torch.sum((teacher*torch.log(prediction + 1e-5)), dim=1)
            cross_entropy_mean = torch.mean(cross_entropy_singel)
            return cross_entropy_mean
        
        teacher_batch = torch.Tensor(self.teacher[indexes]).to(self.device)
        weights_batch = torch.Tensor(self.weights[indexes]).to(self.device)
        Z_batch = torch.Tensor(self.Z[indexes]).to(self.device)

        if self.epoch < 2:
            self.vt = 0.0
        elif self.epoch <= 200:
            self.vt = self.tmax * (np.exp(-5.0 * np.square((self.epoch-2)/(200-2)-1.0)))
        else:
            self.vt = self.tmax
        
        output = F.softmax(output, dim=1)
        loss = dcnn_loss(output, target, teacher_batch, weights_batch, self.vt, self.alpha)

        Z_batch = 0.6 * Z_batch + (1.0-0.6) * output
        teacher_batch = torch.div(Z_batch, (1-pow(0.6, self.epoch)))
        for j, k in enumerate(indexes):
            self.Z[k,:] = Z_batch[j,:].detach().clone()
            self.teacher[k,:] = teacher_batch[j,:].detach().clone()

        return loss

    def update(self, epoch):
        self.epoch = epoch + 1
        y_partial = self.train_p_Y.numpy()
        teachernp = self.teacher.numpy()
        label_bool = y_partial.astype(np.bool_)
        candidate = label_bool.astype(np.float32) * teachernp
        candidate_max = np.max(candidate, axis=1)
        noncandidate = np.logical_not(label_bool).astype(np.float32) * teachernp
        noncandidate_max = np.max(noncandidate, axis=1)
        margin_single = candidate_max - noncandidate_max
        margin_single[np.argwhere(margin_single<0)] = 0
        self.weights = torch.from_numpy(np.square(margin_single))


def exp_loss(outputs, Y):
    sm_outputs = F.softmax(outputs, dim=1)
    margin = torch.sum(Y.float() * sm_outputs, dim=1)
    non_candidate_margin = torch.sum((1 - Y.float()) * sm_outputs, dim=1)
    exp_term = torch.exp(-margin + non_candidate_margin)
    sample_loss = exp_term
    return sample_loss


class VAE_Bernulli_Decoder(nn.Module):
    """VALEN's variational autoencoder decoder"""
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(VAE_Bernulli_Decoder, self).__init__()
        
        self.L1 = nn.Linear(n_inputs, n_hidden, bias=False)
        nn.init.xavier_uniform_(self.L1.weight)
        self.bn1 = nn.BatchNorm1d(n_hidden)
        nn.init.ones_(self.bn1.weight)
        
        self.L2 = nn.Linear(n_hidden, n_outputs)
        nn.init.xavier_uniform_(self.L2.weight)
        nn.init.zeros_(self.L2.bias)
    
    def forward(self, x):
        x = self.L1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.L2(x)
        return x


class LinearEncoder(nn.Module):
    """VALEN's linear encoder"""
    def __init__(self, n_inputs, n_outputs):
        super(LinearEncoder, self).__init__()
        self.L1 = nn.Linear(n_inputs, n_outputs)
        nn.init.xavier_uniform_(self.L1.weight)
    
    def forward(self, x):
        x = self.L1(x)
        return x


class valen_loss:
    def __init__(self, train_p_Y, num_features, device, warm_up_epochs=10):
        """
        VALEN loss function.
        Args:
            train_p_Y: Initial candidate label matrix (N, C)
            num_features: Feature dimension
            device: Compute device
            warm_up_epochs: Warm-up epochs
        """
        self.device = device
        self.num_samples = train_p_Y.shape[0]
        self.num_classes = train_p_Y.shape[1]
        self.num_features = num_features
        self.warm_up_epochs = warm_up_epochs
        
        # Initialize o_array and d_array for label distributions
        self.o_array = (train_p_Y / train_p_Y.sum(dim=1, keepdim=True)).to(device)
        self.d_array = self.o_array.clone().detach()
        
        # Initialize encoder and decoder
        self.encoder = LinearEncoder(num_features, self.num_classes).to(device)
        self.decoder = VAE_Bernulli_Decoder(self.num_classes, num_features, num_features).to(device)
        
        # Dirichlet prior
        self.prior_alpha = torch.ones(1, self.num_classes).to(device)
        
        # Adjacency matrix (computed after warm-up)
        self.adj_matrix = None
        self.A_dense = None
        
        self.current_epoch = 0
        self.is_warmed_up = False
    
    def warm_up(self, model, train_loader, optimizer):
        """Warm-up phase, train with partial loss"""
        print(f"==> VALEN Warm-up: {self.warm_up_epochs} epochs")
        model.train()
        
        for epoch in range(self.warm_up_epochs):
            for i, (images_w, images_s, labels, true_labels, index) in enumerate(train_loader):
                images = images_w.to(self.device)
                labels = labels.to(self.device)
                
                outputs = model(images)
                loss, new_labels = self._partial_loss(outputs, self.o_array[index, :])
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update o_array
                self.o_array[index, :] = new_labels.clone().detach()
        
        self.is_warmed_up = True
        print("==> VALEN Warm-up completed")
    
    def _partial_loss(self, output, target, eps=1e-12):
        """Partial label loss"""
        output_sm = F.softmax(output, dim=1)
        l = target * torch.log(output_sm + eps)
        loss = (-torch.sum(l)) / l.size(0)
        
        # Update pseudo labels
        revisedY = target.clone()
        revisedY[revisedY > 0] = 1
        revisedY = revisedY * output_sm.clone().detach()
        revisedY = revisedY / (revisedY.sum(dim=1, keepdim=True) + eps)
        
        return loss, revisedY
    
    def _alpha_loss(self, alpha, prior_alpha, eps=1e-8):
        """Dirichlet KL divergence loss (numerically stable version)"""
        # Clamp for numerical stability
        alpha = torch.clamp(alpha, min=eps, max=30.0)
        prior_alpha = torch.clamp(prior_alpha.expand_as(alpha), min=eps, max=30.0)
        
        # Normalized distributions
        alpha_norm = alpha / (alpha.sum(dim=1, keepdim=True) + eps)
        prior_norm = prior_alpha / (prior_alpha.sum(dim=1, keepdim=True) + eps)
        
        # KL divergence approximation
        kl_approx = (alpha_norm * (torch.log(alpha_norm + eps) - torch.log(prior_norm + eps))).sum(dim=1)
        
        return kl_approx.mean()
    
    def _dot_product_decode(self, Z):
        """Graph decoder"""
        A_pred = torch.sigmoid(torch.matmul(Z, Z.t()))
        return A_pred
    
    def _revised_target(self, output, target, eps=1e-12):
        """Revise target labels"""
        revisedY = target.clone()
        revisedY[revisedY > 0] = 1
        revisedY = revisedY * output.clone().detach()
        revisedY = revisedY / (revisedY.sum(dim=1, keepdim=True) + eps)
        return revisedY
    
    def compute_adj_matrix(self, features, k=3):
        """Compute adjacency matrix (KNN-based)"""
        try:
            from sklearn.metrics.pairwise import euclidean_distances  # Fixed import
        except ImportError:
            print("Warning: sklearn not available, skipping adjacency matrix computation")
            return
        
        print(f"==> Computing adjacency matrix (k={k})")
        X = features.cpu().numpy()
        dm = euclidean_distances(X, X)
        adj_m = np.zeros_like(dm)
        row = np.arange(0, X.shape[0])
        dm[row, row] = np.inf
        
        for _ in range(k):
            col = np.argmin(dm, axis=1)
            dm[row, col] = np.inf
            adj_m[row, col] = 1.0
        
        self.A_dense = torch.from_numpy(adj_m).float().to(self.device)
        print("==> Adjacency matrix computed")
    
    def __call__(self, model, images, features, targets, indexes, 
                 alpha=1.0, beta=1.0, gamma=1.0, theta=1.0, correct=1.0):
        """
        VALEN full loss function.
        Args:
            model: Main classifier
            images: Raw images (B, C, H, W) - for classifier
            features: Flattened feature vectors (B, D) - for encoder
            targets: Candidate labels
            indexes: Sample indexes
            alpha, beta, gamma, theta: Loss weights
            correct: Label correction coefficient
        """
        batch_size = features.size(0)
        
        # Classifier output (using raw images)
        outputs = model(images)
        
        # Encoder output for alpha parameters (using features)
        alpha_params = self.encoder(features)
        s_alpha = F.softmax(alpha_params, dim=1)
        
        # Restrict alpha to candidate labels
        revised_alpha = torch.zeros_like(targets)
        revised_alpha[self.o_array[indexes, :] > 0] = 1.0
        s_alpha = s_alpha * revised_alpha
        s_alpha_sum = s_alpha.clone().detach().sum(dim=1, keepdim=True)
        s_alpha = s_alpha / (s_alpha_sum + 1e-2) + 1e-2
        
        # Alpha loss (for Encoder)
        L_d, new_d = self._partial_loss(alpha_params, self.o_array[indexes, :])
        
        # Process alpha parameters
        alpha_exp = torch.exp(alpha_params / 4)
        alpha_exp = F.hardtanh(alpha_exp, min_val=1e-2, max_val=30)
        L_alpha = self._alpha_loss(alpha_exp, self.prior_alpha)
        
        # Dirichlet sampling
        dirichlet_sampler = torch.distributions.dirichlet.Dirichlet(s_alpha)
        d = dirichlet_sampler.rsample()
        
        # Reconstruction losses
        # 1. Feature reconstruction (using features)
        x_hat = self.decoder(d)
        x_hat = x_hat.view(features.shape)
        L_recx = 0.01 * F.mse_loss(x_hat, features)
        
        # 2. Label reconstruction
        L_recy = 0.01 * F.binary_cross_entropy_with_logits(d, targets)
        
        # 3. Adjacency matrix reconstruction (disabled to avoid indexing issues)
        L_recA = torch.tensor(0.0).to(self.device)
        
        L_rec = L_recx + L_recy + L_recA
        
        # Classifier loss (using classifier output)
        L_o, new_o = self._partial_loss(outputs, self.d_array[indexes, :])
        
        # Total loss
        total_loss = alpha * L_rec + beta * L_alpha + gamma * L_d + theta * L_o
        
        # Update label distributions
        new_d = self._revised_target(d, new_d)
        new_d = correct * new_d + (1 - correct) * self.o_array[indexes, :]
        self.d_array[indexes, :] = new_d.clone().detach()
        self.o_array[indexes, :] = new_o.clone().detach()
        
        return total_loss
    
    def update_epoch(self):
        """Update epoch count"""
        self.current_epoch += 1
    
    def get_encoder_params(self):
        """Get encoder parameters for optimization"""
        return self.encoder.parameters()
    
    def get_decoder_params(self):
        """Get decoder parameters for optimization"""
        return self.decoder.parameters()

class EMAConfidence:
    def __init__(self, init_conf, momentum=0.9):
        self.conf = init_conf.clone().cuda()
        self.m = momentum

    @torch.no_grad()
    def update(self, index, new_conf):
        self.conf[index] = self.m * self.conf[index] + (1 - self.m) * new_conf

class ALIMPartialLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, soft_target):
        prob = torch.softmax(logits, dim=1)
        return - (soft_target * torch.log(prob + 1e-8)).sum(dim=1).mean()